Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 52 additions & 18 deletions src/runtime/pkg/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,57 @@ const (
DatasetOperation string = "Dataset"
)

func setOrUnsetEnv(key, value string) {
if value != "" {
os.Setenv(key, value)
} else {
os.Unsetenv(key)
func envSliceToMap(env []string) map[string]string {
envMap := make(map[string]string, len(env))
for _, entry := range env {
key, value, found := strings.Cut(entry, "=")
if found {
envMap[key] = value
}
}
return envMap
}

func envMapToSlice(envMap map[string]string) []string {
env := make([]string, 0, len(envMap))
for key, value := range envMap {
env = append(env, key+"="+value)
}
return env
}

func awsCredentialCommandEnv(
baseEnv []string,
dataCredential DataCredential,
hasExplicitCredential bool,
) []string {
envMap := envSliceToMap(baseEnv)
if !hasExplicitCredential {
return envMapToSlice(envMap)
}

usesStaticCredential := dataCredential.AccessKeyId != "" || dataCredential.AccessKey != ""
if usesStaticCredential {
if dataCredential.AccessKeyId != "" {
envMap["AWS_ACCESS_KEY_ID"] = dataCredential.AccessKeyId
} else {
delete(envMap, "AWS_ACCESS_KEY_ID")
}
if dataCredential.AccessKey != "" {
envMap["AWS_SECRET_ACCESS_KEY"] = dataCredential.AccessKey
} else {
delete(envMap, "AWS_SECRET_ACCESS_KEY")
}
delete(envMap, "AWS_SESSION_TOKEN")
}

if dataCredential.Region != "" {
envMap["AWS_REGION"] = dataCredential.Region
} else if usesStaticCredential {
delete(envMap, "AWS_REGION")
}

return envMapToSlice(envMap)
}

func awsForcePathStyleEnabled() bool {
Expand Down Expand Up @@ -503,19 +548,7 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string,
isEmpty := true
storageBackend := ParseStorageBackend(urlPath)

dataCredential, ok := credentialInfo.Auth.Data[storageBackend.GetProfile()]
if ok {
setOrUnsetEnv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId)
setOrUnsetEnv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey)
setOrUnsetEnv("AWS_REGION", dataCredential.Region)
} else {
// No explicit credential — clear any stale values and let the
// SDK resolve ambient credentials (IRSA, pod identity, etc.).
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_REGION")
}
os.Unsetenv("AWS_SESSION_TOKEN")
dataCredential, hasExplicitCredential := credentialInfo.Auth.Data[storageBackend.GetProfile()]

var commandArgs []string

Expand All @@ -542,6 +575,7 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string,

mountS3Path := common.ResolveCommandPath("MOUNT_S3_PATH", "mount-s3", "/usr/bin/mount-s3")
cmd := exec.Command(mountS3Path, commandArgs...)
cmd.Env = awsCredentialCommandEnv(os.Environ(), dataCredential, hasExplicitCredential)
cmd.Stderr = log
if err = cmd.Run(); err != nil {
if strings.Contains(err.Error(), "Timeout") {
Expand Down
79 changes: 79 additions & 0 deletions src/runtime/pkg/data/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,88 @@ package data

import (
"slices"
"strings"
"testing"
)

func testEnvMap(env []string) map[string]string {
envMap := make(map[string]string, len(env))
for _, entry := range env {
key, value, found := strings.Cut(entry, "=")
if found {
envMap[key] = value
}
}
return envMap
}

func requireEnvValue(t *testing.T, envMap map[string]string, key string, expected string) {
t.Helper()
if got := envMap[key]; got != expected {
t.Fatalf("expected %s=%q, got %q", key, expected, got)
}
}

func requireEnvMissing(t *testing.T, envMap map[string]string, key string) {
t.Helper()
if value, ok := envMap[key]; ok {
t.Fatalf("expected %s to be unset, got %q", key, value)
}
}

func TestAwsCredentialCommandEnvStaticCredentialOverridesAmbient(t *testing.T) {
env := awsCredentialCommandEnv([]string{
"AWS_ACCESS_KEY_ID=ambient-id",
"AWS_SECRET_ACCESS_KEY=ambient-secret",
"AWS_SESSION_TOKEN=ambient-session",
"AWS_REGION=us-east-1",
"OSMO_TEST=value",
}, DataCredential{
AccessKeyId: "static-id",
AccessKey: "static-secret",
Region: "us-west-2",
}, true)

envMap := testEnvMap(env)
requireEnvValue(t, envMap, "AWS_ACCESS_KEY_ID", "static-id")
requireEnvValue(t, envMap, "AWS_SECRET_ACCESS_KEY", "static-secret")
requireEnvValue(t, envMap, "AWS_REGION", "us-west-2")
requireEnvValue(t, envMap, "OSMO_TEST", "value")
requireEnvMissing(t, envMap, "AWS_SESSION_TOKEN")
}

func TestAwsCredentialCommandEnvDefaultCredentialPreservesAmbientAuth(t *testing.T) {
env := awsCredentialCommandEnv([]string{
"AWS_ACCESS_KEY_ID=ambient-id",
"AWS_SECRET_ACCESS_KEY=ambient-secret",
"AWS_SESSION_TOKEN=ambient-session",
"AWS_REGION=us-east-1",
}, DataCredential{
Region: "us-west-2",
}, true)

envMap := testEnvMap(env)
requireEnvValue(t, envMap, "AWS_ACCESS_KEY_ID", "ambient-id")
requireEnvValue(t, envMap, "AWS_SECRET_ACCESS_KEY", "ambient-secret")
requireEnvValue(t, envMap, "AWS_SESSION_TOKEN", "ambient-session")
requireEnvValue(t, envMap, "AWS_REGION", "us-west-2")
}

func TestAwsCredentialCommandEnvNoCredentialPreservesAmbientAuth(t *testing.T) {
env := awsCredentialCommandEnv([]string{
"AWS_ACCESS_KEY_ID=ambient-id",
"AWS_SECRET_ACCESS_KEY=ambient-secret",
"AWS_SESSION_TOKEN=ambient-session",
"AWS_REGION=us-east-1",
}, DataCredential{}, false)

envMap := testEnvMap(env)
requireEnvValue(t, envMap, "AWS_ACCESS_KEY_ID", "ambient-id")
requireEnvValue(t, envMap, "AWS_SECRET_ACCESS_KEY", "ambient-secret")
requireEnvValue(t, envMap, "AWS_SESSION_TOKEN", "ambient-session")
requireEnvValue(t, envMap, "AWS_REGION", "us-east-1")
}

func TestBuildMountCommandArgsS3VirtualCustomEndpoint(t *testing.T) {
backend := ParseStorageBackend("s3://coreweave-bucket/datasets")
credential := DataCredential{
Expand Down
4 changes: 2 additions & 2 deletions src/utils/connectors/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,15 +1528,15 @@ def get_data_cred(self, user: str, profile: str) -> credentials.DataCredential |

return None

def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCredential]:
def get_all_data_creds(self, user: str) -> Dict[str, credentials.DataCredential]:
""" Fetch all data credentials for user. """
select_data_cmd = PostgresSelectCommand(
table='credential',
conditions=['user_name = %s', 'cred_type = %s'],
condition_args=[user, CredentialType.DATA.value])
rows = self.execute_fetch_command(*select_data_cmd.get_args())

user_creds: Dict[str, credentials.StaticDataCredential] = {
user_creds: Dict[str, credentials.DataCredential] = {
cred.profile: credentials.StaticDataCredential(
endpoint=cred.profile,
**self.decrypt_credential(cred),
Expand Down
27 changes: 27 additions & 0 deletions src/utils/connectors/tests/test_cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
import unittest

from src.lib.data.storage import credentials
from src.lib.utils import osmo_errors
from src.utils import connectors

Expand Down Expand Up @@ -46,5 +47,31 @@ def test_rejects_malformed_min_supported_version(self):
connectors.CliConfig(min_supported_version=bad)


class TestWorkflowStorageConfig(unittest.TestCase):
"""Validation tests for workflow storage credentials."""

def test_workflow_log_accepts_default_data_credential(self):
log_config = connectors.LogConfig(credential={
'endpoint': 's3://osmo-logs',
'region': 'us-west-2',
})

credential = log_config.credential
if not isinstance(credential, credentials.DefaultDataCredential):
self.fail('workflow_log credential should use DefaultDataCredential')
self.assertEqual(credential.endpoint, 's3://osmo-logs')

def test_workflow_data_accepts_default_data_credential(self):
data_config = connectors.DataConfig(credential={
'endpoint': 's3://osmo-data',
'region': 'us-west-2',
})

credential = data_config.credential
if not isinstance(credential, credentials.DefaultDataCredential):
self.fail('workflow_data credential should use DefaultDataCredential')
self.assertEqual(credential.endpoint, 's3://osmo-data')


if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions src/utils/job/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_login_dict(user: str,


def create_config_dict(
data_info: Mapping[str, credentials.StaticDataCredential | credentials.DefaultDataCredential],
data_info: Mapping[str, credentials.DataCredential],
) -> dict:
'''
Creates the config dict where the input should be a dict containing key values like:
Expand Down Expand Up @@ -2700,7 +2700,7 @@ def convert_to_pod_spec(
service_config: connectors.ServiceConfig | None = None,
dataset_config: connectors.DatasetConfig | None = None,
pool_info: connectors.Pool | None = None,
data_endpoints: Mapping[str, credentials.StaticDataCredential] | None = None,
data_endpoints: Mapping[str, credentials.DataCredential] | None = None,
skip_refresh_token: bool = False,
auth_token: str | None = None,
) -> Tuple[Dict, Dict[str, kb_objects.FileMount], Optional[Tuple[str, str]]]:
Expand Down Expand Up @@ -3194,10 +3194,10 @@ def decode_hstore(tasks: str) -> Set[str]:

def fetch_creds(
user: str,
data_creds: Mapping[str, credentials.StaticDataCredential],
data_creds: Mapping[str, credentials.DataCredential],
path: str,
disabled_data: list[str] | None = None,
) -> credentials.StaticDataCredential | None:
) -> credentials.DataCredential | None:
backend_info = storage.construct_storage_backend(path)

if backend_info.profile not in data_creds:
Expand Down