diff --git a/config.example.yaml b/config.example.yaml index e0d61830..9814b050 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -90,6 +90,9 @@ apiKeys: # - macro names must not be a reserved name: PORT or MODEL_ID # - macro values can be numbers, bools, or strings # - macros can contain other macros, but they must be defined before they are used +# - environment variables can be referenced with ${env.VAR_NAME} syntax +# - env macros are substituted first, before regular macros +# - if the env var is not set, config loading will fail with an error macros: # Example of a multi-line macro "latest-llama": > @@ -102,6 +105,11 @@ macros: # but they must be previously declared. "default_args": "--ctx-size ${default_ctx}" + # Example of environment variable macros + # - ${env.VAR_NAME} pulls the value from the system environment + # - useful for paths, secrets, or machine-specific configuration + "models_dir": "${env.HOME}/models" + # models: a dictionary of model configurations # - required # - each key is the model's ID, used in API requests diff --git a/proxy/config/config.go b/proxy/config/config.go index 078d27fd..c1af8654 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -87,6 +87,7 @@ type GroupConfig struct { var ( macroNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) macroPatternRegex = regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`) + envMacroRegex = regexp.MustCompile(`\$\{env\.([a-zA-Z_][a-zA-Z0-9_]*)\}`) ) // set default values for GroupConfig @@ -237,6 +238,17 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } + // Process environment variable macros in global macro values first + for i, macro := range config.Macros { + if strVal, ok := macro.Value.(string); ok { + newVal, err := substituteEnvMacros(strVal) + if err != nil { + return Config{}, fmt.Errorf("global macro '%s': %w", macro.Name, err) + } + config.Macros[i].Value = newVal + } + } + // Get and sort all model IDs first, makes testing more consistent modelIds := make([]string, 0, len(config.Models)) for modelId := range config.Models { @@ -252,6 +264,48 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { modelConfig.Cmd = StripComments(modelConfig.Cmd) modelConfig.CmdStop = StripComments(modelConfig.CmdStop) + // Substitute environment variable macros in model fields + modelConfig.Cmd, err = substituteEnvMacros(modelConfig.Cmd) + if err != nil { + return Config{}, fmt.Errorf("model %s cmd: %w", modelId, err) + } + modelConfig.CmdStop, err = substituteEnvMacros(modelConfig.CmdStop) + if err != nil { + return Config{}, fmt.Errorf("model %s cmdStop: %w", modelId, err) + } + modelConfig.Proxy, err = substituteEnvMacros(modelConfig.Proxy) + if err != nil { + return Config{}, fmt.Errorf("model %s proxy: %w", modelId, err) + } + modelConfig.CheckEndpoint, err = substituteEnvMacros(modelConfig.CheckEndpoint) + if err != nil { + return Config{}, fmt.Errorf("model %s checkEndpoint: %w", modelId, err) + } + modelConfig.Filters.StripParams, err = substituteEnvMacros(modelConfig.Filters.StripParams) + if err != nil { + return Config{}, fmt.Errorf("model %s filters.stripParams: %w", modelId, err) + } + + // Substitute env macros in model-level macro values + for i, macro := range modelConfig.Macros { + if strVal, ok := macro.Value.(string); ok { + newVal, err := substituteEnvMacros(strVal) + if err != nil { + return Config{}, fmt.Errorf("model %s macro '%s': %w", modelId, macro.Name, err) + } + modelConfig.Macros[i].Value = newVal + } + } + + // Substitute env macros in metadata + if len(modelConfig.Metadata) > 0 { + result, err := substituteEnvMacrosInValue(modelConfig.Metadata) + if err != nil { + return Config{}, fmt.Errorf("model %s metadata: %w", modelId, err) + } + modelConfig.Metadata = result.(map[string]any) + } + // validate model macros for _, macro := range modelConfig.Macros { if err = validateMacro(macro.Name, macro.Value); err != nil { @@ -362,6 +416,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { // Any other macro is unknown return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName) } + + // Check for unsubstituted env macros + envMatches := envMacroRegex.FindAllStringSubmatch(fieldValue, -1) + for _, match := range envMatches { + varName := match[1] + return Config{}, fmt.Errorf("environment variable '%s' not set (found in %s.%s)", varName, modelId, fieldName) + } } // Check for unknown macros in metadata @@ -574,6 +635,12 @@ func validateMetadataForUnknownMacros(value any, modelId string) error { macroName := match[1] return fmt.Errorf("model %s metadata: unknown macro '${%s}'", modelId, macroName) } + // Check for unsubstituted env macros + envMatches := envMacroRegex.FindAllStringSubmatch(v, -1) + for _, match := range envMatches { + varName := match[1] + return fmt.Errorf("model %s metadata: environment variable '%s' not set", modelId, varName) + } return nil case map[string]any: @@ -645,3 +712,54 @@ func substituteMacroInValue(value any, macroName string, macroValue any) (any, e return value, nil } } + +// substituteEnvMacros replaces ${env.VAR_NAME} with environment variable values +// Returns error if any env var is not set +func substituteEnvMacros(s string) (string, error) { + result := s + matches := envMacroRegex.FindAllStringSubmatch(s, -1) + for _, match := range matches { + fullMatch := match[0] // ${env.VAR_NAME} + varName := match[1] // VAR_NAME + + value, exists := os.LookupEnv(varName) + if !exists { + return "", fmt.Errorf("environment variable '%s' is not set", varName) + } + result = strings.ReplaceAll(result, fullMatch, value) + } + return result, nil +} + +// substituteEnvMacrosInValue recursively substitutes env macros in nested structures +func substituteEnvMacrosInValue(value any) (any, error) { + switch v := value.(type) { + case string: + return substituteEnvMacros(v) + + case map[string]any: + newMap := make(map[string]any) + for key, val := range v { + newVal, err := substituteEnvMacrosInValue(val) + if err != nil { + return nil, err + } + newMap[key] = newVal + } + return newMap, nil + + case []any: + newSlice := make([]any, len(v)) + for i, val := range v { + newVal, err := substituteEnvMacrosInValue(val) + if err != nil { + return nil, err + } + newSlice[i] = newVal + } + return newSlice, nil + + default: + return value, nil + } +} diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index ab358e66..855cef50 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -809,3 +809,213 @@ func TestConfig_APIKeys_Invalid(t *testing.T) { }) } } + +func TestConfig_EnvMacros(t *testing.T) { + t.Run("basic env substitution in cmd", func(t *testing.T) { + t.Setenv("TEST_MODEL_PATH", "/opt/models") + + content := ` +models: + test: + cmd: "${env.TEST_MODEL_PATH}/llama-server" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "/opt/models/llama-server", config.Models["test"].Cmd) + }) + + t.Run("env substitution in multiple fields", func(t *testing.T) { + t.Setenv("TEST_HOST", "myserver") + t.Setenv("TEST_PORT", "9999") + + content := ` +models: + test: + cmd: "server --host ${env.TEST_HOST}" + proxy: "http://${env.TEST_HOST}:${env.TEST_PORT}" + checkEndpoint: "http://${env.TEST_HOST}/health" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "server --host myserver", config.Models["test"].Cmd) + assert.Equal(t, "http://myserver:9999", config.Models["test"].Proxy) + assert.Equal(t, "http://myserver/health", config.Models["test"].CheckEndpoint) + }) + + t.Run("env in global macro value", func(t *testing.T) { + t.Setenv("TEST_BASE_PATH", "/usr/local") + + content := ` +macros: + SERVER_PATH: "${env.TEST_BASE_PATH}/bin/server" +models: + test: + cmd: "${SERVER_PATH} --port 8080" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "/usr/local/bin/server --port 8080", config.Models["test"].Cmd) + }) + + t.Run("env in model-level macro value", func(t *testing.T) { + t.Setenv("TEST_MODEL_DIR", "/models/llama") + + content := ` +models: + test: + macros: + MODEL_FILE: "${env.TEST_MODEL_DIR}/model.gguf" + cmd: "server --model ${MODEL_FILE}" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "server --model /models/llama/model.gguf", config.Models["test"].Cmd) + }) + + t.Run("env in metadata", func(t *testing.T) { + t.Setenv("TEST_API_KEY", "secret123") + + content := ` +models: + test: + cmd: "server" + proxy: "http://localhost:8080" + metadata: + api_key: "${env.TEST_API_KEY}" + nested: + key: "${env.TEST_API_KEY}" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "secret123", config.Models["test"].Metadata["api_key"]) + nested := config.Models["test"].Metadata["nested"].(map[string]any) + assert.Equal(t, "secret123", nested["key"]) + }) + + t.Run("env in filters.stripParams", func(t *testing.T) { + t.Setenv("TEST_STRIP_PARAMS", "temperature,top_p") + + content := ` +models: + test: + cmd: "server" + proxy: "http://localhost:8080" + filters: + stripParams: "${env.TEST_STRIP_PARAMS}" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "temperature,top_p", config.Models["test"].Filters.StripParams) + }) + + t.Run("env in cmdStop", func(t *testing.T) { + t.Setenv("TEST_KILL_SIGNAL", "SIGTERM") + + content := ` +models: + test: + cmd: "server --port ${PORT}" + cmdStop: "kill -${env.TEST_KILL_SIGNAL} ${PID}" + proxy: "http://localhost:${PORT}" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Contains(t, config.Models["test"].CmdStop, "-SIGTERM") + }) + + t.Run("missing env var returns error", func(t *testing.T) { + content := ` +models: + test: + cmd: "${env.UNDEFINED_VAR_12345}/server" + proxy: "http://localhost:8080" +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "UNDEFINED_VAR_12345") + assert.Contains(t, err.Error(), "not set") + } + }) + + t.Run("missing env var in global macro", func(t *testing.T) { + content := ` +macros: + PATH: "${env.UNDEFINED_GLOBAL_VAR}" +models: + test: + cmd: "server" + proxy: "http://localhost:8080" +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "UNDEFINED_GLOBAL_VAR") + assert.Contains(t, err.Error(), "not set") + } + }) + + t.Run("missing env var in model macro", func(t *testing.T) { + content := ` +models: + test: + macros: + MY_PATH: "${env.UNDEFINED_MODEL_VAR}" + cmd: "server" + proxy: "http://localhost:8080" +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "UNDEFINED_MODEL_VAR") + assert.Contains(t, err.Error(), "not set") + } + }) + + t.Run("missing env var in metadata", func(t *testing.T) { + content := ` +models: + test: + cmd: "server" + proxy: "http://localhost:8080" + metadata: + key: "${env.UNDEFINED_META_VAR}" +` + _, err := LoadConfigFromReader(strings.NewReader(content)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "UNDEFINED_META_VAR") + assert.Contains(t, err.Error(), "not set") + } + }) + + t.Run("env combined with regular macros", func(t *testing.T) { + t.Setenv("TEST_ROOT", "/data") + + content := ` +macros: + MODEL_BASE: "${env.TEST_ROOT}/models" +models: + test: + cmd: "server --model ${MODEL_BASE}/${MODEL_ID}.gguf" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "server --model /data/models/test.gguf", config.Models["test"].Cmd) + }) + + t.Run("multiple env vars in same string", func(t *testing.T) { + t.Setenv("TEST_USER", "admin") + t.Setenv("TEST_PASS", "secret") + + content := ` +models: + test: + cmd: "server --auth ${env.TEST_USER}:${env.TEST_PASS}" + proxy: "http://localhost:8080" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.Equal(t, "server --auth admin:secret", config.Models["test"].Cmd) + }) +}