diff --git a/service/cmd/migrate.go b/service/cmd/migrate.go index 5422c684a..6a15718b9 100644 --- a/service/cmd/migrate.go +++ b/service/cmd/migrate.go @@ -121,11 +121,7 @@ func migrateService(cmd *cobra.Command, args []string, migrationFunc func(*db.Cl func migrateDBClient(cmd *cobra.Command, opts ...db.OptsFunc) (*db.Client, error) { configFile, _ := cmd.Flags().GetString(configFileFlag) configKey, _ := cmd.Flags().GetString(configKeyFlag) - envLoader, err := config.NewEnvironmentValueLoader(configKey, nil) - if err != nil { - panic(fmt.Errorf("could not load config: %w", err)) - } - configFileLoader, err := config.NewConfigFileLoader(configKey, configFile) + legacyLoader, err := config.NewLegacyLoader(configKey, configFile) if err != nil { panic(fmt.Errorf("could not load config: %w", err)) } @@ -135,8 +131,7 @@ func migrateDBClient(cmd *cobra.Command, opts ...db.OptsFunc) (*db.Client, error } conf, err := config.Load( cmd.Context(), - envLoader, - configFileLoader, + legacyLoader, defaultSettingsLoader, ) if err != nil { diff --git a/service/cmd/policy.go b/service/cmd/policy.go index 8c84cd701..0b0593740 100644 --- a/service/cmd/policy.go +++ b/service/cmd/policy.go @@ -35,11 +35,7 @@ var ( Run: func(cmd *cobra.Command, _ []string) { configFile, _ := cmd.Flags().GetString(configFileFlag) configKey, _ := cmd.Flags().GetString(configKeyFlag) - envLoader, err := config.NewEnvironmentValueLoader(configKey, nil) - if err != nil { - panic(fmt.Errorf("could not load config: %w", err)) - } - configFileLoader, err := config.NewConfigFileLoader(configKey, configFile) + legacyLoader, err := config.NewLegacyLoader(configKey, configFile) if err != nil { panic(fmt.Errorf("could not load config: %w", err)) } @@ -49,8 +45,7 @@ var ( } cfg, err := config.Load( cmd.Context(), - envLoader, - configFileLoader, + legacyLoader, defaultSettingsLoader, ) if err != nil { diff --git a/service/cmd/provisionFixtures.go b/service/cmd/provisionFixtures.go index 29184fc7d..2e120998e 100644 --- a/service/cmd/provisionFixtures.go +++ b/service/cmd/provisionFixtures.go @@ -38,11 +38,7 @@ You can clear/recycle your database with 'docker compose down' and 'docker compo Run: func(cmd *cobra.Command, _ []string) { configFile, _ := cmd.Flags().GetString(configFileFlag) configKey, _ := cmd.Flags().GetString(configKeyFlag) - envLoader, err := config.NewEnvironmentValueLoader(configKey, nil) - if err != nil { - panic(fmt.Errorf("could not load config: %w", err)) - } - configFileLoader, err := config.NewConfigFileLoader(configKey, configFile) + legacyLoader, err := config.NewLegacyLoader(configKey, configFile) if err != nil { panic(fmt.Errorf("could not load config: %w", err)) } @@ -52,8 +48,7 @@ You can clear/recycle your database with 'docker compose down' and 'docker compo } cfg, err := config.Load( cmd.Context(), - envLoader, - configFileLoader, + legacyLoader, defaultSettingsLoader, ) if err != nil { diff --git a/service/pkg/config/config.go b/service/pkg/config/config.go index 26d1510b8..a8528e4f4 100644 --- a/service/pkg/config/config.go +++ b/service/pkg/config/config.go @@ -270,11 +270,7 @@ func (c SDKConfig) LogValue() slog.Value { // Deprecated: Use the `Load` method with your preferred loaders func LoadConfig(ctx context.Context, key, file string) (*Config, error) { - envLoader, err := NewEnvironmentValueLoader(key, nil) - if err != nil { - return nil, fmt.Errorf("could not load config: %w", err) - } - configFileLoader, err := NewConfigFileLoader(key, file) + legacyLoader, err := NewLegacyLoader(key, file) if err != nil { return nil, fmt.Errorf("could not load config: %w", err) } @@ -284,8 +280,7 @@ func LoadConfig(ctx context.Context, key, file string) (*Config, error) { } return Load( ctx, - envLoader, - configFileLoader, + legacyLoader, defaultSettingsLoader, ) } diff --git a/service/pkg/config/config_test.go b/service/pkg/config/config_test.go index 9d20afc3c..569b6c32f 100644 --- a/service/pkg/config/config_test.go +++ b/service/pkg/config/config_test.go @@ -10,6 +10,10 @@ import ( "github.com/stretchr/testify/require" ) +const ( + configKey = "test" +) + // Manual mock implementation of Loader type MockLoader struct { loadFn func(Config) error @@ -19,7 +23,6 @@ type MockLoader struct { closeFn func() error getNameFn func() string - loadCalled bool watchCalled bool closeCalled bool getNameCalled bool @@ -28,7 +31,6 @@ type MockLoader struct { } func (l *MockLoader) Load(mostRecentConfig Config) error { - l.loadCalled = true if l.loadFn != nil { return l.loadFn(mostRecentConfig) } @@ -36,7 +38,6 @@ func (l *MockLoader) Load(mostRecentConfig Config) error { } func (l *MockLoader) Get(key string) (any, error) { - l.loadCalled = true if l.getFn != nil { return l.getFn(key) } @@ -44,8 +45,7 @@ func (l *MockLoader) Get(key string) (any, error) { } func (l *MockLoader) GetConfigKeys() ([]string, error) { - l.loadCalled = true - if l.loadFn != nil { + if l.getConfigKeysFn != nil { return l.getConfigKeysFn() } return nil, nil @@ -75,7 +75,7 @@ func (l *MockLoader) Name() string { if l.getNameFn != nil { return l.getNameFn() } - return "" + return "mock" } func newMockLoader() *MockLoader { @@ -233,9 +233,9 @@ func TestConfig_OnChange(t *testing.T) { func TestLoadConfig_NoFileExistsInEnv(t *testing.T) { ctx := t.Context() - envLoader, err := NewEnvironmentValueLoader("test", nil) + envLoader, err := NewEnvironmentValueLoader(configKey, nil) require.NoError(t, err) - configFileLoader, err := NewConfigFileLoader("test", "non-existent-file") + configFileLoader, err := NewConfigFileLoader(configKey, "non-existent-file") require.NoError(t, err) _, err = Load( ctx, @@ -279,9 +279,9 @@ server: // Call LoadConfig with the temp file ctx := t.Context() - envLoader, err := NewEnvironmentValueLoader("test", nil) + envLoader, err := NewEnvironmentValueLoader(configKey, nil) require.NoError(t, err) - configFileLoader, err := NewConfigFileLoader("test", tempFile.Name()) + configFileLoader, err := NewConfigFileLoader(configKey, tempFile.Name()) require.NoError(t, err) config, err := Load( ctx, @@ -310,3 +310,339 @@ server: assert.Equal(t, "abc", config.Services["service_a"]["value1"]) assert.Equal(t, "def", config.Services["service_a"]["value2"]) } + +// TestLoad_Precedence is a matrix test that verifies the loading order +// and precedence of different configuration sources. +func TestLoad_Precedence(t *testing.T) { + // Helper to create a temp config file for tests + newTempConfigFile := func(t *testing.T, content string) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "config-*.yaml") + require.NoError(t, err) + _, err = f.WriteString(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + return f.Name() + } + + testCases := []struct { + name string + setupLoaders func(t *testing.T, configFile string) []Loader + envVars map[string]string + err error + fileContent string + asserts func(t *testing.T, cfg *Config) + }{ + { + name: "defaults only", + setupLoaders: func(t *testing.T, _ string) []Loader { + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + return []Loader{defaults} + }, + asserts: func(t *testing.T, cfg *Config) { + // Assert values from `default` struct tags + assert.Equal(t, []string{"all"}, cfg.Mode) + assert.Equal(t, "info", cfg.Logger.Level) + assert.Equal(t, 8080, cfg.Server.Port) + }, + }, + { + name: "file overrides defaults", + setupLoaders: func(t *testing.T, configFile string) []Loader { + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + // File loader comes first, so it has higher priority + return []Loader{file, defaults} + }, + fileContent: ` +server: + port: 9090 +logger: + level: warn +`, + asserts: func(t *testing.T, cfg *Config) { + // Values from file + assert.Equal(t, 9090, cfg.Server.Port) + assert.Equal(t, "warn", cfg.Logger.Level) + // Value from defaults + assert.Equal(t, []string{"all"}, cfg.Mode) + }, + }, + { + name: "file with extras and defaults", + setupLoaders: func(t *testing.T, configFile string) []Loader { + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + // File loader comes first, so it has higher priority + return []Loader{file, defaults} + }, + fileContent: ` +server: + port: 9090 + public_hostname: "test.host" +logger: + level: warn +special_key: + nested: + special_value: 123 +`, + asserts: func(t *testing.T, cfg *Config) { + // Values from file + assert.Equal(t, "test.host", cfg.Server.PublicHostname) + assert.Equal(t, 9090, cfg.Server.Port) + assert.Equal(t, "warn", cfg.Logger.Level) + // Value from defaults + assert.Equal(t, []string{"all"}, cfg.Mode) + }, + }, + { + name: "env overrides file and defaults except client_id", + setupLoaders: func(t *testing.T, configFile string) []Loader { + envLoader, err := NewEnvironmentValueLoader(configKey, nil) + require.NoError(t, err) + + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + // Order: env > file > defaults + return []Loader{envLoader, file, defaults} + }, + envVars: map[string]string{ + "TEST_SERVER_PORT": "9999", + "TEST_LOGGER_LEVEL": "debug", + "TEST_DB_HOST": "env.host", + "TEST_SDK_CONFIG_CLIENT_ID": "client-from-env", + "TEST_SDK_CONFIG_CLIENT_SECRET": "secret-from-env", + "TEST_SERVICES_FOO_BAR": "baz", + }, + fileContent: ` +server: + port: 9090 +logger: + level: warn +db: + host: file.host +sdk_config: + client_id: client-from-file + client_secret: secret-from-file +`, + asserts: func(t *testing.T, cfg *Config) { + // Values from env + assert.Equal(t, 9999, cfg.Server.Port) + assert.Equal(t, "debug", cfg.Logger.Level) + assert.Equal(t, "env.host", cfg.DB.Host) + + // Different from the LegacyLoader below + assert.Equal(t, "client-from-file", cfg.SDKConfig.ClientID) + assert.Equal(t, "secret-from-file", cfg.SDKConfig.ClientSecret) + + // Value from defaults (not overridden by file or env) + assert.Equal(t, []string{"all"}, cfg.Mode) + + // Value placed into service map in env + // Different from the LegacyLoader below + require.Contains(t, cfg.Services, "foo") + assert.Equal(t, "baz", cfg.Services["foo"]["bar"]) + }, + }, + { + name: "env from legacy overrides file and defaults", + setupLoaders: func(t *testing.T, configFile string) []Loader { + legacyLoader, err := NewLegacyLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + // Order: env > file > defaults + return []Loader{legacyLoader, defaults} + }, + envVars: map[string]string{ + "TEST_SERVER_PORT": "9999", + "TEST_LOGGER_LEVEL": "debug", + "TEST_DB_HOST": "env.host", + "TEST_SDK_CONFIG_CLIENT_ID": "client-from-env", + "TEST_SDK_CONFIG_CLIENT_SECRET": "secret-from-env", + "TEST_SERVICES_FOO_BAR": "baz", + }, + fileContent: ` +server: + port: 9090 +logger: + level: warn +db: + host: file.host +sdk_config: + client_id: client-from-file + client_secret: secret-from-file +`, + asserts: func(t *testing.T, cfg *Config) { + // Values from env + assert.Equal(t, 9999, cfg.Server.Port) + assert.Equal(t, "debug", cfg.Logger.Level) + assert.Equal(t, "env.host", cfg.DB.Host) + + // Different from the EnvironmentValueLoader above + assert.Equal(t, "client-from-env", cfg.SDKConfig.ClientID) + assert.Equal(t, "secret-from-env", cfg.SDKConfig.ClientSecret) + + // Value from defaults (not overridden by file or env) + assert.Equal(t, []string{"all"}, cfg.Mode) + + // Value not placed service map in env + // Different from the EnvironmentValueLoader above + require.NotContains(t, cfg.Services, "foo") + }, + }, + { + name: "env does not override undefined snake-case YAML keys", + setupLoaders: func(t *testing.T, configFile string) []Loader { + envLoader, err := NewEnvironmentValueLoader(configKey, nil) + require.NoError(t, err) + + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + // Order: env > file > defaults + return []Loader{envLoader, file, defaults} + }, + envVars: map[string]string{ + "TEST_SDK_CONFIG_CLIENT_ID": "client-from-env", + "TEST_SDK_CONFIG_CLIENT_SECRET": "secret-from-env", + }, + fileContent: ` +server: + port: 9090 +logger: + level: warn +db: + host: file.host +`, + asserts: func(t *testing.T, cfg *Config) { + // Same as the LegacyLoader below + assert.Empty(t, cfg.SDKConfig.ClientID) + assert.Empty(t, cfg.SDKConfig.ClientSecret) + }, + }, + { + name: "env from legacy does not override undefined snake-case YAML keys", + setupLoaders: func(t *testing.T, configFile string) []Loader { + legacyLoader, err := NewLegacyLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + // Order: env > file > defaults + return []Loader{legacyLoader, defaults} + }, + envVars: map[string]string{ + "TEST_SDK_CONFIG_CLIENT_ID": "client-from-env", + "TEST_SDK_CONFIG_CLIENT_SECRET": "secret-from-env", + }, + fileContent: ` +server: + port: 9090 +logger: + level: warn +db: + host: file.host +`, + asserts: func(t *testing.T, cfg *Config) { + // Same as the EnvironmentValueLoader above + assert.Empty(t, cfg.SDKConfig.ClientID) + assert.Empty(t, cfg.SDKConfig.ClientSecret) + }, + }, + { + name: "env with allow list allows key", + envVars: map[string]string{ + "TEST_SERVER_PORT": "9999", // This should be loaded + }, + setupLoaders: func(t *testing.T, configFile string) []Loader { + // Allow list only contains server.port + allowList := []string{"server.port"} + envLoader, err := NewEnvironmentValueLoader(configKey, allowList) + require.NoError(t, err) + + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + return []Loader{envLoader, file} + }, + fileContent: ` +server: + port: 8888 +`, + asserts: func(t *testing.T, cfg *Config) { + // The allowed env var should override the file value. + assert.Equal(t, 9999, cfg.Server.Port) + }, + }, + { + name: "env with allow list blocks key", + envVars: map[string]string{ + "TEST_SERVER_PORT": "9999", // This should be BLOCKED + "TEST_LOGGER_LEVEL": "debug", // This should be ALLOWED + }, + setupLoaders: func(t *testing.T, configFile string) []Loader { + // Allow list does NOT contain server.port + allowList := []string{"logger.level"} + envLoader, err := NewEnvironmentValueLoader(configKey, allowList) + require.NoError(t, err) + + file, err := NewConfigFileLoader(configKey, configFile) + require.NoError(t, err) + defaults, err := NewDefaultSettingsLoader() + require.NoError(t, err) + return []Loader{envLoader, file, defaults} + }, + fileContent: ` +server: + port: 8888 +logger: + level: info +`, + asserts: func(t *testing.T, cfg *Config) { + // The server.port env var was blocked, so the value from the file takes precedence. + assert.Equal(t, 8888, cfg.Server.Port) + // The logger.level env var was allowed, so it overrides the file value. + assert.Equal(t, "debug", cfg.Logger.Level) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup env vars + for k, v := range tc.envVars { + t.Setenv(k, v) + } + + // Setup config file + configFile := "" + if tc.fileContent != "" { + configFile = newTempConfigFile(t, tc.fileContent) + } + + // Setup loaders + loaders := tc.setupLoaders(t, configFile) + + // Load config + cfg, err := Load(t.Context(), loaders...) + + // Assertions + if tc.err != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, cfg) + } + if tc.asserts != nil { + tc.asserts(t, cfg) + } + }) + } +} diff --git a/service/pkg/config/environment_value_loader.go b/service/pkg/config/environment_value_loader.go index ef57c54bf..d2821362c 100644 --- a/service/pkg/config/environment_value_loader.go +++ b/service/pkg/config/environment_value_loader.go @@ -3,10 +3,8 @@ package config import ( "context" "fmt" - "log/slog" + "os" "strings" - - "github.com/spf13/viper" ) const LoaderNameEnvironmentValue = "environment-value" @@ -14,21 +12,15 @@ const LoaderNameEnvironmentValue = "environment-value" // EnvironmentValueLoader implements Loader using Viper type EnvironmentValueLoader struct { allowListMap map[string]struct{} - viper *viper.Viper + envKeyPrefix string + loadedKeys []string + loadedValues map[string]string } // NewEnvironmentValueLoader creates a new Viper-based configuration loader // to load from environment variables, from a default or specified file // (or k8s config map), or some combination func NewEnvironmentValueLoader(key string, allowList []string) (*EnvironmentValueLoader, error) { - // Set paths and config file info - v := viper.NewWithOptions(viper.WithLogger(slog.Default())) - - // Environment variable settings - v.SetEnvPrefix(key) - v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - v.AutomaticEnv() - var allowListMap map[string]struct{} if allowList != nil || len(allowList) > 0 { allowListMap = make(map[string]struct{}) @@ -39,7 +31,7 @@ func NewEnvironmentValueLoader(key string, allowList []string) (*EnvironmentValu result := &EnvironmentValueLoader{ allowListMap: allowListMap, - viper: v, + envKeyPrefix: strings.ToUpper(key) + "_", } return result, nil } @@ -51,17 +43,54 @@ func (l *EnvironmentValueLoader) Get(key string) (any, error) { return nil, fmt.Errorf("environment value %s is not allowed", key) } } - return l.viper.Get(key), nil + value, found := l.loadedValues[key] + if found { + return value, nil + } + return nil, nil //nolint:nilnil // Not an error, value doesn't exist } // GetConfigKeys returns all the configuration keys found in the environment variables. func (l *EnvironmentValueLoader) GetConfigKeys() ([]string, error) { - return l.viper.AllKeys(), nil + return l.loadedKeys, nil } // Load loads the configuration into the provided struct func (l *EnvironmentValueLoader) Load(_ Config) error { - // For environment variables, Viper's `AutomaticEnv` handles this, so no explicit load is needed here. + var loadedKeys []string + loadedValues := make(map[string]string) + + env := os.Environ() + for _, kv := range env { + upperKV := strings.ToUpper(kv) + if strings.HasPrefix(upperKV, l.envKeyPrefix) { + eqIdx := strings.Index(upperKV, "=") + if eqIdx == -1 { + continue + } + envKey := kv[len(l.envKeyPrefix):eqIdx] + envValue := kv[eqIdx+1:] + dottedKey := strings.ToLower(strings.ReplaceAll(envKey, "_", ".")) + if l.allowListMap != nil { + if _, keyInAllowList := l.allowListMap[dottedKey]; !keyInAllowList { + // This key is not allowed, skip it + continue + } + } + + loadedKeys = append(loadedKeys, dottedKey) + loadedValues[dottedKey] = envValue + } + } + + if len(loadedKeys) > 0 { + l.loadedKeys = loadedKeys + l.loadedValues = loadedValues + } else { + l.loadedKeys = nil + l.loadedValues = nil + } + return nil } diff --git a/service/pkg/config/legacy_loader.go b/service/pkg/config/legacy_loader.go new file mode 100644 index 000000000..6c5f651a4 --- /dev/null +++ b/service/pkg/config/legacy_loader.go @@ -0,0 +1,99 @@ +package config + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "strings" + + "github.com/fsnotify/fsnotify" + "github.com/spf13/viper" +) + +const LoaderNameLegacy = "legacy" + +// LegacyLoader enables loading values from a YAML file and the environment together +type LegacyLoader struct { + viper *viper.Viper +} + +// NewLegacyLoader creates a new Viper-based configuration loader +// to load from a default or specified file. +func NewLegacyLoader(key, file string) (*LegacyLoader, error) { + homedir, err := os.UserHomeDir() + if err != nil { + return nil, errors.Join(err, ErrLoadingConfig) + } + + // Set paths and config file info + v := viper.NewWithOptions(viper.WithLogger(slog.Default())) + v.AddConfigPath(fmt.Sprintf("%s/."+key, homedir)) + v.AddConfigPath("." + key) + v.AddConfigPath(".") + v.SetConfigName(key) + v.SetConfigType("yaml") + + // Default config values (non-zero) + v.SetDefault("server.auth.cache_refresh_interval", "15m") + + // Environment variable settings + v.SetEnvPrefix(key) + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.AutomaticEnv() + + // Allow for a custom config file to be passed in + // This takes precedence over the AddConfigPath/SetConfigName + if file != "" { + v.SetConfigFile(file) + } + + return &LegacyLoader{viper: v}, nil +} + +// Get fetches a particular config value by dot-delimited key from the source +func (l *LegacyLoader) Get(key string) (any, error) { + return l.viper.Get(key), nil +} + +// GetConfigKeys returns all the configuration keys found in the config file. +func (l *LegacyLoader) GetConfigKeys() ([]string, error) { + return l.viper.AllKeys(), nil +} + +// Load is called to load/refresh the configuration from its source +func (l *LegacyLoader) Load(cfg Config) error { + // Read the config file + if err := l.viper.ReadInConfig(); err != nil { + return errors.Join(err, ErrLoadingConfig) + } + + err := l.viper.Unmarshal(&cfg) + + return err +} + +// Watch starts watching the config file for configuration changes +func (l *LegacyLoader) Watch(ctx context.Context, _ *Config, onChange func(context.Context) error) error { + l.viper.WatchConfig() + + // If config changes, trigger the main config reload function + l.viper.OnConfigChange(func(e fsnotify.Event) { + slog.DebugContext(ctx, "config file changed, triggering reload", slog.String("file", e.Name)) + + if err := onChange(ctx); err != nil { + slog.ErrorContext(ctx, "error processing config file change", slog.Any("error", err)) + } + }) + + return nil +} + +func (l *LegacyLoader) Name() string { + return LoaderNameLegacy +} + +func (l *LegacyLoader) Close() error { + return nil +} diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 8d5adc8f4..5539b45c8 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -47,8 +47,7 @@ func Start(f ...StartOptions) error { slog.Debug("loading configuration from environment") loaderOrder := []string{ - config.LoaderNameEnvironmentValue, - config.LoaderNameFile, + config.LoaderNameLegacy, config.LoaderNameDefaultSettings, } if startConfig.configLoaderOrder != nil { @@ -85,6 +84,11 @@ func Start(f ...StartOptions) error { if err != nil { return err } + case config.LoaderNameLegacy: + loader, err = config.NewLegacyLoader(startConfig.ConfigKey, startConfig.ConfigFile) + if err != nil { + return err + } default: mappedLoader, ok := additionalLoaderMap[loaderName] if !ok {