diff --git a/filterapi/filterconfig.go b/filterapi/filterconfig.go index f4e64d74de..da5201c3f1 100644 --- a/filterapi/filterconfig.go +++ b/filterapi/filterconfig.go @@ -186,14 +186,24 @@ type APIKeyAuth struct { } // UnmarshalConfigYaml reads the file at the given path and unmarshals it into a Config struct. -func UnmarshalConfigYaml(path string) (*Config, error) { +func UnmarshalConfigYaml(path string) (*Config, []byte, error) { raw, err := os.ReadFile(path) if err != nil { - return nil, err + return nil, nil, err } var cfg Config if err := yaml.Unmarshal(raw, &cfg); err != nil { - return nil, err + return nil, nil, err } - return &cfg, nil + return &cfg, raw, nil +} + +// MustLoadDefaultConfig loads the default configuration. +// This panics if the configuration fails to be loaded. +func MustLoadDefaultConfig() (*Config, []byte) { + var cfg Config + if err := yaml.Unmarshal([]byte(DefaultConfig), &cfg); err != nil { + panic(err) + } + return &cfg, []byte(DefaultConfig) } diff --git a/filterapi/filterconfig_test.go b/filterapi/filterconfig_test.go index d0cd64a5da..1f71285bce 100644 --- a/filterapi/filterconfig_test.go +++ b/filterapi/filterconfig_test.go @@ -12,7 +12,6 @@ import ( "testing" "github.com/stretchr/testify/require" - "k8s.io/apimachinery/pkg/util/yaml" "github.com/envoyproxy/ai-gateway/filterapi" "github.com/envoyproxy/ai-gateway/internal/extproc" @@ -23,11 +22,15 @@ func TestDefaultConfig(t *testing.T) { require.NoError(t, err) require.NotNil(t, server) - var cfg filterapi.Config - err = yaml.Unmarshal([]byte(filterapi.DefaultConfig), &cfg) - require.NoError(t, err) + cfg, raw := filterapi.MustLoadDefaultConfig() + require.Equal(t, []byte(filterapi.DefaultConfig), raw) + require.Equal(t, &filterapi.Config{ + Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}, + SelectedBackendHeaderKey: "x-ai-eg-selected-backend", + ModelNameHeaderKey: "x-ai-eg-model", + }, cfg) - err = server.LoadConfig(t.Context(), &cfg) + err = server.LoadConfig(t.Context(), cfg) require.NoError(t, err) } @@ -71,8 +74,9 @@ rules: value: gpt4.4444 ` require.NoError(t, os.WriteFile(configPath, []byte(config), 0o600)) - cfg, err := filterapi.UnmarshalConfigYaml(configPath) + cfg, raw, err := filterapi.UnmarshalConfigYaml(configPath) require.NoError(t, err) + require.Equal(t, []byte(config), raw) require.Equal(t, "ai_gateway_llm_ns", cfg.MetadataNamespace) require.Equal(t, "token_usage_key", cfg.LLMRequestCosts[0].MetadataKey) require.Equal(t, "OutputToken", string(cfg.LLMRequestCosts[0].Type)) @@ -92,13 +96,14 @@ rules: require.Equal(t, "us-east-1", cfg.Rules[0].Backends[1].Auth.AWSAuth.Region) t.Run("not found", func(t *testing.T) { - _, err := filterapi.UnmarshalConfigYaml("not-found.yaml") + _, _, err := filterapi.UnmarshalConfigYaml("not-found.yaml") require.Error(t, err) + require.True(t, os.IsNotExist(err)) }) t.Run("invalid", func(t *testing.T) { const invalidConfig = `{wefaf3q20,9u,f02` require.NoError(t, os.WriteFile(configPath, []byte(invalidConfig), 0o600)) - _, err := filterapi.UnmarshalConfigYaml(configPath) + _, _, err := filterapi.UnmarshalConfigYaml(configPath) require.Error(t, err) }) } diff --git a/internal/extproc/watcher.go b/internal/extproc/watcher.go index f6bba43dd3..e3cd796d15 100644 --- a/internal/extproc/watcher.go +++ b/internal/extproc/watcher.go @@ -65,45 +65,47 @@ func (cw *configWatcher) watch(ctx context.Context, tick time.Duration) { // loadConfig loads a new config from the given path and updates the Receiver by // calling the [Receiver.Load]. func (cw *configWatcher) loadConfig(ctx context.Context) error { + var ( + cfg *filterapi.Config + raw []byte + ) + stat, err := os.Stat(cw.path) - if err != nil { + switch { + case err != nil && os.IsNotExist(err): + // If the file does not exist, do not fail (which could lead to the extproc process to terminate) + // Instead, load the default configuration and keep running unconfigured + cfg, raw = filterapi.MustLoadDefaultConfig() + case err != nil: return err } - if stat.ModTime().Sub(cw.lastMod) <= 0 { - return nil + + if cfg != nil { + cw.l.Info("config file does not exist; loading default config", slog.String("path", cw.path)) + cw.lastMod = time.Now() + } else { + cw.l.Info("loading a new config", slog.String("path", cw.path)) + if stat.ModTime().Sub(cw.lastMod) <= 0 { + return nil + } + cw.lastMod = stat.ModTime() + cfg, raw, err = filterapi.UnmarshalConfigYaml(cw.path) + if err != nil { + return err + } } - cw.lastMod = stat.ModTime() - cw.l.Info("loading a new config", slog.String("path", cw.path)) // Print the diff between the old and new config. if cw.l.Enabled(ctx, slog.LevelDebug) { // Re-hydrate the current config file for later diffing. previous := cw.current - cw.current, err = cw.getConfigString() - if err != nil { - return fmt.Errorf("failed to read the config file: %w", err) - } - + cw.current = string(raw) cw.diff(previous, cw.current) } - cfg, err := filterapi.UnmarshalConfigYaml(cw.path) - if err != nil { - return err - } return cw.rcv.LoadConfig(ctx, cfg) } -// getConfigString gets a string representation of the current config -// read from the path. This is only used for debug log path for diff prints. -func (cw *configWatcher) getConfigString() (string, error) { - currentByte, err := os.ReadFile(cw.path) - if err != nil { - return "", err - } - return string(currentByte), nil -} - func (cw *configWatcher) diff(oldConfig, newConfig string) { if oldConfig == "" { return diff --git a/internal/extproc/watcher_test.go b/internal/extproc/watcher_test.go index 0a7d742a71..27f919166a 100644 --- a/internal/extproc/watcher_test.go +++ b/internal/extproc/watcher_test.go @@ -8,6 +8,7 @@ package extproc import ( "bytes" "context" + "io" "log/slog" "os" "strings" @@ -15,6 +16,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/filterapi" @@ -40,9 +42,30 @@ func (m *mockReceiver) getConfig() *filterapi.Config { return m.cfg } +var _ io.Writer = (*syncBuffer)(nil) + +// syncBuffer is a bytes.Buffer that is safe for concurrent read/write access. +// used just in the tests to safely read the logs in assertions without data races. +type syncBuffer struct { + mu sync.RWMutex + b *bytes.Buffer +} + +func (s *syncBuffer) Write(p []byte) (n int, err error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.b.Write(p) +} + +func (s *syncBuffer) String() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.b.String() +} + // newTestLoggerWithBuffer creates a new logger with a buffer for testing and asserting the output. -func newTestLoggerWithBuffer() (*slog.Logger, *bytes.Buffer) { - buf := &bytes.Buffer{} +func newTestLoggerWithBuffer() (*slog.Logger, *syncBuffer) { + buf := &syncBuffer{b: &bytes.Buffer{}} logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{ Level: slog.LevelDebug, })) @@ -54,7 +77,22 @@ func TestStartConfigWatcher(t *testing.T) { path := tmpdir + "/config.yaml" rcv := &mockReceiver{} - require.NoError(t, os.WriteFile(path, []byte{}, 0o600)) + logger, buf := newTestLoggerWithBuffer() + err := StartConfigWatcher(t.Context(), path, rcv, logger, time.Millisecond*100) + require.NoError(t, err) + + defaultCfg, _ := filterapi.MustLoadDefaultConfig() + require.NoError(t, err) + + // Verify the default config has been loaded. + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, defaultCfg, rcv.getConfig()) + }, 1*time.Second, 100*time.Millisecond) + + // Verify the buffer contains the default config loading. + require.Eventually(t, func() bool { + return strings.Contains(buf.String(), "config file does not exist; loading default config") + }, 1*time.Second, 100*time.Millisecond, buf.String()) // Create the initial config file. cfg := ` @@ -84,15 +122,10 @@ rules: value: gpt4.4444 ` require.NoError(t, os.WriteFile(path, []byte(cfg), 0o600)) - ctx, cancel := context.WithCancel(t.Context()) - defer cancel() - logger, buf := newTestLoggerWithBuffer() - err := StartConfigWatcher(ctx, path, rcv, logger, time.Millisecond*100) - require.NoError(t, err) // Initial loading should have happened. require.Eventually(t, func() bool { - return rcv.getConfig() != nil + return rcv.getConfig() != defaultCfg }, 1*time.Second, 100*time.Millisecond) firstCfg := rcv.getConfig() require.NotNil(t, firstCfg)