diff --git a/Makefile b/Makefile index 8e01898b..0bce0882 100644 --- a/Makefile +++ b/Makefile @@ -24,10 +24,11 @@ proxy/ui_dist/placeholder.txt: touch $@ test: proxy/ui_dist/placeholder.txt - go test -short -v -count=1 ./proxy + go test -short ./proxy/... +# for CI - full test (takes longer) test-all: proxy/ui_dist/placeholder.txt - go test -v -count=1 ./proxy + go test -count=1 ./proxy/... ui/node_modules: cd ui && npm install @@ -81,4 +82,4 @@ release: git tag "$$new_tag"; # Phony targets -.PHONY: all clean ui mac linux windows simple-responder +.PHONY: all clean ui mac linux windows simple-responder test test-all diff --git a/llama-swap.go b/llama-swap.go index 3bdcef13..bf93a41a 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -16,6 +16,7 @@ import ( "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/proxy" + "github.com/mostlygeek/llama-swap/proxy/config" ) var ( @@ -38,13 +39,13 @@ func main() { os.Exit(0) } - config, err := proxy.LoadConfig(*configPath) + conf, err := config.LoadConfig(*configPath) if err != nil { fmt.Printf("Error loading config: %v\n", err) os.Exit(1) } - if len(config.Profiles) > 0 { + if len(conf.Profiles) > 0 { fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.") } @@ -67,7 +68,7 @@ func main() { // Support for watching config and reloading when it changes reloadProxyManager := func() { if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok { - config, err = proxy.LoadConfig(*configPath) + conf, err = config.LoadConfig(*configPath) if err != nil { fmt.Printf("Warning, unable to reload configuration: %v\n", err) return @@ -75,7 +76,7 @@ func main() { fmt.Println("Configuration Changed") currentPM.Shutdown() - srv.Handler = proxy.New(config) + srv.Handler = proxy.New(conf) fmt.Println("Configuration Reloaded") // wait a few seconds and tell any UI to reload @@ -85,12 +86,12 @@ func main() { }) }) } else { - config, err = proxy.LoadConfig(*configPath) + conf, err = config.LoadConfig(*configPath) if err != nil { fmt.Printf("Error, unable to load configuration: %v\n", err) os.Exit(1) } - srv.Handler = proxy.New(config) + srv.Handler = proxy.New(conf) } } diff --git a/proxy/config.go b/proxy/config/config.go similarity index 99% rename from proxy/config.go rename to proxy/config/config.go index 65be8d06..a269b0b7 100644 --- a/proxy/config.go +++ b/proxy/config/config.go @@ -1,4 +1,4 @@ -package proxy +package config import ( "fmt" @@ -154,6 +154,7 @@ type Config struct { Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ Profiles map[string][]string `yaml:"profiles"` Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ + Peers map[string]PeerConfig `yaml:"peers"` /* key is peer ID */ // for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint Macros map[string]string `yaml:"macros"` diff --git a/proxy/config_posix_test.go b/proxy/config/config_posix_test.go similarity index 92% rename from proxy/config_posix_test.go rename to proxy/config/config_posix_test.go index 122d3511..05a0fea3 100644 --- a/proxy/config_posix_test.go +++ b/proxy/config/config_posix_test.go @@ -1,10 +1,11 @@ //go:build !windows -package proxy +package config import ( "os" "path/filepath" + "regexp" "strings" "testing" @@ -148,6 +149,14 @@ groups: persistent: true members: - "model4" +peers: + desktop: + name: "Desktop" + description: "runs Linux" + baseURL: "http://10.0.4.11:8080" + apikey: "secret-key" + priority: 10 + filters: [] ` if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { @@ -232,6 +241,17 @@ groups: Members: []string{"model4"}, }, }, + Peers: map[string]PeerConfig{ + "desktop": { + Name: "Desktop", + Description: "runs Linux", + BaseURL: "http://10.0.4.11:8080", + ApiKey: "secret-key", + Priority: 10, + Filters: []string{}, + reFilters: []*regexp.Regexp{}, /* leave blank, test in peer_test.go */ + }, + }, } assert.Equal(t, expected, config) diff --git a/proxy/config_test.go b/proxy/config/config_test.go similarity index 99% rename from proxy/config_test.go rename to proxy/config/config_test.go index 505d80b4..0bee0141 100644 --- a/proxy/config_test.go +++ b/proxy/config/config_test.go @@ -1,4 +1,4 @@ -package proxy +package config import ( "slices" diff --git a/proxy/config_windows_test.go b/proxy/config/config_windows_test.go similarity index 98% rename from proxy/config_windows_test.go rename to proxy/config/config_windows_test.go index 6902da02..6a5ab9ed 100644 --- a/proxy/config_windows_test.go +++ b/proxy/config/config_windows_test.go @@ -1,6 +1,6 @@ //go:build windows -package proxy +package config import ( "os" @@ -221,6 +221,7 @@ groups: Members: []string{"model4"}, }, }, + Peers: nil, // empty here, see config_posix_test.go } assert.Equal(t, expected, config) diff --git a/proxy/config/peer.go b/proxy/config/peer.go new file mode 100644 index 00000000..395a8f5b --- /dev/null +++ b/proxy/config/peer.go @@ -0,0 +1,46 @@ +package config + +import ( + "fmt" + "regexp" +) + +type PeerConfig struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + BaseURL string `yaml:"baseURL"` + ApiKey string `yaml:"apikey"` + Priority int `yaml:"priority"` + Filters []string `yaml:"filters"` + reFilters []*regexp.Regexp `yaml:"-"` +} + +// set default values for GroupConfig +func (c *PeerConfig) UnmarshalYAML(unmarshal func(any) error) error { + type rawConfig PeerConfig + defaults := rawConfig{ + Name: "", + Description: "", + BaseURL: "", + ApiKey: "", + Priority: 0, + Filters: []string{}, + reFilters: []*regexp.Regexp{}, + } + + if err := unmarshal(&defaults); err != nil { + return err + } + + // compile regex filters and store compiled patterns in reFilters + for _, pat := range defaults.Filters { + r, err := regexp.Compile(pat) + if err != nil { + return fmt.Errorf("failed to compile peer filter %q: %w", pat, err) + } + defaults.reFilters = append(defaults.reFilters, r) + } + + *c = PeerConfig(defaults) + return nil +} diff --git a/proxy/config/peer_test.go b/proxy/config/peer_test.go new file mode 100644 index 00000000..de9a9f11 --- /dev/null +++ b/proxy/config/peer_test.go @@ -0,0 +1,99 @@ +package config + +import ( + "strings" + "testing" + + "gopkg.in/yaml.v3" +) + +// Tests that defaults are set when unmarshaling an empty/minimal YAML. +func TestPeerConfig_Defaults(t *testing.T) { + var pc PeerConfig + data := `{}` + + if err := yaml.Unmarshal([]byte(data), &pc); err != nil { + t.Fatalf("unexpected unmarshal error: %v", err) + } + + if pc.Name != "" { + t.Errorf("Name expected %q, got %q", "", pc.Name) + } + if pc.Description != "" { + t.Errorf("Description expected %q, got %q", "", pc.Description) + } + if pc.BaseURL != "" { + t.Errorf("BaseURL expected %q, got %q", "", pc.BaseURL) + } + if pc.ApiKey != "" { + t.Errorf("ApiKey expected %q, got %q", "", pc.ApiKey) + } + if pc.Priority != 0 { + t.Errorf("Priority expected %d, got %d", 0, pc.Priority) + } + if len(pc.Filters) != 0 { + t.Errorf("Filters expected length %d, got %d", 0, len(pc.Filters)) + } + if len(pc.reFilters) != 0 { + t.Errorf("reFilters expected length %d, got %d", 0, len(pc.reFilters)) + } +} + +// Tests that valid regex patterns in Filters are compiled into reFilters and work as expected. +func TestPeerConfig_RegexCompileSuccess(t *testing.T) { + var pc PeerConfig + data := ` +filters: + - "^foo.*" + - "ba[rz]$" +` + + if err := yaml.Unmarshal([]byte(data), &pc); err != nil { + t.Fatalf("unexpected unmarshal error: %v", err) + } + + if len(pc.Filters) != 2 { + t.Fatalf("expected Filters length 2, got %d", len(pc.Filters)) + } + if len(pc.reFilters) != 2 { + t.Fatalf("expected reFilters length 2, got %d", len(pc.reFilters)) + } + + // first pattern ^foo.* + if !pc.reFilters[0].MatchString("foobar") { + t.Errorf("expected pattern %q to match %q", pc.Filters[0], "foobar") + } + if pc.reFilters[0].MatchString("barfoo") { + t.Errorf("expected pattern %q NOT to match %q", pc.Filters[0], "barfoo") + } + + // second pattern ba[rz]$ + if !pc.reFilters[1].MatchString("bar") { + t.Errorf("expected pattern %q to match %q", pc.Filters[1], "bar") + } + if !pc.reFilters[1].MatchString("baz") { + t.Errorf("expected pattern %q to match %q", pc.Filters[1], "baz") + } + if pc.reFilters[1].MatchString("bax") { + t.Errorf("expected pattern %q NOT to match %q", pc.Filters[1], "bax") + } +} + +// Tests that an invalid regex produces an error during Unmarshal. +func TestPeerConfig_RegexCompileFailure(t *testing.T) { + var pc PeerConfig + data := ` +filters: + - "(" +` + + err := yaml.Unmarshal([]byte(data), &pc) + if err == nil { + t.Fatalf("expected error compiling invalid regex, got nil") + } + // Optionally ensure our error message path was used + if !strings.Contains(err.Error(), "failed to compile peer filter") { + t.Logf("warning: error did not contain expected text; full error: %v", err) + t.Fail() + } +} diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go index bcb5acbd..95b8e6bd 100644 --- a/proxy/helpers_test.go +++ b/proxy/helpers_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/mostlygeek/llama-swap/proxy/config" "gopkg.in/yaml.v3" ) @@ -65,18 +66,18 @@ func getTestPort() int { return port } -func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { +func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig { return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) } -func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { +func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig { // Create a YAML string with just the values we want to set yamlStr := fmt.Sprintf(` cmd: '%s --port %d --silent --respond %s' proxy: "http://127.0.0.1:%d" `, simpleResponderPath, port, expectedMessage, port) - var cfg ModelConfig + var cfg config.ModelConfig if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil { panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr)) } diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go index ee11f2ac..826870f0 100644 --- a/proxy/metrics_monitor.go +++ b/proxy/metrics_monitor.go @@ -6,6 +6,7 @@ import ( "time" "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/proxy/config" ) // TokenMetrics represents parsed token statistics from llama-server logs @@ -38,7 +39,7 @@ type MetricsMonitor struct { nextID int } -func NewMetricsMonitor(config *Config) *MetricsMonitor { +func NewMetricsMonitor(config *config.Config) *MetricsMonitor { maxMetrics := config.MetricsMaxInMemory if maxMetrics <= 0 { maxMetrics = 1000 // Default fallback diff --git a/proxy/process.go b/proxy/process.go index 94c004bc..51a5bc61 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -16,6 +16,7 @@ import ( "time" "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/proxy/config" ) type ProcessState string @@ -39,7 +40,7 @@ const ( type Process struct { ID string - config ModelConfig + config config.ModelConfig cmd *exec.Cmd // PR #155 called to cancel the upstream process @@ -74,7 +75,7 @@ type Process struct { failedStartCount int } -func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { +func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { concurrentLimit := 10 if config.ConcurrencyLimit > 0 { concurrentLimit = config.ConcurrencyLimit @@ -539,7 +540,7 @@ func (p *Process) cmdStopUpstreamProcess() error { if p.config.CmdStop != "" { // replace ${PID} with the pid of the process - stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid))) + stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid))) if err != nil { p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err) return err diff --git a/proxy/process_test.go b/proxy/process_test.go index da4a3804..574c5d9e 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" ) @@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) { // test that the automatic start returns the expected error type func TestProcess_BrokenModelConfig(t *testing.T) { // Create a process configuration - config := ModelConfig{ + config := config.ModelConfig{ Cmd: "nonexistent-command", Proxy: "http://127.0.0.1:9913", CheckEndpoint: "/health", @@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { // should run and exit but interrupt the long checkHealthTimeout checkHealthTimeout := 5 - config := ModelConfig{ + config := config.ModelConfig{ Cmd: "sleep 1", Proxy: "http://127.0.0.1:9913", CheckEndpoint: "/health", @@ -402,7 +403,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) { binaryPath := getSimpleResponderPath() port := getTestPort() - config := ModelConfig{ + conf := config.ModelConfig{ // note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent // to force the process to exit Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage), @@ -410,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) + process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger) defer process.Stop() // reduce to make testing go faster @@ -450,15 +451,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) { } func TestProcess_StopCmd(t *testing.T) { - config := getTestSimpleResponderConfig("test_stop_cmd") + conf := getTestSimpleResponderConfig("test_stop_cmd") if runtime.GOOS == "windows" { - config.CmdStop = "taskkill /f /t /pid ${PID}" + conf.CmdStop = "taskkill /f /t /pid ${PID}" } else { - config.CmdStop = "kill -TERM ${PID}" + conf.CmdStop = "kill -TERM ${PID}" } - process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger) + process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger) defer process.Stop() err := process.start() @@ -470,15 +471,15 @@ func TestProcess_StopCmd(t *testing.T) { func TestProcess_EnvironmentSetCorrectly(t *testing.T) { expectedMessage := "test_env_not_emptied" - config := getTestSimpleResponderConfig(expectedMessage) + conf := getTestSimpleResponderConfig(expectedMessage) // ensure that the the default config does not blank out the inherited environment - configWEnv := config + configWEnv := conf // ensure the additiona variables are appended to the process' environment configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2") - process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger) + process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger) process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger) process1.start() diff --git a/proxy/processgroup.go b/proxy/processgroup.go index cca48c10..cef540e2 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -5,12 +5,14 @@ import ( "net/http" "slices" "sync" + + "github.com/mostlygeek/llama-swap/proxy/config" ) type ProcessGroup struct { sync.Mutex - config Config + config config.Config id string swap bool exclusive bool @@ -24,7 +26,7 @@ type ProcessGroup struct { lastUsedProcess string } -func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { +func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { groupConfig, ok := config.Groups[id] if !ok { panic("Unable to find configuration for group id: " + id) diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index 791a5a94..f10293ae 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -7,19 +7,20 @@ import ( "sync" "testing" + "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" ) -var processGroupTestConfig = AddDefaultGroupToConfig(Config{ +var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), "model3": getTestSimpleResponderConfig("model3"), "model4": getTestSimpleResponderConfig("model4"), "model5": getTestSimpleResponderConfig("model5"), }, - Groups: map[string]GroupConfig{ + Groups: map[string]config.GroupConfig{ "G1": { Swap: true, Exclusive: true, @@ -34,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{ }) func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { - pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) assert.True(t, pg.HasMember("model5")) } @@ -48,9 +49,9 @@ func TestProcessGroup_HasMember(t *testing.T) { // TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true // and multiple requests are made in parallel, only one process is running at a time. func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { - var processGroupTestConfig = AddDefaultGroupToConfig(Config{ + var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ // use the same listening so if a model is already running, it will fail // this is a way to test that swap isolation is working // properly when there are parallel requests made at the @@ -61,7 +62,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { "model4": getTestSimpleResponderConfigPort("model4", 9832), "model5": getTestSimpleResponderConfigPort("model5", 9832), }, - Groups: map[string]GroupConfig{ + Groups: map[string]config.GroupConfig{ "G1": { Swap: true, Members: []string{"model1", "model2", "model3", "model4", "model5"}, diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 9383f06a..8c6a8095 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -16,6 +16,7 @@ import ( "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/proxy/config" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -27,7 +28,7 @@ const ( type ProxyManager struct { sync.Mutex - config Config + config config.Config ginEngine *gin.Engine // logging @@ -44,7 +45,7 @@ type ProxyManager struct { shutdownCancel context.CancelFunc } -func New(config Config) *ProxyManager { +func New(config config.Config) *ProxyManager { // set up loggers stdoutLogger := NewLogMonitorWriter(os.Stdout) upstreamLogger := NewLogMonitorWriter(stdoutLogger) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index e5cea5ec..662a705f 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -16,14 +16,15 @@ import ( "time" "github.com/mostlygeek/llama-swap/event" + "github.com/mostlygeek/llama-swap/proxy/config" "github.com/stretchr/testify/assert" "github.com/tidwall/gjson" ) func TestProxyManager_SwapProcessCorrectly(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, @@ -44,14 +45,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { } } func TestProxyManager_SwapMultiProcess(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, LogLevel: "error", - Groups: map[string]GroupConfig{ + Groups: map[string]config.GroupConfig{ "G1": { Swap: true, Exclusive: false, @@ -89,14 +90,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { // Test that a persistent group is not affected by the swapping behaviour of // other groups. func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), // goes into the default group "model2": getTestSimpleResponderConfig("model2"), }, LogLevel: "error", - Groups: map[string]GroupConfig{ + Groups: map[string]config.GroupConfig{ // the forever group is persistent and should not be affected by model1 "forever": { Swap: true, @@ -133,9 +134,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { t.Skip("skipping slow test") } - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), "model3": getTestSimpleResponderConfig("model3"), @@ -196,9 +197,9 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { model2Config.Name = " " // empty whitespace only strings will get ignored model2Config.Description = " " - config := Config{ + config := config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": model1Config, "model2": model2Config, "model3": getTestSimpleResponderConfig("model3"), @@ -283,13 +284,13 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { // Intentionally add models in non-sorted order and with an unlisted model - config := Config{ + config := config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "zeta": getTestSimpleResponderConfig("zeta"), "alpha": getTestSimpleResponderConfig("alpha"), "beta": getTestSimpleResponderConfig("beta"), - "hidden": func() ModelConfig { + "hidden": func() config.ModelConfig { mc := getTestSimpleResponderConfig("hidden") mc.Unlisted = true return mc @@ -337,15 +338,15 @@ func TestProxyManager_Shutdown(t *testing.T) { model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config.Proxy = "http://localhost:10003/" - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": model1Config, "model2": model2Config, "model3": model3Config, }, LogLevel: "error", - Groups: map[string]GroupConfig{ + Groups: map[string]config.GroupConfig{ "test": { Swap: false, Members: []string{"model1", "model2", "model3"}, @@ -380,21 +381,21 @@ func TestProxyManager_Shutdown(t *testing.T) { } func TestProxyManager_Unload(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + conf := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", }) - proxy := New(config) + proxy := New(conf) reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() proxy.ServeHTTP(w, req) - assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) + assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) req = httptest.NewRequest("GET", "/unload", nil) w = httptest.NewRecorder() proxy.ServeHTTP(w, req) @@ -403,15 +404,15 @@ func TestProxyManager_Unload(t *testing.T) { // give it a bit of time to stop <-time.After(time.Millisecond * 250) - assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) + assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) } // Test issue #61 `Listing the current list of models and the loaded model.` func TestProxyManager_RunningEndpoint(t *testing.T) { // Shared configuration - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, @@ -474,9 +475,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { } func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), }, LogLevel: "error", @@ -527,15 +528,15 @@ func TestProxyManager_UseModelName(t *testing.T) { modelConfig := getTestSimpleResponderConfig(upstreamModelName) modelConfig.UseModelName = upstreamModelName - config := AddDefaultGroupToConfig(Config{ + conf := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": modelConfig, }, LogLevel: "error", }) - proxy := New(config) + proxy := New(conf) defer proxy.StopProcesses(StopWaitForInflightRequest) requestedModel := "model1" @@ -590,9 +591,9 @@ func TestProxyManager_UseModelName(t *testing.T) { } func TestProxyManager_CORSOptionsHandler(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -666,7 +667,7 @@ models: aliases: [model-alias] `, getSimpleResponderPath()) - config, err := LoadConfigFromReader(strings.NewReader(configStr)) + config, err := config.LoadConfigFromReader(strings.NewReader(configStr)) assert.NoError(t, err) proxy := New(config) @@ -689,9 +690,9 @@ models: } func TestProxyManager_ChatContentLength(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -714,14 +715,14 @@ func TestProxyManager_ChatContentLength(t *testing.T) { func TestProxyManager_FiltersStripParams(t *testing.T) { modelConfig := getTestSimpleResponderConfig("model1") - modelConfig.Filters = ModelFilters{ + modelConfig.Filters = config.ModelFilters{ StripParams: "temperature, model, stream", } - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, LogLevel: "error", - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": modelConfig, }, }) @@ -747,9 +748,9 @@ func TestProxyManager_FiltersStripParams(t *testing.T) { } func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -782,9 +783,9 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) { } func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -817,9 +818,9 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) { } func TestProxyManager_HealthEndpoint(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -836,9 +837,9 @@ func TestProxyManager_HealthEndpoint(t *testing.T) { // Ensure the custom llama-server /completion endpoint proxies correctly func TestProxyManager_CompletionEndpoint(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -882,7 +883,7 @@ models: `, "${simpleresponderpath}", simpleResponderPath, -1) // Create a test model configuration - config, err := LoadConfigFromReader(strings.NewReader(configStr)) + config, err := config.LoadConfigFromReader(strings.NewReader(configStr)) if !assert.NoError(t, err, "Invalid configuration") { return } @@ -916,9 +917,9 @@ models: } func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", @@ -955,9 +956,9 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { } func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ + config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ + Models: map[string]config.ModelConfig{ "streaming-model": getTestSimpleResponderConfig("streaming-model"), }, LogLevel: "error",