Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/cli/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,10 @@ def _run_list_command(service_client: client.ServiceClient, args: argparse.Names
Args:
args: Parsed command line arguments.
"""
params = {'name': args.name, 'user': args.user, 'all_users': args.all, 'buckets': args.bucket,
params = {'user': args.user, 'all_users': args.all, 'buckets': args.bucket,
'count': args.count, 'order': args.order.upper()}
if args.name:
params['name'] = args.name
result = service_client.request(
client.RequestMethod.GET,
'api/bucket/list_dataset',
Expand Down
10 changes: 10 additions & 0 deletions src/lib/data/storage/backends/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,16 @@ def data_auth(

def _validate_auth():
arn = sts_client.get_caller_identity()['Arn']
# SimulatePrincipalPolicy requires an IAM role ARN, not an STS
# assumed-role session ARN. Convert when running under IRSA or
# instance roles where STS returns an assumed-role ARN.
# arn:aws:sts::<acct>:assumed-role/<role>/<session>
# -> arn:aws:iam::<acct>:role/<role>
if ':assumed-role/' in arn:
parts = arn.split(':')
account_id = parts[4]
role_name = parts[5].split('/')[1]
arn = f'arn:aws:iam::{account_id}:role/{role_name}'
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
path = f'{self.container}/{self.path if self.path else "*"}'

if path.endswith('/'):
Expand Down
47 changes: 47 additions & 0 deletions src/lib/data/storage/backends/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,53 @@ def test_workflow_config_with_null_credential(self):
# Assert
self.assertIsNone(config.workflow_data.credential)

def test_workflow_config_data_with_default_credential(self):
"""Test WorkflowConfig DataConfig accepts DefaultDataCredential."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://bucket.io/workflows',
region='us-west-2',
)

config = postgres.WorkflowConfig(
workflow_data=postgres.DataConfig(credential=default_cred),
)

credential = config.workflow_data.credential
self.assertIsInstance(credential, credentials.DefaultDataCredential)
assert isinstance(credential, credentials.DefaultDataCredential)
self.assertEqual(credential.endpoint, 's3://bucket.io/workflows')
self.assertEqual(credential.region, 'us-west-2')

def test_workflow_config_log_with_default_credential(self):
"""Test WorkflowConfig LogConfig accepts DefaultDataCredential."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://log-bucket.io/logs',
region='us-east-1',
)

config = postgres.WorkflowConfig(
workflow_log=postgres.LogConfig(credential=default_cred),
)

self.assertIsInstance(
config.workflow_log.credential,
credentials.DefaultDataCredential,
)

def test_default_credential_to_decrypted_dict_no_keys(self):
"""Test DefaultDataCredential.to_decrypted_dict has no access keys."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://bucket.io/data',
region='us-west-2',
)

result = default_cred.to_decrypted_dict()

self.assertEqual(result['endpoint'], 's3://bucket.io/data')
self.assertEqual(result['region'], 'us-west-2')
self.assertNotIn('access_key_id', result)
self.assertNotIn('access_key', result)


if __name__ == '__main__':
unittest.main()
13 changes: 11 additions & 2 deletions src/runtime/pkg/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,17 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string,
osmoChan <- fmt.Sprintf("Missing data credential for %s.", storageBackend.GetProfile())
return isEmpty
}
os.Setenv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId)
os.Setenv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey)
// Only set static key env vars when keys are provided.
// When using DefaultDataCredential (ambient credentials via Pod Identity,
// IRSA, etc.), keys are empty — setting empty env vars would clobber the
// SDK's default credential chain.
if dataCredential.AccessKeyId != "" {
os.Setenv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId)
os.Setenv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey)
}
if dataCredential.Region != "" {
os.Setenv("AWS_REGION", dataCredential.Region)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

var commandArgs []string

Expand Down
3 changes: 1 addition & 2 deletions src/service/core/data/data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,7 @@ def get_bucket_info(default_only: bool = False,
path=bucket_info.dataset_path,
description=bucket_info.description,
mode=bucket_info.mode,
default_cred=bucket_info.default_credential is not None\
and bucket_info.default_credential.access_key_id != '')\
default_cred=bucket_info.default_credential is not None)\
for bucket_name, bucket_info in dataset_configs.buckets.items()
}

Expand Down
50 changes: 30 additions & 20 deletions src/utils/connectors/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ def func(new_encrypted: str):
self.execute_commit_command(cmd, (cmd_args[0], new_encrypted) + cmd_args[1:])
return func

