diff --git a/src/runtime/pkg/data/data.go b/src/runtime/pkg/data/data.go index 21b11fdca..920940864 100644 --- a/src/runtime/pkg/data/data.go +++ b/src/runtime/pkg/data/data.go @@ -70,12 +70,57 @@ const ( DatasetOperation string = "Dataset" ) -func setOrUnsetEnv(key, value string) { - if value != "" { - os.Setenv(key, value) - } else { - os.Unsetenv(key) +func envSliceToMap(env []string) map[string]string { + envMap := make(map[string]string, len(env)) + for _, entry := range env { + key, value, found := strings.Cut(entry, "=") + if found { + envMap[key] = value + } + } + return envMap +} + +func envMapToSlice(envMap map[string]string) []string { + env := make([]string, 0, len(envMap)) + for key, value := range envMap { + env = append(env, key+"="+value) } + return env +} + +func awsCredentialCommandEnv( + baseEnv []string, + dataCredential DataCredential, + hasExplicitCredential bool, +) []string { + envMap := envSliceToMap(baseEnv) + if !hasExplicitCredential { + return envMapToSlice(envMap) + } + + usesStaticCredential := dataCredential.AccessKeyId != "" || dataCredential.AccessKey != "" + if usesStaticCredential { + if dataCredential.AccessKeyId != "" { + envMap["AWS_ACCESS_KEY_ID"] = dataCredential.AccessKeyId + } else { + delete(envMap, "AWS_ACCESS_KEY_ID") + } + if dataCredential.AccessKey != "" { + envMap["AWS_SECRET_ACCESS_KEY"] = dataCredential.AccessKey + } else { + delete(envMap, "AWS_SECRET_ACCESS_KEY") + } + delete(envMap, "AWS_SESSION_TOKEN") + } + + if dataCredential.Region != "" { + envMap["AWS_REGION"] = dataCredential.Region + } else if usesStaticCredential { + delete(envMap, "AWS_REGION") + } + + return envMapToSlice(envMap) } func awsForcePathStyleEnabled() bool { @@ -503,19 +548,7 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string, isEmpty := true storageBackend := ParseStorageBackend(urlPath) - dataCredential, ok := credentialInfo.Auth.Data[storageBackend.GetProfile()] - if ok { - setOrUnsetEnv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId) - setOrUnsetEnv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey) - setOrUnsetEnv("AWS_REGION", dataCredential.Region) - } else { - // No explicit credential — clear any stale values and let the - // SDK resolve ambient credentials (IRSA, pod identity, etc.). - os.Unsetenv("AWS_ACCESS_KEY_ID") - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - os.Unsetenv("AWS_REGION") - } - os.Unsetenv("AWS_SESSION_TOKEN") + dataCredential, hasExplicitCredential := credentialInfo.Auth.Data[storageBackend.GetProfile()] var commandArgs []string @@ -542,6 +575,7 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string, mountS3Path := common.ResolveCommandPath("MOUNT_S3_PATH", "mount-s3", "/usr/bin/mount-s3") cmd := exec.Command(mountS3Path, commandArgs...) + cmd.Env = awsCredentialCommandEnv(os.Environ(), dataCredential, hasExplicitCredential) cmd.Stderr = log if err = cmd.Run(); err != nil { if strings.Contains(err.Error(), "Timeout") { diff --git a/src/runtime/pkg/data/data_test.go b/src/runtime/pkg/data/data_test.go index 38cd70715..ecc1b664d 100644 --- a/src/runtime/pkg/data/data_test.go +++ b/src/runtime/pkg/data/data_test.go @@ -20,9 +20,88 @@ package data import ( "slices" + "strings" "testing" ) +func testEnvMap(env []string) map[string]string { + envMap := make(map[string]string, len(env)) + for _, entry := range env { + key, value, found := strings.Cut(entry, "=") + if found { + envMap[key] = value + } + } + return envMap +} + +func requireEnvValue(t *testing.T, envMap map[string]string, key string, expected string) { + t.Helper() + if got := envMap[key]; got != expected { + t.Fatalf("expected %s=%q, got %q", key, expected, got) + } +} + +func requireEnvMissing(t *testing.T, envMap map[string]string, key string) { + t.Helper() + if value, ok := envMap[key]; ok { + t.Fatalf("expected %s to be unset, got %q", key, value) + } +} + +func TestAwsCredentialCommandEnvStaticCredentialOverridesAmbient(t *testing.T) { + env := awsCredentialCommandEnv([]string{ + "AWS_ACCESS_KEY_ID=ambient-id", + "AWS_SECRET_ACCESS_KEY=ambient-secret", + "AWS_SESSION_TOKEN=ambient-session", + "AWS_REGION=us-east-1", + "OSMO_TEST=value", + }, DataCredential{ + AccessKeyId: "static-id", + AccessKey: "static-secret", + Region: "us-west-2", + }, true) + + envMap := testEnvMap(env) + requireEnvValue(t, envMap, "AWS_ACCESS_KEY_ID", "static-id") + requireEnvValue(t, envMap, "AWS_SECRET_ACCESS_KEY", "static-secret") + requireEnvValue(t, envMap, "AWS_REGION", "us-west-2") + requireEnvValue(t, envMap, "OSMO_TEST", "value") + requireEnvMissing(t, envMap, "AWS_SESSION_TOKEN") +} + +func TestAwsCredentialCommandEnvDefaultCredentialPreservesAmbientAuth(t *testing.T) { + env := awsCredentialCommandEnv([]string{ + "AWS_ACCESS_KEY_ID=ambient-id", + "AWS_SECRET_ACCESS_KEY=ambient-secret", + "AWS_SESSION_TOKEN=ambient-session", + "AWS_REGION=us-east-1", + }, DataCredential{ + Region: "us-west-2", + }, true) + + envMap := testEnvMap(env) + requireEnvValue(t, envMap, "AWS_ACCESS_KEY_ID", "ambient-id") + requireEnvValue(t, envMap, "AWS_SECRET_ACCESS_KEY", "ambient-secret") + requireEnvValue(t, envMap, "AWS_SESSION_TOKEN", "ambient-session") + requireEnvValue(t, envMap, "AWS_REGION", "us-west-2") +} + +func TestAwsCredentialCommandEnvNoCredentialPreservesAmbientAuth(t *testing.T) { + env := awsCredentialCommandEnv([]string{ + "AWS_ACCESS_KEY_ID=ambient-id", + "AWS_SECRET_ACCESS_KEY=ambient-secret", + "AWS_SESSION_TOKEN=ambient-session", + "AWS_REGION=us-east-1", + }, DataCredential{}, false) + + envMap := testEnvMap(env) + requireEnvValue(t, envMap, "AWS_ACCESS_KEY_ID", "ambient-id") + requireEnvValue(t, envMap, "AWS_SECRET_ACCESS_KEY", "ambient-secret") + requireEnvValue(t, envMap, "AWS_SESSION_TOKEN", "ambient-session") + requireEnvValue(t, envMap, "AWS_REGION", "us-east-1") +} + func TestBuildMountCommandArgsS3VirtualCustomEndpoint(t *testing.T) { backend := ParseStorageBackend("s3://coreweave-bucket/datasets") credential := DataCredential{ diff --git a/src/utils/connectors/postgres.py b/src/utils/connectors/postgres.py index 9700b6009..a003094b7 100644 --- a/src/utils/connectors/postgres.py +++ b/src/utils/connectors/postgres.py @@ -1528,7 +1528,7 @@ def get_data_cred(self, user: str, profile: str) -> credentials.DataCredential | return None - def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCredential]: + def get_all_data_creds(self, user: str) -> Dict[str, credentials.DataCredential]: """ Fetch all data credentials for user. """ select_data_cmd = PostgresSelectCommand( table='credential', @@ -1536,7 +1536,7 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCrede condition_args=[user, CredentialType.DATA.value]) rows = self.execute_fetch_command(*select_data_cmd.get_args()) - user_creds: Dict[str, credentials.StaticDataCredential] = { + user_creds: Dict[str, credentials.DataCredential] = { cred.profile: credentials.StaticDataCredential( endpoint=cred.profile, **self.decrypt_credential(cred), diff --git a/src/utils/connectors/tests/test_cli_config.py b/src/utils/connectors/tests/test_cli_config.py index 719d23635..a016b95ed 100644 --- a/src/utils/connectors/tests/test_cli_config.py +++ b/src/utils/connectors/tests/test_cli_config.py @@ -17,6 +17,7 @@ """ import unittest +from src.lib.data.storage import credentials from src.lib.utils import osmo_errors from src.utils import connectors @@ -46,5 +47,31 @@ def test_rejects_malformed_min_supported_version(self): connectors.CliConfig(min_supported_version=bad) +class TestWorkflowStorageConfig(unittest.TestCase): + """Validation tests for workflow storage credentials.""" + + def test_workflow_log_accepts_default_data_credential(self): + log_config = connectors.LogConfig(credential={ + 'endpoint': 's3://osmo-logs', + 'region': 'us-west-2', + }) + + credential = log_config.credential + if not isinstance(credential, credentials.DefaultDataCredential): + self.fail('workflow_log credential should use DefaultDataCredential') + self.assertEqual(credential.endpoint, 's3://osmo-logs') + + def test_workflow_data_accepts_default_data_credential(self): + data_config = connectors.DataConfig(credential={ + 'endpoint': 's3://osmo-data', + 'region': 'us-west-2', + }) + + credential = data_config.credential + if not isinstance(credential, credentials.DefaultDataCredential): + self.fail('workflow_data credential should use DefaultDataCredential') + self.assertEqual(credential.endpoint, 's3://osmo-data') + + if __name__ == '__main__': unittest.main() diff --git a/src/utils/job/task.py b/src/utils/job/task.py index 31b64e9bb..76d9b335f 100644 --- a/src/utils/job/task.py +++ b/src/utils/job/task.py @@ -103,7 +103,7 @@ def create_login_dict(user: str, def create_config_dict( - data_info: Mapping[str, credentials.StaticDataCredential | credentials.DefaultDataCredential], + data_info: Mapping[str, credentials.DataCredential], ) -> dict: ''' Creates the config dict where the input should be a dict containing key values like: @@ -2700,7 +2700,7 @@ def convert_to_pod_spec( service_config: connectors.ServiceConfig | None = None, dataset_config: connectors.DatasetConfig | None = None, pool_info: connectors.Pool | None = None, - data_endpoints: Mapping[str, credentials.StaticDataCredential] | None = None, + data_endpoints: Mapping[str, credentials.DataCredential] | None = None, skip_refresh_token: bool = False, auth_token: str | None = None, ) -> Tuple[Dict, Dict[str, kb_objects.FileMount], Optional[Tuple[str, str]]]: @@ -3194,10 +3194,10 @@ def decode_hstore(tasks: str) -> Set[str]: def fetch_creds( user: str, - data_creds: Mapping[str, credentials.StaticDataCredential], + data_creds: Mapping[str, credentials.DataCredential], path: str, disabled_data: list[str] | None = None, -) -> credentials.StaticDataCredential | None: +) -> credentials.DataCredential | None: backend_info = storage.construct_storage_backend(path) if backend_info.profile not in data_creds: