Skip to content
220 changes: 215 additions & 5 deletions service/pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,21 @@ type MockLoader struct {
}

func (l *MockLoader) Load(mostRecentConfig Config) error {
l.loadCalled = true
if l.loadFn != nil {
return l.loadFn(mostRecentConfig)
}
return nil
}

func (l *MockLoader) Get(key string) (any, error) {
l.loadCalled = true
if l.getFn != nil {
return l.getFn(key)
}
return nil, errors.New("not setup for Get")
}

func (l *MockLoader) GetConfigKeys() ([]string, error) {
l.loadCalled = true
if l.loadFn != nil {
if l.getConfigKeysFn != nil {
return l.getConfigKeysFn()
}
return nil, nil
Expand Down Expand Up @@ -75,7 +72,7 @@ func (l *MockLoader) Name() string {
if l.getNameFn != nil {
return l.getNameFn()
}
return ""
return "mock"
}

func newMockLoader() *MockLoader {
Expand Down Expand Up @@ -310,3 +307,216 @@ server:
assert.Equal(t, "abc", config.Services["service_a"]["value1"])
assert.Equal(t, "def", config.Services["service_a"]["value2"])
}

// TestLoad_Precedence is a matrix test that verifies the loading order
// and precedence of different configuration sources.
func TestLoad_Precedence(t *testing.T) {
// Helper to create a temp config file for tests
newTempConfigFile := func(t *testing.T, content string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "config-*.yaml")
require.NoError(t, err)
_, err = f.WriteString(content)
require.NoError(t, err)
require.NoError(t, f.Close())
return f.Name()
}

testCases := []struct {
name string
setupLoaders func(t *testing.T, configFile string) []Loader
envVars map[string]string
fileContent string
asserts func(t *testing.T, cfg *Config)
}{
{
name: "defaults only",
setupLoaders: func(t *testing.T, _ string) []Loader {
defaults, err := NewDefaultSettingsLoader()
require.NoError(t, err)
return []Loader{defaults}
},
asserts: func(t *testing.T, cfg *Config) {
// Assert values from `default` struct tags
assert.Equal(t, []string{"all"}, cfg.Mode)
assert.Equal(t, "info", cfg.Logger.Level)
assert.Equal(t, 8080, cfg.Server.Port)
},
},
{
name: "file overrides defaults",
setupLoaders: func(t *testing.T, configFile string) []Loader {
file, err := NewConfigFileLoader("test", configFile)
require.NoError(t, err)
defaults, err := NewDefaultSettingsLoader()
require.NoError(t, err)
// File loader comes first, so it has higher priority
return []Loader{file, defaults}
},
fileContent: `
server:
port: 9090
logger:
level: warn
`,
asserts: func(t *testing.T, cfg *Config) {
// Values from file
assert.Equal(t, 9090, cfg.Server.Port)
assert.Equal(t, "warn", cfg.Logger.Level)
// Value from defaults
assert.Equal(t, []string{"all"}, cfg.Mode)
},
},
{
name: "file with extras and defaults",
setupLoaders: func(t *testing.T, configFile string) []Loader {
file, err := NewConfigFileLoader("test", configFile)
require.NoError(t, err)
defaults, err := NewDefaultSettingsLoader()
require.NoError(t, err)
// File loader comes first, so it has higher priority
return []Loader{file, defaults}
},
fileContent: `
server:
port: 9090
public_hostname: "test.host"
logger:
level: warn
special_key:
nested:
special_value: 123
`,
asserts: func(t *testing.T, cfg *Config) {
// Values from file
assert.Equal(t, "test.host", cfg.Server.PublicHostname)
assert.Equal(t, 9090, cfg.Server.Port)
assert.Equal(t, "warn", cfg.Logger.Level)
// Value from defaults
assert.Equal(t, []string{"all"}, cfg.Mode)
},
},
{
name: "mocked env overrides file and defaults",
envVars: map[string]string{
"TEST_SERVER_PORT": "9999",
"TEST_LOGGER_LEVEL": "debug",
"TEST_DB_HOST": "env.host",
"TEST_SERVICES_FOO_BAR": "baz",
},
setupLoaders: func(t *testing.T, configFile string) []Loader {
// Use a mock for env to test the Reload precedence logic
envLoader, err := NewEnvironmentValueLoader("TEST", nil)
require.NoError(t, err)

file, err := NewConfigFileLoader("test", configFile)
require.NoError(t, err)
defaults, err := NewDefaultSettingsLoader()
require.NoError(t, err)
// Order: env > file > defaults
return []Loader{envLoader, file, defaults}
},
fileContent: `
server:
port: 9090
logger:
level: warn
db:
host: file.host
`,
asserts: func(t *testing.T, cfg *Config) {
// Values from mocked env
assert.Equal(t, 9999, cfg.Server.Port)
assert.Equal(t, "debug", cfg.Logger.Level)
assert.Equal(t, "env.host", cfg.DB.Host)

// Value from defaults (not overridden by file or env)
assert.Equal(t, []string{"all"}, cfg.Mode)

// Value from service map in mocked env
require.Contains(t, cfg.Services, "foo")
assert.Equal(t, "baz", cfg.Services["foo"]["bar"])
},
}, {
name: "env with allow list allows key",
envVars: map[string]string{
"TEST_SERVER_PORT": "9999", // This should be loaded
},
setupLoaders: func(t *testing.T, configFile string) []Loader {
// Allow list only contains server.port
allowList := []string{"server.port"}
envLoader, err := NewEnvironmentValueLoader("TEST", allowList)
require.NoError(t, err)

file, err := NewConfigFileLoader("test", configFile)
require.NoError(t, err)
return []Loader{envLoader, file}
},
fileContent: `
server:
port: 8888
`,
asserts: func(t *testing.T, cfg *Config) {
// The allowed env var should override the file value.
assert.Equal(t, 9999, cfg.Server.Port)
},
},
{
name: "env with allow list blocks key",
envVars: map[string]string{
"TEST_SERVER_PORT": "9999", // This should be BLOCKED
"TEST_LOGGER_LEVEL": "debug", // This should be ALLOWED
},
setupLoaders: func(t *testing.T, configFile string) []Loader {
// Allow list does NOT contain server.port
allowList := []string{"logger.level"}
envLoader, err := NewEnvironmentValueLoader("TEST", allowList)
require.NoError(t, err)

file, err := NewConfigFileLoader("test", configFile)
require.NoError(t, err)
defaults, err := NewDefaultSettingsLoader()
require.NoError(t, err)
return []Loader{envLoader, file, defaults}
},
fileContent: `
server:
port: 8888
logger:
level: info
`,
asserts: func(t *testing.T, cfg *Config) {
// The server.port env var was blocked, so the value from the file takes precedence.
assert.Equal(t, 8888, cfg.Server.Port)
// The logger.level env var was allowed, so it overrides the file value.
assert.Equal(t, "debug", cfg.Logger.Level)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Setup env vars
for k, v := range tc.envVars {
t.Setenv(k, v)
}

// Setup config file
configFile := ""
if tc.fileContent != "" {
configFile = newTempConfigFile(t, tc.fileContent)
}

// Setup loaders
loaders := tc.setupLoaders(t, configFile)

// Load config
cfg, err := Load(context.Background(), loaders...)

// Assertions
require.NoError(t, err)
require.NotNil(t, cfg)
tc.asserts(t, cfg)
})
}
}
61 changes: 45 additions & 16 deletions service/pkg/config/environment_value_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,24 @@ package config
import (
"context"
"fmt"
"log/slog"
"os"
"strings"

"github.com/spf13/viper"
)

const LoaderNameEnvironmentValue = "environment-value"

// EnvironmentValueLoader implements Loader using Viper
type EnvironmentValueLoader struct {
allowListMap map[string]struct{}
viper *viper.Viper
envKeyPrefix string
loadedKeys []string
loadedValues map[string]string
}

// NewEnvironmentValueLoader creates a new Viper-based configuration loader
// to load from environment variables, from a default or specified file
// (or k8s config map), or some combination
func NewEnvironmentValueLoader(key string, allowList []string) (*EnvironmentValueLoader, error) {
// Set paths and config file info
v := viper.NewWithOptions(viper.WithLogger(slog.Default()))

// Environment variable settings
v.SetEnvPrefix(key)
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()

var allowListMap map[string]struct{}
if allowList != nil || len(allowList) > 0 {
allowListMap = make(map[string]struct{})
Expand All @@ -39,7 +31,7 @@ func NewEnvironmentValueLoader(key string, allowList []string) (*EnvironmentValu

result := &EnvironmentValueLoader{
allowListMap: allowListMap,
viper: v,
envKeyPrefix: strings.ToUpper(key) + "_",
}
return result, nil
}
Expand All @@ -51,17 +43,54 @@ func (l *EnvironmentValueLoader) Get(key string) (any, error) {
return nil, fmt.Errorf("environment value %s is not allowed", key)
}
}
return l.viper.Get(key), nil
value, found := l.loadedValues[key]
if found {
return value, nil
}
return nil, nil
}

// GetConfigKeys returns all the configuration keys found in the environment variables.
func (l *EnvironmentValueLoader) GetConfigKeys() ([]string, error) {
return l.viper.AllKeys(), nil
return l.loadedKeys, nil
}

// Load loads the configuration into the provided struct
func (l *EnvironmentValueLoader) Load(_ Config) error {
// For environment variables, Viper's `AutomaticEnv` handles this, so no explicit load is needed here.
var loadedKeys []string
loadedValues := make(map[string]string)

env := os.Environ()
for _, kv := range env {
upperKV := strings.ToUpper(kv)
if strings.HasPrefix(upperKV, l.envKeyPrefix) {
eqIdx := strings.Index(upperKV, "=")
if eqIdx == -1 {
continue
}
envKey := kv[len(l.envKeyPrefix):eqIdx]
envValue := kv[eqIdx+1:]
dottedKey := strings.ToLower(strings.ReplaceAll(envKey, "_", "."))
if l.allowListMap != nil {
if _, keyInAllowList := l.allowListMap[dottedKey]; !keyInAllowList {
// This key is not allowed, skip it
continue
}
}

loadedKeys = append(loadedKeys, dottedKey)
loadedValues[dottedKey] = envValue
}
}

if len(loadedKeys) > 0 {
l.loadedKeys = loadedKeys
l.loadedValues = loadedValues
} else {
l.loadedKeys = nil
l.loadedValues = nil
}

return nil
}

Expand Down
Loading