diff --git a/src/cli/dataset.py b/src/cli/dataset.py index 0a2228ecc..21081d74f 100644 --- a/src/cli/dataset.py +++ b/src/cli/dataset.py @@ -483,8 +483,13 @@ def _run_list_command(service_client: client.ServiceClient, args: argparse.Names Args: args: Parsed command line arguments. """ - params = {'name': args.name, 'user': args.user, 'all_users': args.all, 'buckets': args.bucket, - 'count': args.count, 'order': args.order.upper()} + params = {'all_users': args.all, 'count': args.count, 'order': args.order.upper()} + if args.name: + params['name'] = args.name + if args.user: + params['user'] = args.user + if args.bucket: + params['buckets'] = args.bucket result = service_client.request( client.RequestMethod.GET, 'api/bucket/list_dataset', diff --git a/src/lib/data/storage/backends/backends.py b/src/lib/data/storage/backends/backends.py index f48c6b3ae..810e533a4 100644 --- a/src/lib/data/storage/backends/backends.py +++ b/src/lib/data/storage/backends/backends.py @@ -485,6 +485,16 @@ def data_auth( self._validate_bucket_access(data_cred=data_cred) return + # Ambient credentials (IRSA, EC2 instance profile, EKS pod identity, etc.) rely on + # resource-based policies (bucket policies) in addition to identity-based IAM policies. + # SimulatePrincipalPolicy only evaluates identity policies, so it produces false + # negatives when access is granted exclusively via a bucket policy. For ambient + # credentials we fall back to a real head_bucket check, which is the true arbiter + # of whether the caller can reach the bucket. + if isinstance(data_cred, credentials.DefaultDataCredential): + self._validate_bucket_access(data_cred=data_cred) + return + action = [] if access_type == common.AccessType.READ: action.append('s3:GetObject') @@ -493,25 +503,28 @@ def data_auth( elif access_type == common.AccessType.DELETE: action.append('s3:DeleteObject') - match data_cred: - case credentials.StaticDataCredential(): - session = boto3.Session( - aws_access_key_id=data_cred.access_key_id, - aws_secret_access_key=data_cred.access_key.get_secret_value(), - region_name=self.region(data_cred), - ) - case credentials.DefaultDataCredential(): - session = boto3.Session( - region_name=self.region(data_cred), - ) - case _ as unreachable: - assert_never(unreachable) + session = boto3.Session( + aws_access_key_id=data_cred.access_key_id, + aws_secret_access_key=data_cred.access_key.get_secret_value(), + region_name=self.region(data_cred), + ) iam_client: mypy_boto3_iam.client.IAMClient = session.client('iam') sts_client: mypy_boto3_sts.client.STSClient = session.client('sts') def _validate_auth(): arn = sts_client.get_caller_identity()['Arn'] + # SimulatePrincipalPolicy requires an IAM role ARN, not an STS + # assumed-role session ARN. Convert when running under IRSA or + # instance roles where STS returns an assumed-role ARN. + # arn::sts:::assumed-role/// + # -> arn::iam:::role// + if ':assumed-role/' in arn: + parts = arn.split(':') + partition = parts[1] + account_id = parts[4] + role_path = parts[5].split('assumed-role/', 1)[1].rsplit('/', 1)[0] + arn = f'arn:{partition}:iam::{account_id}:role/{role_path}' path = f'{self.container}/{self.path if self.path else "*"}' if path.endswith('/'): diff --git a/src/lib/data/storage/backends/tests/test_backends.py b/src/lib/data/storage/backends/tests/test_backends.py index 2597779c4..1441dffc2 100644 --- a/src/lib/data/storage/backends/tests/test_backends.py +++ b/src/lib/data/storage/backends/tests/test_backends.py @@ -534,6 +534,53 @@ def test_workflow_config_with_null_credential(self): # Assert self.assertIsNone(config.workflow_data.credential) + def test_workflow_config_data_with_default_credential(self): + """Test WorkflowConfig DataConfig accepts DefaultDataCredential.""" + default_cred = credentials.DefaultDataCredential( + endpoint='s3://bucket.io/workflows', + region='us-west-2', + ) + + config = postgres.WorkflowConfig( + workflow_data=postgres.DataConfig(credential=default_cred), + ) + + credential = config.workflow_data.credential + self.assertIsInstance(credential, credentials.DefaultDataCredential) + assert isinstance(credential, credentials.DefaultDataCredential) + self.assertEqual(credential.endpoint, 's3://bucket.io/workflows') + self.assertEqual(credential.region, 'us-west-2') + + def test_workflow_config_log_with_default_credential(self): + """Test WorkflowConfig LogConfig accepts DefaultDataCredential.""" + default_cred = credentials.DefaultDataCredential( + endpoint='s3://log-bucket.io/logs', + region='us-east-1', + ) + + config = postgres.WorkflowConfig( + workflow_log=postgres.LogConfig(credential=default_cred), + ) + + self.assertIsInstance( + config.workflow_log.credential, + credentials.DefaultDataCredential, + ) + + def test_default_credential_to_decrypted_dict_no_keys(self): + """Test DefaultDataCredential.to_decrypted_dict has no access keys.""" + default_cred = credentials.DefaultDataCredential( + endpoint='s3://bucket.io/data', + region='us-west-2', + ) + + result = default_cred.to_decrypted_dict() + + self.assertEqual(result['endpoint'], 's3://bucket.io/data') + self.assertEqual(result['region'], 'us-west-2') + self.assertNotIn('access_key_id', result) + self.assertNotIn('access_key', result) + if __name__ == '__main__': unittest.main() diff --git a/src/runtime/pkg/data/data.go b/src/runtime/pkg/data/data.go index 22a6f0e95..18f8cdfed 100644 --- a/src/runtime/pkg/data/data.go +++ b/src/runtime/pkg/data/data.go @@ -62,6 +62,14 @@ const ( DatasetOperation string = "Dataset" ) +func setOrUnsetEnv(key, value string) { + if value != "" { + os.Setenv(key, value) + } else { + os.Unsetenv(key) + } +} + type VersionInfo struct { Size int Checksum string @@ -410,12 +418,18 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string, storageBackend := ParseStorageBackend(urlPath) dataCredential, ok := credentialInfo.Auth.Data[storageBackend.GetProfile()] - if !ok { - osmoChan <- fmt.Sprintf("Missing data credential for %s.", storageBackend.GetProfile()) - return isEmpty + if ok { + setOrUnsetEnv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId) + setOrUnsetEnv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey) + setOrUnsetEnv("AWS_REGION", dataCredential.Region) + } else { + // No explicit credential — clear any stale values and let the + // SDK resolve ambient credentials (IRSA, pod identity, etc.). + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + os.Unsetenv("AWS_REGION") } - os.Setenv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId) - os.Setenv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey) + os.Unsetenv("AWS_SESSION_TOKEN") var commandArgs []string diff --git a/src/service/core/data/data_service.py b/src/service/core/data/data_service.py index b210c563f..d571a2d8d 100755 --- a/src/service/core/data/data_service.py +++ b/src/service/core/data/data_service.py @@ -467,8 +467,7 @@ def get_bucket_info(default_only: bool = False, path=bucket_info.dataset_path, description=bucket_info.description, mode=bucket_info.mode, - default_cred=bucket_info.default_credential is not None\ - and bucket_info.default_credential.access_key_id != '')\ + default_cred=bucket_info.default_credential is not None)\ for bucket_name, bucket_info in dataset_configs.buckets.items() } diff --git a/src/utils/connectors/postgres.py b/src/utils/connectors/postgres.py index 77f479131..f34fde95b 100644 --- a/src/utils/connectors/postgres.py +++ b/src/utils/connectors/postgres.py @@ -1504,7 +1504,7 @@ def func(new_encrypted: str): self.execute_commit_command(cmd, (cmd_args[0], new_encrypted) + cmd_args[1:]) return func - def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCredential | None: + def get_data_cred(self, user: str, profile: str) -> credentials.DataCredential | None: """ Fetch data credentials by profile. """ select_data_cmd = PostgresSelectCommand( table='credential', @@ -1522,13 +1522,7 @@ def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCreden bucket_info = storage.construct_storage_backend(bucket.dataset_path) if bucket_info.profile == profile: if bucket.default_credential: - return credentials.StaticDataCredential( - region=bucket.region, - access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.access_key, - endpoint=bucket_info.profile, - override_url=bucket.default_credential.override_url, - ) + return _resolve_bucket_credential(bucket, bucket_info.profile) break return None @@ -1541,7 +1535,7 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCrede condition_args=[user, CredentialType.DATA.value]) rows = self.execute_fetch_command(*select_data_cmd.get_args()) - user_creds = { + user_creds: Dict[str, credentials.StaticDataCredential] = { cred.profile: credentials.StaticDataCredential( endpoint=cred.profile, **self.decrypt_credential(cred), @@ -1553,13 +1547,8 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCrede for bucket in self.get_dataset_configs().buckets.values(): bucket_info = storage.construct_storage_backend(bucket.dataset_path) if bucket_info.profile not in user_creds and bucket.default_credential: - user_creds[bucket_info.profile] = credentials.StaticDataCredential( - region=bucket.region, - access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.access_key, - endpoint=bucket_info.profile, - override_url=bucket.default_credential.override_url, - ) + user_creds[bucket_info.profile] = _resolve_bucket_credential( + bucket, bucket_info.profile) return user_creds def get_generic_cred(self, user: str, cred_name: str) -> Any: @@ -2686,7 +2675,7 @@ def construct_path(endpoint: str, bucket: str, path: str): class LogConfig(ExtraArgBaseModel): """ Config for storing information about data. """ - credential: credentials.StaticDataCredential | None = None + credential: credentials.DataCredential | None = None class WorkflowInfo(ExtraArgBaseModel): @@ -2703,7 +2692,7 @@ def validate_name(self, name: str): class DataConfig(ExtraArgBaseModel): """ Config for storing information about data. """ - credential: credentials.StaticDataCredential | None = None + credential: credentials.DataCredential | None = None base_url: str = '' # Timeout in mins for osmo-ctrl to retry connecting to the OSMO service until exiting the task @@ -2727,6 +2716,23 @@ class BucketMode(enum.Enum): READ_WRITE = 'read-write' +def _resolve_bucket_credential( + bucket: 'BucketConfig', + profile: str, +) -> credentials.StaticDataCredential: + """Resolve a bucket's default_credential, rebinding it to the bucket's profile and region.""" + credential = bucket.default_credential + if credential is None: + raise ValueError(f'No default credential configured for bucket with profile {profile}') + return credentials.StaticDataCredential( + endpoint=profile, + region=bucket.region, + access_key_id=credential.access_key_id, + access_key=credential.access_key, + override_url=credential.override_url, + ) + + class BucketConfig(ExtraArgBaseModel): """ Class to store the name of the bucket and the dataset path @@ -2736,7 +2742,7 @@ class BucketConfig(ExtraArgBaseModel): description: str = '' # Mode for read-only or read-write or write-only mode: str = BucketMode.READ_WRITE.value - # Default cred to use doesn't have one + # Default cred to use is a static credential # Only applies to workflow operations, NOT user cli since we cannot forward the credential # to the user default_credential: credentials.StaticDataCredential | None = None diff --git a/src/utils/job/task.py b/src/utils/job/task.py index 48dc11881..b1ab835fd 100644 --- a/src/utils/job/task.py +++ b/src/utils/job/task.py @@ -26,6 +26,7 @@ import re import secrets import time +from collections.abc import Mapping from typing import Any, Dict, List, Optional, Set, Tuple, Union from urllib.parse import urlencode @@ -102,7 +103,7 @@ def create_login_dict(user: str, def create_config_dict( - data_info: dict[str, credentials.StaticDataCredential], + data_info: Mapping[str, credentials.StaticDataCredential | credentials.DefaultDataCredential], ) -> dict: ''' Creates the config dict where the input should be a dict containing key values like: @@ -2699,7 +2700,7 @@ def convert_to_pod_spec( service_config: connectors.ServiceConfig | None = None, dataset_config: connectors.DatasetConfig | None = None, pool_info: connectors.Pool | None = None, - data_endpoints: Dict[str, credentials.StaticDataCredential] | None = None, + data_endpoints: Mapping[str, credentials.StaticDataCredential] | None = None, skip_refresh_token: bool = False, auth_token: str | None = None, ) -> Tuple[Dict, Dict[str, kb_objects.FileMount], Optional[Tuple[str, str]]]: @@ -3193,13 +3194,15 @@ def decode_hstore(tasks: str) -> Set[str]: def fetch_creds( user: str, - data_creds: dict[str, credentials.StaticDataCredential], + data_creds: Mapping[str, credentials.StaticDataCredential], path: str, disabled_data: list[str] | None = None, ) -> credentials.StaticDataCredential | None: backend_info = storage.construct_storage_backend(path) if backend_info.profile not in data_creds: + if backend_info.supports_environment_auth: + return None if not disabled_data or backend_info.scheme not in disabled_data: raise osmo_errors.OSMOCredentialError( f'Could not find {backend_info.profile} credential for user {user}.') diff --git a/src/utils/job/tests/BUILD b/src/utils/job/tests/BUILD index df70621ed..f333ef74b 100644 --- a/src/utils/job/tests/BUILD +++ b/src/utils/job/tests/BUILD @@ -67,6 +67,7 @@ py_test( tags = ["manual"], deps = [ "//src/lib/utils:common", + "//src/lib/utils:credentials", "//src/utils/connectors", "//src/utils/job", ] diff --git a/src/utils/job/tests/test_task.py b/src/utils/job/tests/test_task.py index 63de6cdcc..6fb81779b 100644 --- a/src/utils/job/tests/test_task.py +++ b/src/utils/job/tests/test_task.py @@ -21,7 +21,7 @@ from unittest import mock import unittest -from src.lib.utils import common +from src.lib.utils import common, credentials from src.utils.job import task, kb_objects from src.utils import connectors @@ -869,5 +869,67 @@ def test_file_mount_credential_not_in_secrets(self): self.assertEqual(len(user_secret_call), 0) +class CreateConfigDictTest(unittest.TestCase): + """Tests for create_config_dict with different credential types.""" + + def test_static_credential(self): + """Test create_config_dict with StaticDataCredential.""" + static_cred = credentials.StaticDataCredential( + endpoint='s3://my-bucket', + access_key_id='AKIAIOSFODNN7EXAMPLE', + access_key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', + region='us-east-1', + ) + + result = task.create_config_dict({'s3://my-bucket': static_cred}) + + data_entry = result['auth']['data']['s3://my-bucket'] + self.assertEqual(data_entry['access_key_id'], 'AKIAIOSFODNN7EXAMPLE') + self.assertEqual(data_entry['access_key'], 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY') + self.assertEqual(data_entry['endpoint'], 's3://my-bucket') + self.assertEqual(data_entry['region'], 'us-east-1') + + def test_default_credential(self): + """Test create_config_dict with DefaultDataCredential produces no access keys.""" + default_cred = credentials.DefaultDataCredential( + endpoint='s3://ambient-bucket', + region='us-west-2', + ) + + result = task.create_config_dict({'s3://ambient-bucket': default_cred}) + + data_entry = result['auth']['data']['s3://ambient-bucket'] + self.assertEqual(data_entry['endpoint'], 's3://ambient-bucket') + self.assertEqual(data_entry['region'], 'us-west-2') + self.assertNotIn('access_key_id', data_entry) + self.assertNotIn('access_key', data_entry) + + def test_mixed_credentials(self): + """Test create_config_dict with both credential types.""" + static_cred = credentials.StaticDataCredential( + endpoint='s3://static-bucket', + access_key_id='AKIAIOSFODNN7EXAMPLE', + access_key='secret', + ) + default_cred = credentials.DefaultDataCredential( + endpoint='s3://ambient-bucket', + region='eu-west-1', + ) + + result = task.create_config_dict({ + 's3://static-bucket': static_cred, + 's3://ambient-bucket': default_cred, + }) + + static_entry = result['auth']['data']['s3://static-bucket'] + self.assertIn('access_key_id', static_entry) + self.assertIn('access_key', static_entry) + + ambient_entry = result['auth']['data']['s3://ambient-bucket'] + self.assertNotIn('access_key_id', ambient_entry) + self.assertNotIn('access_key', ambient_entry) + self.assertEqual(ambient_entry['region'], 'eu-west-1') + + if __name__ == '__main__': unittest.main() diff --git a/src/utils/job/workflow.py b/src/utils/job/workflow.py index 16929c605..7be81c50e 100644 --- a/src/utils/job/workflow.py +++ b/src/utils/job/workflow.py @@ -30,6 +30,7 @@ import requests # type: ignore from src.lib.data import storage +from src.lib.data.storage import credentials from src.lib.utils import (common, jinja_sandbox, osmo_errors, priority as wf_priority, workflow as workflow_utils) from src.utils import connectors, notify @@ -653,7 +654,7 @@ def validate_data(self, user: str, dataset_config: connectors.DatasetConfig, seen_bucket_input: Set[str], seen_bucket_output: Set[str], default_user_bucket: str | None, default_service_bucket: str, - user_creds: Dict[str, Any]): + user_creds: Dict[str, credentials.StaticDataCredential]): def _validate_input_output(data_spec: Union[task.InputType, task.OutputType, task.TaskKPI], is_input: bool): @@ -710,24 +711,35 @@ def _fetch_bucket_info(dataset_info: common.DatasetStructure)\ if bucket_info.scheme in disabled_data: return - data_cred = task.fetch_creds(user, user_creds, bucket_info.uri) + access_type = storage.AccessType.READ if is_input else storage.AccessType.WRITE + uri_cache = seen_uri_input if is_input else seen_uri_output - if data_cred is None: - # User does not have any credentials, check if the backend - # supports environment authentication - if not bucket_info.supports_environment_auth: - raise osmo_errors.OSMOCredentialError( - f'Could not validate access to {bucket_info.uri} for user {user}.') - else: - # Check if user credentials have access to READ - if is_input and bucket_info.uri not in seen_uri_input: - bucket_info.data_auth(data_cred, storage.AccessType.READ) - seen_uri_input.add(bucket_info.uri) - - # Check if user credentials have access to WRITE - if not is_input and bucket_info.uri not in seen_uri_output: - bucket_info.data_auth(data_cred, storage.AccessType.WRITE) - seen_uri_output.add(bucket_info.uri) + if bucket_info.uri in uri_cache: + return + + # Credential resolution — strict priority, no cross-tier fallback: + # 1. User explicit credential (from their stored credentials in DB). + # 2. System default_credential on the bucket config (used only when the user + # has no explicit credential for this profile — already enforced by + # get_all_data_creds which merges them with user taking priority). + # Both tiers are represented by user_creds.get(profile): if the user has their + # own credential it wins; otherwise the bucket default is returned; None means + # neither tier has anything configured. + # + # 3. No configured credential + backend does not support ambient → error. + # 4. No configured credential + backend supports ambient → skip submit-time + # validation entirely and let the runtime (osmo-ctrl) perform the real check. + configured_cred = user_creds.get(bucket_info.profile) + + if configured_cred is not None: + bucket_info.data_auth(configured_cred, access_type) + uri_cache.add(bucket_info.uri) + elif not bucket_info.supports_environment_auth: + raise osmo_errors.OSMOCredentialError( + f'No credentials configured for {bucket_info.uri} (user {user}) ' + f'and the backend does not support environment authentication. ' + f'Run `osmo credential set` to add a credential.') + # else: ambient — no submit-time check; osmo-ctrl validates at runtime for input_data_spec in group_task.inputs: _validate_input_output(input_data_spec, True)