def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCredential | None:
def get_data_cred(self, user: str, profile: str) -> credentials.DataCredential | None:
""" Fetch data credentials by profile. """
select_data_cmd = PostgresSelectCommand(
table='credential',
Expand All @@ -1522,26 +1522,20 @@ def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCreden
bucket_info = storage.construct_storage_backend(bucket.dataset_path)
if bucket_info.profile == profile:
if bucket.default_credential:
return credentials.StaticDataCredential(
region=bucket.region,
access_key_id=bucket.default_credential.access_key_id,
access_key=bucket.default_credential.access_key,
endpoint=bucket_info.profile,
override_url=bucket.default_credential.override_url,
)
return _resolve_bucket_credential(bucket, bucket_info.profile)
break

return None

def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCredential]:
def get_all_data_creds(self, user: str) -> Dict[str, credentials.DataCredential]:
""" Fetch all data credentials for user. """
select_data_cmd = PostgresSelectCommand(
table='credential',
conditions=['user_name = %s', 'cred_type = %s'],
condition_args=[user, CredentialType.DATA.value])
rows = self.execute_fetch_command(*select_data_cmd.get_args())

user_creds = {
user_creds: Dict[str, credentials.DataCredential] = {
cred.profile: credentials.StaticDataCredential(
endpoint=cred.profile,
**self.decrypt_credential(cred),
Expand All @@ -1553,13 +1547,8 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCrede
for bucket in self.get_dataset_configs().buckets.values():
bucket_info = storage.construct_storage_backend(bucket.dataset_path)
if bucket_info.profile not in user_creds and bucket.default_credential:
user_creds[bucket_info.profile] = credentials.StaticDataCredential(
region=bucket.region,
access_key_id=bucket.default_credential.access_key_id,
access_key=bucket.default_credential.access_key,
endpoint=bucket_info.profile,
override_url=bucket.default_credential.override_url,
)
user_creds[bucket_info.profile] = _resolve_bucket_credential(
bucket, bucket_info.profile)
return user_creds

def get_generic_cred(self, user: str, cred_name: str) -> Any:
Expand Down Expand Up @@ -2686,7 +2675,7 @@ def construct_path(endpoint: str, bucket: str, path: str):

class LogConfig(ExtraArgBaseModel):
""" Config for storing information about data. """
credential: credentials.StaticDataCredential | None = None
credential: credentials.DataCredential | None = None


class WorkflowInfo(ExtraArgBaseModel):
Expand All @@ -2703,7 +2692,7 @@ def validate_name(self, name: str):

class DataConfig(ExtraArgBaseModel):
""" Config for storing information about data. """
credential: credentials.StaticDataCredential | None = None
credential: credentials.DataCredential | None = None

base_url: str = ''
# Timeout in mins for osmo-ctrl to retry connecting to the OSMO service until exiting the task
Expand All @@ -2727,6 +2716,27 @@ class BucketMode(enum.Enum):
READ_WRITE = 'read-write'


def _resolve_bucket_credential(
bucket: 'BucketConfig',
profile: str,
) -> credentials.DataCredential:
"""Resolve a bucket's default_credential into the appropriate DataCredential type."""
credential = bucket.default_credential
if isinstance(credential, credentials.StaticDataCredential):
return credentials.StaticDataCredential(
region=bucket.region,
access_key_id=credential.access_key_id,
access_key=credential.access_key,
endpoint=profile,
override_url=credential.override_url,
)
return credentials.DefaultDataCredential(
endpoint=profile,
region=bucket.region,
override_url=credential.override_url,
)


class BucketConfig(ExtraArgBaseModel):
"""
Class to store the name of the bucket and the dataset path
Expand All @@ -2739,7 +2749,7 @@ class BucketConfig(ExtraArgBaseModel):
# Default cred to use doesn't have one
# Only applies to workflow operations, NOT user cli since we cannot forward the credential
# to the user
default_credential: credentials.StaticDataCredential | None = None
default_credential: credentials.DataCredential | None = None

def valid_access(self, bucket_name: str, access_type: BucketModeAccess):
if not ((access_type == BucketModeAccess.READ and\
Expand Down
9 changes: 5 additions & 4 deletions src/utils/job/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import re
import secrets
import time
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from urllib.parse import urlencode

Expand Down Expand Up @@ -102,7 +103,7 @@ def create_login_dict(user: str,


def create_config_dict(
data_info: dict[str, credentials.StaticDataCredential],
data_info: Mapping[str, credentials.DataCredential],
) -> dict:
'''
Creates the config dict where the input should be a dict containing key values like:
Expand Down Expand Up @@ -2699,7 +2700,7 @@ def convert_to_pod_spec(
service_config: connectors.ServiceConfig | None = None,
dataset_config: connectors.DatasetConfig | None = None,
pool_info: connectors.Pool | None = None,
data_endpoints: Dict[str, credentials.StaticDataCredential] | None = None,
data_endpoints: Mapping[str, credentials.DataCredential] | None = None,
skip_refresh_token: bool = False,
auth_token: str | None = None,
) -> Tuple[Dict, Dict[str, kb_objects.FileMount], Optional[Tuple[str, str]]]:
Expand Down Expand Up @@ -3193,10 +3194,10 @@ def decode_hstore(tasks: str) -> Set[str]:

def fetch_creds(
user: str,
data_creds: dict[str, credentials.StaticDataCredential],
data_creds: Mapping[str, credentials.DataCredential],
path: str,
disabled_data: list[str] | None = None,
) -> credentials.StaticDataCredential | None:
) -> credentials.DataCredential | None:
backend_info = storage.construct_storage_backend(path)

if backend_info.profile not in data_creds:
Expand Down
1 change: 1 addition & 0 deletions src/utils/job/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ py_test(
tags = ["manual"],
deps = [
"//src/lib/utils:common",
"//src/lib/utils:credentials",
"//src/utils/connectors",
"//src/utils/job",
]
Expand Down
64 changes: 63 additions & 1 deletion src/utils/job/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from unittest import mock
import unittest

from src.lib.utils import common
from src.lib.utils import common, credentials
from src.utils.job import task, kb_objects
from src.utils import connectors

Expand Down Expand Up @@ -869,5 +869,67 @@ def test_file_mount_credential_not_in_secrets(self):
self.assertEqual(len(user_secret_call), 0)


class CreateConfigDictTest(unittest.TestCase):
"""Tests for create_config_dict with different credential types."""

def test_static_credential(self):
"""Test create_config_dict with StaticDataCredential."""
static_cred = credentials.StaticDataCredential(
endpoint='s3://my-bucket',
access_key_id='AKIAIOSFODNN7EXAMPLE',
access_key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
region='us-east-1',
)

result = task.create_config_dict({'s3://my-bucket': static_cred})

data_entry = result['auth']['data']['s3://my-bucket']
self.assertEqual(data_entry['access_key_id'], 'AKIAIOSFODNN7EXAMPLE')
self.assertEqual(data_entry['access_key'], 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY')
self.assertEqual(data_entry['endpoint'], 's3://my-bucket')
self.assertEqual(data_entry['region'], 'us-east-1')

def test_default_credential(self):
"""Test create_config_dict with DefaultDataCredential produces no access keys."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://ambient-bucket',
region='us-west-2',
)

result = task.create_config_dict({'s3://ambient-bucket': default_cred})

data_entry = result['auth']['data']['s3://ambient-bucket']
self.assertEqual(data_entry['endpoint'], 's3://ambient-bucket')
self.assertEqual(data_entry['region'], 'us-west-2')
self.assertNotIn('access_key_id', data_entry)
self.assertNotIn('access_key', data_entry)

def test_mixed_credentials(self):
"""Test create_config_dict with both credential types."""
static_cred = credentials.StaticDataCredential(
endpoint='s3://static-bucket',
access_key_id='AKIAIOSFODNN7EXAMPLE',
access_key='secret',
)
default_cred = credentials.DefaultDataCredential(
endpoint='s3://ambient-bucket',
region='eu-west-1',
)

result = task.create_config_dict({
's3://static-bucket': static_cred,
's3://ambient-bucket': default_cred,
})

static_entry = result['auth']['data']['s3://static-bucket']
self.assertIn('access_key_id', static_entry)
self.assertIn('access_key', static_entry)

ambient_entry = result['auth']['data']['s3://ambient-bucket']
self.assertNotIn('access_key_id', ambient_entry)
self.assertNotIn('access_key', ambient_entry)
self.assertEqual(ambient_entry['region'], 'eu-west-1')


if __name__ == '__main__':
unittest.main()
Loading