diff --git a/Makefile b/Makefile index f7d18586a..043238fc9 100644 --- a/Makefile +++ b/Makefile @@ -35,11 +35,12 @@ test: proxy/ui_dist/placeholder.txt test-all: proxy/ui_dist/placeholder.txt go test -race -count=1 ./proxy/... -ui/node_modules: +ui-svelte/node_modules: ui-svelte/package.json ui-svelte/package-lock.json cd ui-svelte && npm install + touch ui-svelte/node_modules -# build react UI -ui: ui/node_modules +# build svelte UI +ui: ui-svelte/node_modules cd ui-svelte && npm run build # Build OSX binary diff --git a/README.md b/README.md index ca74baffa..b813fc10a 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,11 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and - ✅ API Key support - define keys to restrict access to API endpoints - ✅ Customizable - Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - - Automatic unloading of models after timeout by setting a `ttl` + - Automatic unloading of models after idle timeout by setting a `ttl` + - Request timeout protection with `requestTimeout` to prevent runaway inference - Reliable Docker and Podman support using `cmd` and `cmdStop` together - Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) + - RPC health checking for distributed inference - conditionally expose models based on RPC server availability ### Web UI @@ -189,6 +191,7 @@ Almost all configuration settings are optional and can be added one step at a ti - `useModelName` to override model names sent to upstream servers - `${PORT}` automatic port variables for dynamic port assignment - `filters` rewrite parts of requests before sending to the upstream server + - `rpcHealthCheck` monitor RPC server health for distributed inference models See the [configuration documentation](docs/configuration.md) for all options. diff --git a/config-schema.json b/config-schema.json index 87cde486e..e5200de38 100644 --- a/config-schema.json +++ b/config-schema.json @@ -237,10 +237,21 @@ "type": "boolean", "description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting." }, + "requestTimeout": { + "type": "integer", + "minimum": 0, + "default": 0, + "description": "Maximum time in seconds for a single request to complete before forcefully killing the model process. This prevents runaway inference processes from blocking the GPU indefinitely. 0 disables timeout (default). When exceeded, the process is terminated and must be restarted for the next request." + }, "unlisted": { "type": "boolean", "default": false, "description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests." + }, + "rpcHealthCheck": { + "type": "boolean", + "default": false, + "description": "Enable TCP health checks for RPC endpoints specified in cmd. When enabled, parses --rpc host:port[,host:port,...] from cmd and performs health checks every 30 seconds. Models with unhealthy RPC endpoints are filtered from /v1/models and return 503 on inference requests." } } } diff --git a/config.example.yaml b/config.example.yaml index 35f74c126..5d0df2dd7 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -280,6 +280,16 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # requestTimeout: maximum time in seconds for a single request to complete + # - optional, default: 0 (no timeout) + # - useful for preventing runaway inference processes that never complete + # - when exceeded, the model process is forcefully stopped + # - protects against GPU overheating and blocking from stuck processes + # - the process must be restarted for the next request + # - set to 0 to disable timeout + # - recommended for models that may have infinite loops or excessive generation + requestTimeout: 0 # disabled by default, set to e.g., 300 for 5 minutes + # sendLoadingState: overrides the global sendLoadingState setting for this model # - optional, default: undefined (use global setting) sendLoadingState: false @@ -293,6 +303,24 @@ models: unlisted: true cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0 + # RPC health check example for distributed inference: + "qwen-distributed": + # rpcHealthCheck: enable TCP health checks for RPC endpoints + # - optional, default: false + # - when enabled, parses --rpc host:port[,host:port,...] from cmd + # - performs TCP connectivity checks every 30 seconds + # - model is only listed in /v1/models when ALL RPC endpoints are healthy + # - inference requests to unhealthy models return HTTP 503 + # - useful for distributed inference with llama.cpp's rpc-server + rpcHealthCheck: true + cmd: | + llama-server --port ${PORT} + --rpc 192.168.1.10:50051,192.168.1.11:50051 + -m Qwen2.5-32B-Instruct-Q4_K_M.gguf + -ngl 99 + name: "Qwen 32B (Distributed)" + description: "Large model using distributed RPC inference" + # Docker example: # container runtimes like Docker and Podman can be used reliably with # a combination of cmd, cmdStop, and ${MODEL_ID} diff --git a/config_embed.go b/config_embed.go new file mode 100644 index 000000000..2ba3ee60c --- /dev/null +++ b/config_embed.go @@ -0,0 +1,14 @@ +package main + +import ( + "bytes" + _ "embed" +) + +//go:embed config.example.yaml +var configExampleYAML []byte + +// GetConfigExampleYAML returns the embedded example config file +func GetConfigExampleYAML() []byte { + return bytes.Clone(configExampleYAML) +} diff --git a/docs/configuration.md b/docs/configuration.md index 5aac2706c..13b747d22 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -72,16 +72,17 @@ models: llama-swap supports many more features to customize how you want to manage your environment. -| Feature | Description | -| --------- | ---------------------------------------------- | -| `ttl` | automatic unloading of models after a timeout | -| `macros` | reusable snippets to use in configurations | -| `groups` | run multiple models at a time | -| `hooks` | event driven functionality | -| `env` | define environment variables per model | -| `aliases` | serve a model with different names | -| `filters` | modify requests before sending to the upstream | -| `...` | And many more tweaks | +| Feature | Description | +| ----------------- | ------------------------------------------------------- | +| `ttl` | automatic unloading of models after a timeout | +| `macros` | reusable snippets to use in configurations | +| `groups` | run multiple models at a time | +| `hooks` | event driven functionality | +| `env` | define environment variables per model | +| `aliases` | serve a model with different names | +| `filters` | modify requests before sending to the upstream | +| `rpcHealthCheck` | monitor RPC server health for distributed inference | +| `...` | And many more tweaks | ## Full Configuration Example @@ -319,6 +320,16 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # requestTimeout: maximum time in seconds for a single request to complete + # - optional, default: 0 (no timeout) + # - useful for preventing runaway inference processes that never complete + # - when exceeded, the model process is forcefully stopped + # - protects against GPU overheating and blocking from stuck processes + # - the process must be restarted for the next request + # - set to 0 to disable timeout + # - recommended for models that may have infinite loops or excessive generation + requestTimeout: 300 # 5 minutes + # sendLoadingState: overrides the global sendLoadingState setting for this model # - optional, default: undefined (use global setting) sendLoadingState: false diff --git a/llama-swap.go b/llama-swap.go index 9706e07d1..1c68a25ce 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -97,6 +97,8 @@ func main() { currentPM.Shutdown() newPM := proxy.New(conf) newPM.SetVersion(date, commit, version) + newPM.SetConfigPath(*configPath) + newPM.SetConfigExample(GetConfigExampleYAML()) srv.Handler = newPM fmt.Println("Configuration Reloaded") @@ -114,6 +116,8 @@ func main() { } newPM := proxy.New(conf) newPM.SetVersion(date, commit, version) + newPM.SetConfigPath(*configPath) + newPM.SetConfigExample(GetConfigExampleYAML()) srv.Handler = newPM } } @@ -121,13 +125,15 @@ func main() { // load the initial proxy manager reloadProxyManager() debouncedReload := debounce(time.Second, reloadProxyManager) - if *watchConfig { - defer event.On(func(e proxy.ConfigFileChangedEvent) { - if e.ReloadingState == proxy.ReloadingStateStart { - debouncedReload() - } - })() + // Always listen for API-triggered config changes + defer event.On(func(e proxy.ConfigFileChangedEvent) { + if e.ReloadingState == proxy.ReloadingStateStart { + debouncedReload() + } + })() + + if *watchConfig { fmt.Println("Watching Configuration for changes") go func() { absConfigPath, err := filepath.Abs(*configPath) diff --git a/proxy/config/config.go b/proxy/config/config.go index c474c0894..e5b834cb2 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "io" + "net" "net/url" "os" "regexp" @@ -596,6 +597,70 @@ func SanitizeCommand(cmdStr string) ([]string, error) { return args, nil } +// ParseRPCEndpoints extracts RPC endpoints from command string +// Handles: --rpc host:port,host2:port2 or --rpc=host:port or -rpc host:port +func ParseRPCEndpoints(cmdStr string) ([]string, error) { + args, err := SanitizeCommand(cmdStr) + if err != nil { + return nil, err + } + + var endpoints []string + for i, arg := range args { + if arg == "--rpc" || arg == "-rpc" { + // Collect all non-flag arguments after --rpc + // This handles Windows where shlex splits single-quoted strings with spaces + var parts []string + for j := i + 1; j < len(args) && !strings.HasPrefix(args[j], "-"); j++ { + parts = append(parts, args[j]) + } + if len(parts) > 0 { + // Join parts with space and parse as a single endpoint list + endpoints = parseEndpointList(strings.Join(parts, " ")) + } + } else if strings.HasPrefix(arg, "--rpc=") { + endpoints = parseEndpointList(strings.TrimPrefix(arg, "--rpc=")) + } else if strings.HasPrefix(arg, "-rpc=") { + endpoints = parseEndpointList(strings.TrimPrefix(arg, "-rpc=")) + } + } + + // Validate each endpoint + for _, ep := range endpoints { + if _, _, err := net.SplitHostPort(ep); err != nil { + return nil, fmt.Errorf("invalid RPC endpoint %q: %w", ep, err) + } + } + + return endpoints, nil +} + +func parseEndpointList(s string) []string { + s = strings.TrimSpace(s) + + // Strip surrounding quotes (both single and double) from the whole string + // if they match. This handles cases like: "host:port,host2:port2" + if len(s) >= 2 { + if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"') { + s = s[1 : len(s)-1] + } + } + + parts := strings.Split(s, ",") + var result []string + for _, p := range parts { + p = strings.TrimSpace(p) + // Strip any remaining leading/trailing quotes from individual parts + // This handles Windows where shlex doesn't handle single quotes and + // may split 'host:port, host2:port' into "'host:port," and "host2:port'" + p = strings.Trim(p, "'\"") + if p != "" { + result = append(result, p) + } + } + return result +} + func StripComments(cmdStr string) string { var cleanedLines []string for _, line := range strings.Split(cmdStr, "\n") { diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index 2ea8e4608..2867389b7 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -1438,3 +1438,108 @@ models: }) } + +func TestParseRPCEndpoints_ValidFormats(t *testing.T) { + tests := []struct { + name string + cmd string + expected []string + }{ + { + name: "single endpoint with --rpc", + cmd: "llama-server --rpc localhost:50051 -ngl 99", + expected: []string{"localhost:50051"}, + }, + { + name: "single endpoint with --rpc=", + cmd: "llama-server --rpc=192.168.1.100:50051 -ngl 99", + expected: []string{"192.168.1.100:50051"}, + }, + { + name: "single endpoint with -rpc", + cmd: "llama-server -rpc localhost:50051 -ngl 99", + expected: []string{"localhost:50051"}, + }, + { + name: "single endpoint with -rpc=", + cmd: "llama-server -rpc=localhost:50051 -ngl 99", + expected: []string{"localhost:50051"}, + }, + { + name: "multiple endpoints comma-separated", + cmd: "llama-server --rpc 192.168.1.10:50051,192.168.1.11:50051 -ngl 99", + expected: []string{"192.168.1.10:50051", "192.168.1.11:50051"}, + }, + { + name: "multiple endpoints with spaces trimmed", + cmd: "llama-server --rpc '192.168.1.10:50051, 192.168.1.11:50051' -ngl 99", + expected: []string{"192.168.1.10:50051", "192.168.1.11:50051"}, + }, + { + name: "IPv6 endpoint", + cmd: "llama-server --rpc [::1]:50051 -ngl 99", + expected: []string{"[::1]:50051"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoints, err := ParseRPCEndpoints(tt.cmd) + assert.NoError(t, err) + assert.Equal(t, tt.expected, endpoints) + }) + } +} + +func TestParseRPCEndpoints_NoRPCFlag(t *testing.T) { + cmd := "llama-server -ngl 99 -m model.gguf" + endpoints, err := ParseRPCEndpoints(cmd) + assert.NoError(t, err) + assert.Empty(t, endpoints) +} + +func TestParseRPCEndpoints_InvalidFormats(t *testing.T) { + tests := []struct { + name string + cmd string + wantErr string + }{ + { + name: "missing port", + cmd: "llama-server --rpc localhost -ngl 99", + wantErr: "invalid RPC endpoint", + }, + { + name: "invalid host:port format", + cmd: "llama-server --rpc not-a-valid-endpoint -ngl 99", + wantErr: "invalid RPC endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseRPCEndpoints(tt.cmd) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestParseRPCEndpoints_EmptyEndpointsFiltered(t *testing.T) { + // Empty strings after commas are filtered out + cmd := "llama-server --rpc 'localhost:50051,,' -ngl 99" + endpoints, err := ParseRPCEndpoints(cmd) + assert.NoError(t, err) + assert.Equal(t, []string{"localhost:50051"}, endpoints) +} + +func TestParseRPCEndpoints_MultilineCommand(t *testing.T) { + cmd := `llama-server \ + --rpc localhost:50051 \ + -ngl 99 \ + -m model.gguf` + + endpoints, err := ParseRPCEndpoints(cmd) + assert.NoError(t, err) + assert.Equal(t, []string{"localhost:50051"}, endpoints) +} diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 685687bab..464f8ed3c 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -40,6 +40,12 @@ type ModelConfig struct { // override global setting SendLoadingState *bool `yaml:"sendLoadingState"` + + // RPC health checking + RPCHealthCheck bool `yaml:"rpcHealthCheck"` + // Maximum time in seconds for a request to complete before killing the process + // 0 means no timeout (default) + RequestTimeout int `yaml:"requestTimeout"` } func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { @@ -57,6 +63,8 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { ConcurrencyLimit: 0, Name: "", Description: "", + RPCHealthCheck: false, + RequestTimeout: 0, } // the default cmdStop to taskkill /f /t /pid ${PID} diff --git a/proxy/process.go b/proxy/process.go index 414270595..8ab5faa7b 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -49,6 +49,7 @@ type Process struct { // PR #155 called to cancel the upstream process cmdMutex sync.RWMutex cancelUpstream context.CancelFunc + cmdStarted bool // tracks if cmd.Start() completed successfully // closed when command exits cmdWaitChan chan struct{} @@ -79,18 +80,25 @@ type Process struct { // track the number of failed starts failedStartCount int + + // RPC health checking + rpcEndpoints []string + rpcHealthy atomic.Bool + rpcHealthTicker *time.Ticker + rpcHealthCancel context.CancelFunc + shutdownCtx context.Context // from ProxyManager for graceful shutdown } -func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { +func NewProcess(ID string, healthCheckTimeout int, modelConfig config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor, shutdownCtx context.Context) *Process { concurrentLimit := 10 - if config.ConcurrencyLimit > 0 { - concurrentLimit = config.ConcurrencyLimit + if modelConfig.ConcurrencyLimit > 0 { + concurrentLimit = modelConfig.ConcurrencyLimit } // Setup the reverse proxy. - proxyURL, err := url.Parse(config.Proxy) + proxyURL, err := url.Parse(modelConfig.Proxy) if err != nil { - proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err) + proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, modelConfig.Proxy, err) } var reverseProxy *httputil.ReverseProxy @@ -105,9 +113,9 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr } } - return &Process{ + p := &Process{ ID: ID, - config: config, + config: modelConfig, cmd: nil, reverseProxy: reverseProxy, cancelUpstream: nil, @@ -124,7 +132,25 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr // stop timeout gracefulStopTimeout: 10 * time.Second, cmdWaitChan: make(chan struct{}), + shutdownCtx: shutdownCtx, } + + // Parse RPC endpoints if health checking enabled + if modelConfig.RPCHealthCheck { + endpoints, err := config.ParseRPCEndpoints(modelConfig.Cmd) + if err != nil { + proxyLogger.Errorf("<%s> failed to parse RPC endpoints: %v", ID, err) + } else if len(endpoints) == 0 { + proxyLogger.Warnf("<%s> rpcHealthCheck enabled but no --rpc flag found in cmd", ID) + } else { + p.rpcEndpoints = endpoints + p.rpcHealthy.Store(false) // start unhealthy until first check passes + // Start health checker immediately - runs independent of process state + p.startRPCHealthChecker() + } + } + + return p } // LogMonitor returns the log monitor associated with the process. @@ -250,26 +276,43 @@ func (p *Process) start() error { defer p.waitStarting.Done() cmdContext, ctxCancelUpstream := context.WithCancel(context.Background()) - p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...) - p.cmd.Stdout = p.processLogger - p.cmd.Stderr = p.processLogger - p.cmd.Env = append(p.cmd.Environ(), p.config.Env...) - p.cmd.Cancel = p.cmdStopUpstreamProcess - p.cmd.WaitDelay = p.gracefulStopTimeout - setProcAttributes(p.cmd) + cmd := exec.CommandContext(cmdContext, args[0], args[1:]...) + cmd.Stdout = p.processLogger + cmd.Stderr = p.processLogger + cmd.Env = append(cmd.Environ(), p.config.Env...) + cmd.Cancel = p.cmdStopUpstreamProcess + cmd.WaitDelay = p.gracefulStopTimeout + setProcAttributes(cmd) + + p.failedStartCount++ // this will be reset to zero when the process has successfully started + // Initialize cancelUpstream and cmdWaitChan before Start() so stopCommand() can use them + // if called during startup (e.g., due to timeout) p.cmdMutex.Lock() + p.cmd = cmd p.cancelUpstream = ctxCancelUpstream p.cmdWaitChan = make(chan struct{}) p.cmdMutex.Unlock() - p.failedStartCount++ // this will be reset to zero when the process has successfully started - p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", ")) err = p.cmd.Start() + // Set cmdStarted flag under lock after Start() completes + // This prevents data race with stopCommand() which checks cmdStarted instead of cmd.Process + p.cmdMutex.Lock() + if err == nil { + p.cmdStarted = true + } + p.cmdMutex.Unlock() + // Set process state to failed if err != nil { + // Close cmdWaitChan to prevent stopCommand() from hanging if a timeout + // transitions StateStarting -> StateStopping before Start() completes + p.cmdMutex.Lock() + close(p.cmdWaitChan) + p.cmdMutex.Unlock() + if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { p.forceState(StateStopped) // force it into a stopped state return fmt.Errorf( @@ -381,14 +424,28 @@ func (p *Process) Stop() { // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. func (p *Process) StopImmediately() { - if !isValidTransition(p.CurrentState(), StateStopping) { - return - } + // Try to transition from current state to StateStopping + // Process might be in StateReady or StateStarting when timeout fires + // Retry on ErrExpectedStateMismatch to handle transient state changes + for { + currentState := p.CurrentState() + if !isValidTransition(currentState, StateStopping) { + return + } - p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState()) - if curState, err := p.swapState(StateReady, StateStopping); err != nil { - p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) - return + p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, currentState) + + if _, err := p.swapState(currentState, StateStopping); err != nil { + if err == ErrExpectedStateMismatch { + // State changed between CurrentState() and swapState(), retry + continue + } + p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v", p.ID, currentState, err) + return + } + + // Successfully transitioned to StateStopping + break } p.stopCommand() @@ -422,14 +479,28 @@ func (p *Process) stopCommand() { p.cmdMutex.RLock() cancelUpstream := p.cancelUpstream cmdWaitChan := p.cmdWaitChan + cmdStarted := p.cmdStarted p.cmdMutex.RUnlock() + // If cancelUpstream is nil, the process was never actually started + // (e.g., forceState() was used in tests). Just return silently. if cancelUpstream == nil { - p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID) + p.proxyLogger.Debugf("<%s> stopCommand: cancelUpstream is nil, process was never started", p.ID) return } + // Always cancel the context to stop the command cancelUpstream() + + // If cmdStarted is false, the process never actually started + // (cmd.Start() was never called or failed), so skip waiting on cmdWaitChan + // to avoid hanging. This can happen if a timeout transitions StateStarting + // to StateStopping before cmd.Start() completes. + if !cmdStarted { + p.proxyLogger.Debugf("<%s> stopCommand: process never started (cmdStarted is false), skipping wait", p.ID) + return + } + <-cmdWaitChan } @@ -500,6 +571,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { p.inFlightRequests.Done() }() + // Start timeout monitoring if requestTimeout is configured + var requestCancel context.CancelFunc + var timeoutCancel context.CancelFunc + + if p.config.RequestTimeout > 0 { + timeoutDuration := time.Duration(p.config.RequestTimeout) * time.Second + + // Add timeout to request context to cancel the request when exceeded + requestCtx, cancel := context.WithTimeout(r.Context(), timeoutDuration) + requestCancel = cancel + r = r.Clone(requestCtx) + + // Create a separate timeout context for monitoring only + // Use context.Background() to ensure we detect our configured timeout, + // not parent-imposed deadlines that would cause misattribution + timeoutCtx, cancel := context.WithTimeout(context.Background(), timeoutDuration) + timeoutCancel = cancel + + go func() { + <-timeoutCtx.Done() + if timeoutCtx.Err() == context.DeadlineExceeded { + p.proxyLogger.Warnf("<%s> Request timeout exceeded (%v), force stopping process to prevent GPU blocking", p.ID, timeoutDuration) + // Force stop the process - this will kill the underlying inference process + p.StopImmediately() + } + }() + + // Ensure both timeouts are cancelled when request completes + defer func() { + if requestCancel != nil { + requestCancel() + } + if timeoutCancel != nil { + timeoutCancel() + } + }() + } + // for #366 // - extract streaming param from request context, should have been set by proxymanager var srw *statusResponseWriter @@ -606,6 +715,7 @@ func (p *Process) waitForCmd() { p.cmdMutex.Lock() close(p.cmdWaitChan) + p.cmdStarted = false p.cmdMutex.Unlock() } @@ -877,3 +987,72 @@ func (s *statusResponseWriter) Flush() { flusher.Flush() } } + +// startRPCHealthChecker launches background goroutine for RPC health monitoring. +// Runs independently of process state - checks RPC endpoints regardless of whether +// the model is loaded, starting, stopped, etc. +func (p *Process) startRPCHealthChecker() { + if !p.config.RPCHealthCheck || len(p.rpcEndpoints) == 0 { + return + } + + ctx, cancel := context.WithCancel(p.shutdownCtx) + p.rpcHealthCancel = cancel + p.rpcHealthTicker = time.NewTicker(10 * time.Second) + + go func() { + defer p.rpcHealthTicker.Stop() + + // Run initial check immediately + p.checkRPCHealth() + + for { + select { + case <-ctx.Done(): + p.proxyLogger.Debugf("<%s> RPC health checker shutting down", p.ID) + return + case <-p.rpcHealthTicker.C: + // Check regardless of process state + p.checkRPCHealth() + } + } + }() +} + +func (p *Process) checkRPCHealth() { + allHealthy := true + + for _, endpoint := range p.rpcEndpoints { + dialer := net.Dialer{Timeout: 3 * time.Second} + conn, err := dialer.Dial("tcp", endpoint) + if err != nil { + // Ignore I/O timeout errors - don't mark as unhealthy + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + p.proxyLogger.Debugf("<%s> RPC endpoint %s timeout (ignoring): %v", p.ID, endpoint, err) + continue + } + p.proxyLogger.Warnf("<%s> RPC endpoint %s unhealthy: %v", p.ID, endpoint, err) + allHealthy = false + break + } + conn.Close() + } + + wasHealthy := p.rpcHealthy.Load() + p.rpcHealthy.Store(allHealthy) + + // Log state changes + if wasHealthy && !allHealthy { + p.proxyLogger.Infof("<%s> RPC endpoints now UNHEALTHY", p.ID) + } else if !wasHealthy && allHealthy { + p.proxyLogger.Infof("<%s> RPC endpoints now HEALTHY", p.ID) + } +} + +// IsRPCHealthy returns true if RPC health checking is disabled or all endpoints healthy +func (p *Process) IsRPCHealthy() bool { + if !p.config.RPCHealthCheck || len(p.rpcEndpoints) == 0 { + return true // not using RPC health checks + } + return p.rpcHealthy.Load() +} diff --git a/proxy/process_rpc_health_test.go b/proxy/process_rpc_health_test.go new file mode 100644 index 000000000..3e055329f --- /dev/null +++ b/proxy/process_rpc_health_test.go @@ -0,0 +1,124 @@ +package proxy + +import ( + "context" + "io" + "testing" + + "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/stretchr/testify/assert" +) + +func TestProcess_RPCHealthIndependentOfState(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + modelConfig := config.ModelConfig{ + Cmd: "llama-server --rpc 127.0.0.1:50051", + Proxy: "http://localhost:8080", + RPCHealthCheck: true, + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Verify endpoints were parsed + assert.NotEmpty(t, process.rpcEndpoints, "RPC endpoints should be parsed from cmd") + assert.Equal(t, []string{"127.0.0.1:50051"}, process.rpcEndpoints) + + // Initially should be unhealthy (false) until first check + assert.False(t, process.rpcHealthy.Load(), "RPC health should start as false") + + // Health checker should be running regardless of process state + assert.NotNil(t, process.rpcHealthTicker, "Health checker ticker should be running") + assert.NotNil(t, process.rpcHealthCancel, "Health checker should have cancel func") + + // Process state should not affect health checking + assert.Equal(t, StateStopped, process.CurrentState(), "Process should be in stopped state") + + // Health check runs independently - simulate RPC becoming healthy + process.rpcHealthy.Store(true) + assert.True(t, process.IsRPCHealthy(), "Process should report healthy regardless of state") +} + +func TestProcess_RPCHealthCheckDisabled(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx := context.Background() + + modelConfig := config.ModelConfig{ + Cmd: "llama-server --rpc 127.0.0.1:50051", + Proxy: "http://localhost:8080", + RPCHealthCheck: false, // Disabled + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Should always return healthy when disabled + assert.True(t, process.IsRPCHealthy(), "Should return true when RPC health check is disabled") +} + +func TestProcess_RPCHealthCheckNoEndpoints(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx := context.Background() + + modelConfig := config.ModelConfig{ + Cmd: "llama-server --port 8080", // No --rpc flag + Proxy: "http://localhost:8080", + RPCHealthCheck: true, // Enabled but no endpoints + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Should have no endpoints + assert.Empty(t, process.rpcEndpoints, "Should have no RPC endpoints when --rpc flag is missing") + + // Should return healthy when no endpoints configured (treat as not using RPC) + assert.True(t, process.IsRPCHealthy(), "Should return true when no RPC endpoints found") + + // Health checker should NOT start when no endpoints + assert.Nil(t, process.rpcHealthTicker, "Health checker should not run without endpoints") + assert.Nil(t, process.rpcHealthCancel, "Health checker cancel should be nil") +} + +func TestProcess_RPCHealthCheckTimeoutIgnored(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Use an IP address that will timeout (non-routable IP) + // 192.0.2.0/24 is reserved for documentation/testing (RFC 5737) + modelConfig := config.ModelConfig{ + Cmd: "llama-server --rpc 192.0.2.1:50051", + Proxy: "http://localhost:8080", + RPCHealthCheck: true, + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Verify endpoints were parsed + assert.NotEmpty(t, process.rpcEndpoints, "RPC endpoints should be parsed from cmd") + assert.Equal(t, []string{"192.0.2.1:50051"}, process.rpcEndpoints) + + // Initially should be unhealthy (false) until first check + assert.False(t, process.rpcHealthy.Load(), "RPC health should start as false") + + // Manually run health check - this should timeout but not mark as unhealthy + process.checkRPCHealth() + + // After timeout, should remain at initial state (false) but not be marked unhealthy + // The key is that timeout doesn't change the state - it's effectively a no-op + // To test this properly, let's set it to healthy first, then see if timeout changes it + process.rpcHealthy.Store(true) + initialState := process.rpcHealthy.Load() + assert.True(t, initialState, "Should be healthy before timeout check") + + // Run health check that will timeout + process.checkRPCHealth() + + // After timeout, should still be healthy (timeout is ignored) + assert.True(t, process.rpcHealthy.Load(), "Should remain healthy after timeout") +} diff --git a/proxy/process_test.go b/proxy/process_test.go index dd9e9d8ab..034646157 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -18,6 +19,15 @@ var ( debugLogger = NewLogMonitorWriter(os.Stdout) ) +// getSleepCommand returns a platform-appropriate command to sleep for the given number of seconds +func getSleepCommand(seconds int) string { + if runtime.GOOS == "windows" { + // Use full path to avoid conflict with GNU coreutils timeout in Git Bash + return fmt.Sprintf("C:\\Windows\\System32\\timeout.exe /t %d /nobreak", seconds) + } + return fmt.Sprintf("sleep %d", seconds) +} + func init() { // flip to help with debugging tests if false { @@ -33,7 +43,7 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { config := getTestSimpleResponderConfig(expectedMessage) // Create a process - process := NewProcess("test-process", 5, config, debugLogger, debugLogger) + process := NewProcess("test-process", 5, config, debugLogger, debugLogger, context.Background()) defer process.Stop() req := httptest.NewRequest("GET", "/test", nil) @@ -69,7 +79,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) { expectedMessage := "testing91931" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("test-process", 5, config, debugLogger, debugLogger) + process := NewProcess("test-process", 5, config, debugLogger, debugLogger, context.Background()) defer process.Stop() var wg sync.WaitGroup @@ -97,7 +107,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("broken", 1, config, debugLogger, debugLogger) + process := NewProcess("broken", 1, config, debugLogger, debugLogger, context.Background()) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -122,7 +132,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { conf.UnloadAfter = 3 // seconds assert.Equal(t, 3, conf.UnloadAfter) - process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger) + process := NewProcess("ttl_test", 2, conf, debugLogger, debugLogger, context.Background()) defer process.Stop() // this should take 4 seconds @@ -164,7 +174,7 @@ func TestProcess_LowTTLValue(t *testing.T) { conf.UnloadAfter = 1 // second assert.Equal(t, 1, conf.UnloadAfter) - process := NewProcess("ttl", 2, conf, debugLogger, debugLogger) + process := NewProcess("ttl", 2, conf, debugLogger, debugLogger, context.Background()) defer process.Stop() for i := 0; i < 100; i++ { @@ -191,7 +201,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { expectedMessage := "12345" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("t", 10, config, debugLogger, debugLogger) + process := NewProcess("t", 10, config, debugLogger, debugLogger, context.Background()) defer process.Stop() results := map[string]string{ @@ -264,7 +274,7 @@ func TestProcess_SwapState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) + p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger, context.Background()) p.state = test.currentState resultState, err := p.swapState(test.expectedState, test.newState) @@ -297,7 +307,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { config.Proxy = "http://localhost:9998/test" healthCheckTTLSeconds := 30 - process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) + process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger, context.Background()) // make it a lot faster process.healthCheckLoopInterval = time.Second @@ -332,7 +342,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger) + process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger, context.Background()) process.healthCheckLoopInterval = time.Second // make it faster err := process.start() assert.Equal(t, "upstream command exited prematurely but successfully", err.Error()) @@ -350,7 +360,7 @@ func TestProcess_ConcurrencyLimit(t *testing.T) { // only allow 1 concurrent request at a time config.ConcurrencyLimit = 1 - process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) + process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger, context.Background()) assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore)) defer process.Stop() @@ -375,7 +385,7 @@ func TestProcess_StopImmediately(t *testing.T) { expectedMessage := "test_stop_immediate" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) + process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger, context.Background()) defer process.Stop() err := process.start() @@ -415,7 +425,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger) + process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger, context.Background()) defer process.Stop() // reduce to make testing go faster @@ -465,7 +475,7 @@ func TestProcess_StopCmd(t *testing.T) { conf.CmdStop = "kill -TERM ${PID}" } - process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger) + process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger, context.Background()) defer process.Stop() err := process.start() @@ -485,8 +495,8 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) { // ensure the additiona variables are appended to the process' environment configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2") - process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger) - process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger) + process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger, context.Background()) + process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger, context.Background()) process1.start() defer process1.Stop() @@ -521,7 +531,7 @@ func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) { expectedMessage := "panic_test" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("panic-test", 5, config, debugLogger, debugLogger) + process := NewProcess("panic-test", 5, config, debugLogger, debugLogger, context.Background()) defer process.Stop() // Start the process @@ -569,3 +579,106 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) { } return w.ResponseRecorder.Write(b) } + +// TestProcess_StopCommandDoesNotHangWhenStartFails verifies that stopCommand() +// does not hang when cmd.Start() fails or hasn't completed yet. This can happen +// when a timeout transitions StateStarting -> StateStopping before cmd.Start() +// completes. +func TestProcess_StopCommandDoesNotHangWhenStartFails(t *testing.T) { + // Create a process with a command that will fail to start + config := config.ModelConfig{ + Cmd: "nonexistent-command-that-will-fail", + Proxy: "http://127.0.0.1:9999", + CheckEndpoint: "/health", + } + + process := NewProcess("fail-test", 1, config, debugLogger, debugLogger, context.Background()) + + // Try to start the process - this will fail + err := process.start() + assert.Error(t, err) + assert.Contains(t, err.Error(), "start() failed for command") + assert.Equal(t, StateStopped, process.CurrentState()) + + // Now try to stop the process - this should not hang + // Create a channel to track if stopCommand completes + done := make(chan struct{}) + go func() { + process.stopCommand() + close(done) + }() + + // Wait for stopCommand to complete with a timeout + select { + case <-done: + // Success - stopCommand completed without hanging + case <-time.After(2 * time.Second): + t.Fatal("stopCommand() hung when process never started") + } +} + +// TestProcess_StopImmediatelyDuringStartup verifies that StopImmediately() +// can safely interrupt a process during StateStarting without hanging. +func TestProcess_StopImmediatelyDuringStartup(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test") + } + + // Use a platform-appropriate command that takes a while to start but won't respond to health checks + cmd := getSleepCommand(10) + config := config.ModelConfig{ + Cmd: cmd, + Proxy: "http://127.0.0.1:9999", + CheckEndpoint: "/health", + } + + process := NewProcess("interrupt-test", 20, config, debugLogger, debugLogger, context.Background()) + process.healthCheckLoopInterval = 100 * time.Millisecond + + // Start the process in a goroutine (it will be in StateStarting) + startDone := make(chan struct{}) + errCh := make(chan error, 1) + go func() { + err := process.start() + errCh <- err + close(startDone) + }() + + // Wait a bit for the process to enter StateStarting + <-time.After(200 * time.Millisecond) + currentState := process.CurrentState() + assert.Equal(t, StateStarting, currentState, "Process should be in StateStarting") + + // Now call StopImmediately while in StateStarting + // This simulates a timeout firing during startup + stopDone := make(chan struct{}) + go func() { + process.StopImmediately() + close(stopDone) + }() + + // Verify StopImmediately completes without hanging + select { + case <-stopDone: + // Success - StopImmediately completed + case <-time.After(3 * time.Second): + t.Fatal("StopImmediately() hung when called during StateStarting") + } + + // Wait for start() to complete + select { + case <-startDone: + // Success + case <-time.After(2 * time.Second): + t.Fatal("start() did not complete after StopImmediately") + } + + // Verify start() returned an error due to StopImmediately interrupt + err := <-errCh + assert.Error(t, err) + + // Process should be in StateStopped or StateStopping + finalState := process.CurrentState() + assert.True(t, finalState == StateStopped || finalState == StateStopping, + "Expected StateStopped or StateStopping, got %s", finalState) +} diff --git a/proxy/process_timeout_test.go b/proxy/process_timeout_test.go new file mode 100644 index 000000000..dd109d238 --- /dev/null +++ b/proxy/process_timeout_test.go @@ -0,0 +1,126 @@ +package proxy + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/proxy/config" +) + +// TestProcess_RequestTimeout verifies that requestTimeout actually kills the process +func TestProcess_RequestTimeout(t *testing.T) { + // Create error channel to report handler errors from the mock server goroutine + srvErrCh := make(chan error, 1) + + // Create a mock server that simulates a long-running inference + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Mock server received request") + + // Simulate streaming response that takes 60 seconds + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + srvErrCh <- fmt.Errorf("Expected http.ResponseWriter to be an http.Flusher") + return + } + + // Stream data for 60 seconds + for i := 0; i < 60; i++ { + select { + case <-r.Context().Done(): + t.Logf("Mock server: client disconnected after %d seconds", i) + return + default: + fmt.Fprintf(w, "data: token %d\n\n", i) + flusher.Flush() + time.Sleep(1 * time.Second) + } + } + t.Logf("Mock server completed full 60 second response") + })) + defer mockServer.Close() + + // Setup process logger - use NewLogMonitor() to avoid race in test + processLogger := NewLogMonitor() + proxyLogger := NewLogMonitor() + + // Create process with 5 second request timeout + cfg := config.ModelConfig{ + Proxy: mockServer.URL, + CheckEndpoint: "none", // skip health check + RequestTimeout: 5, // 5 second timeout + } + + p := NewProcess("test-timeout", 30, cfg, processLogger, proxyLogger, context.Background()) + p.gracefulStopTimeout = 2 * time.Second // shorter for testing + + // Manually set state to ready (skip actual process start) + p.forceState(StateReady) + + // Make a request that should timeout + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + start := time.Now() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + p.ProxyRequest(w, req) + }() + + // Wait for either completion or timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case err := <-srvErrCh: + // Handler error - fail the test immediately + t.Fatalf("Mock server handler error: %v", err) + + case <-done: + elapsed := time.Since(start) + t.Logf("Request completed after %v", elapsed) + + // Check for any deferred server errors + select { + case err := <-srvErrCh: + t.Fatalf("Mock server handler error: %v", err) + default: + // No server errors, continue with assertions + } + + // Request should complete within timeout + gracefulStopTimeout + some buffer + maxExpected := time.Duration(cfg.RequestTimeout+2)*time.Second + 3*time.Second + if elapsed > maxExpected { + t.Errorf("Request took %v, expected less than %v with 5s timeout", elapsed, maxExpected) + } else { + t.Logf("✓ Request was properly terminated by timeout") + } + + case <-time.After(15 * time.Second): + t.Fatalf("Test timed out after 15 seconds - request should have been killed by requestTimeout") + } +} + +// TestProcess_RequestTimeoutWithRealProcess tests with an actual process +func TestProcess_RequestTimeoutWithRealProcess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test with real process in short mode") + } + + // This test would require a real llama.cpp server or similar + // For now, we can skip it or mock it + t.Skip("Requires real inference server") +} diff --git a/proxy/processgroup.go b/proxy/processgroup.go index b401d8a68..c920f302c 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "fmt" "net/http" "slices" @@ -24,9 +25,11 @@ type ProcessGroup struct { // map of current processes processes map[string]*Process lastUsedProcess string + + shutdownCtx context.Context } -func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { +func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor, shutdownCtx context.Context) *ProcessGroup { groupConfig, ok := config.Groups[id] if !ok { panic("Unable to find configuration for group id: " + id) @@ -41,13 +44,14 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u proxyLogger: proxyLogger, upstreamLogger: upstreamLogger, processes: make(map[string]*Process), + shutdownCtx: shutdownCtx, } // Create a Process for each member in the group for _, modelID := range groupConfig.Members { modelConfig, modelID, _ := pg.config.FindConfig(modelID) processLogger := NewLogMonitorWriter(upstreamLogger) - process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger) + process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger, shutdownCtx) pg.processes[modelID] = process } diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index 6b90f4433..55e5276a8 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -2,6 +2,7 @@ package proxy import ( "bytes" + "context" "net/http" "net/http/httptest" "sync" @@ -35,12 +36,12 @@ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ }) func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { - pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger, context.Background()) assert.True(t, pg.HasMember("model5")) } func TestProcessGroup_HasMember(t *testing.T) { - pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger, context.Background()) assert.True(t, pg.HasMember("model1")) assert.True(t, pg.HasMember("model2")) assert.False(t, pg.HasMember("model3")) @@ -74,7 +75,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { }, }) - pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger, context.Background()) defer pg.StopProcesses(StopWaitForInflightRequest) tests := []string{"model1", "model2", "model3", "model4", "model5"} @@ -96,7 +97,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { } func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { - pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger, context.Background()) defer pg.StopProcesses(StopWaitForInflightRequest) tests := []string{"model3", "model4"} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index c5042bab8..a1942c608 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -88,6 +88,12 @@ type ProxyManager struct { commit string version string + // config file path for editing + configPath string + + // embedded example config + configExample []byte + // peer proxy see: #296, #433 peerProxy *PeerProxy } @@ -205,7 +211,7 @@ func New(proxyConfig config.Config) *ProxyManager { // create the process groups for groupID := range proxyConfig.Groups { - processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger) + processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger, shutdownCtx) pm.processGroups[groupID] = processGroup } @@ -532,6 +538,16 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { continue } + // Filter models with unhealthy RPC endpoints + if processGroup := pm.findGroupByModelName(id); processGroup != nil { + if process, ok := processGroup.GetMember(id); ok { + if !process.IsRPCHealthy() { + pm.proxyLogger.Debugf("<%s> filtered from /v1/models (unhealthy RPC)", id) + continue + } + } + } + data = append(data, newRecord(id, modelConfig)) // Include aliases @@ -684,6 +700,15 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { return } + // Check RPC health before processing request + if process, ok := processGroup.GetMember(modelID); ok { + if !process.IsRPCHealthy() { + pm.sendErrorResponse(c, http.StatusServiceUnavailable, + fmt.Sprintf("model %s unavailable (RPC endpoints unhealthy)", modelID)) + return + } + } + // issue #69 allow custom model names to be sent to upstream useModelName := pm.config.Models[modelID].UseModelName if useModelName != "" { @@ -1083,3 +1108,15 @@ func (pm *ProxyManager) SetVersion(buildDate string, commit string, version stri pm.commit = commit pm.version = version } + +func (pm *ProxyManager) SetConfigPath(configPath string) { + pm.Lock() + defer pm.Unlock() + pm.configPath = configPath +} + +func (pm *ProxyManager) SetConfigExample(configExample []byte) { + pm.Lock() + defer pm.Unlock() + pm.configExample = configExample +} diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index 00897c650..a65818d0b 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" + "os" "sort" "strconv" "strings" @@ -33,6 +35,9 @@ func addApiHandlers(pm *ProxyManager) { apiGroup.GET("/events", pm.apiSendEvents) apiGroup.GET("/metrics", pm.apiGetMetrics) apiGroup.GET("/version", pm.apiGetVersion) + apiGroup.GET("/config/current", pm.apiGetCurrentConfig) + apiGroup.GET("/config/example", pm.apiGetExampleConfig) + apiGroup.POST("/config", pm.apiUpdateConfig) apiGroup.GET("/captures/:id", pm.apiGetCapture) } } @@ -276,6 +281,68 @@ func (pm *ProxyManager) apiGetVersion(c *gin.Context) { }) } +func (pm *ProxyManager) apiGetCurrentConfig(c *gin.Context) { + pm.Lock() + configPath := pm.configPath + pm.Unlock() + + if configPath == "" { + pm.sendErrorResponse(c, http.StatusNotFound, "Config file path not set") + return + } + + data, err := os.ReadFile(configPath) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("Failed to read config file: %v", err)) + return + } + + c.Data(http.StatusOK, "text/yaml; charset=utf-8", data) +} + +func (pm *ProxyManager) apiGetExampleConfig(c *gin.Context) { + pm.Lock() + data := pm.configExample + pm.Unlock() + + if data == nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, "Example config not available") + return + } + + c.Data(http.StatusOK, "text/yaml; charset=utf-8", data) +} + +func (pm *ProxyManager) apiUpdateConfig(c *gin.Context) { + pm.Lock() + configPath := pm.configPath + pm.Unlock() + + if configPath == "" { + pm.sendErrorResponse(c, http.StatusBadRequest, "Config file path not set") + return + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("Failed to read request body: %v", err)) + return + } + + // Write to config file + if err := os.WriteFile(configPath, body, 0644); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("Failed to write config file: %v", err)) + return + } + + // Trigger config reload event + event.Emit(ConfigFileChangedEvent{ + ReloadingState: ReloadingStateStart, + }) + + c.JSON(http.StatusOK, gin.H{"message": "Config updated successfully. Reloading..."}) +} + func (pm *ProxyManager) apiGetCapture(c *gin.Context) { idStr := c.Param("id") id, err := strconv.Atoi(idStr) diff --git a/ui-svelte/package-lock.json b/ui-svelte/package-lock.json index 8c86b603c..5366c17fd 100644 --- a/ui-svelte/package-lock.json +++ b/ui-svelte/package-lock.json @@ -8,7 +8,13 @@ "name": "ui-svelte", "version": "0.0.0", "dependencies": { + "@codemirror/lang-yaml": "^6.1.2", + "@codemirror/language": "^6.12.1", + "@codemirror/state": "^6.5.4", + "@codemirror/view": "^6.39.12", + "codemirror": "^6.0.2", "highlight.js": "^11.11.1", + "js-yaml": "^4.1.1", "katex": "^0.16.28", "lucide-svelte": "^0.563.0", "rehype-katex": "^7.0.1", @@ -36,6 +42,102 @@ "vitest": "^4.0.18" } }, + "node_modules/@codemirror/autocomplete": { + "version": "6.20.1", + "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.20.1.tgz", + "integrity": "sha512-1cvg3Vz1dSSToCNlJfRA2WSI4ht3K+WplO0UMOgmUYPivCyy2oueZY6Lx7M9wThm7SDUBViRmuT+OG/i8+ON9A==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.17.0", + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@codemirror/commands": { + "version": "6.10.2", + "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.10.2.tgz", + "integrity": "sha512-vvX1fsih9HledO1c9zdotZYUZnE4xV0m6i3m25s5DIfXofuprk6cRcLUZvSk3CASUbwjQX21tOGbkY2BH8TpnQ==", + "license": "MIT", + "dependencies": { + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.4.0", + "@codemirror/view": "^6.27.0", + "@lezer/common": "^1.1.0" + } + }, + "node_modules/@codemirror/lang-yaml": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/@codemirror/lang-yaml/-/lang-yaml-6.1.2.tgz", + "integrity": "sha512-dxrfG8w5Ce/QbT7YID7mWZFKhdhsaTNOYjOkSIMt1qmC4VQnXSDSYVHHHn8k6kJUfIhtLo8t1JJgltlxWdsITw==", + "license": "MIT", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@lezer/common": "^1.2.0", + "@lezer/highlight": "^1.2.0", + "@lezer/lr": "^1.0.0", + "@lezer/yaml": "^1.0.0" + } + }, + "node_modules/@codemirror/language": { + "version": "6.12.2", + "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.12.2.tgz", + "integrity": "sha512-jEPmz2nGGDxhRTg3lTpzmIyGKxz3Gp3SJES4b0nAuE5SWQoKdT5GoQ69cwMmFd+wvFUhYirtDTr0/DRHpQAyWg==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.23.0", + "@lezer/common": "^1.5.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.0.0", + "style-mod": "^4.0.0" + } + }, + "node_modules/@codemirror/lint": { + "version": "6.9.5", + "resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.9.5.tgz", + "integrity": "sha512-GElsbU9G7QT9xXhpUg1zWGmftA/7jamh+7+ydKRuT0ORpWS3wOSP0yT1FOlIZa7mIJjpVPipErsyvVqB9cfTFA==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.35.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/search": { + "version": "6.6.0", + "resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.6.0.tgz", + "integrity": "sha512-koFuNXcDvyyotWcgOnZGmY7LZqEOXZaaxD/j6n18TCLx2/9HieZJ5H6hs1g8FiRxBD0DNfs0nXn17g872RmYdw==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.37.0", + "crelt": "^1.0.5" + } + }, + "node_modules/@codemirror/state": { + "version": "6.5.4", + "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.5.4.tgz", + "integrity": "sha512-8y7xqG/hpB53l25CIoit9/ngxdfoG+fx+V3SHBrinnhOtLvKHRyAJJuHzkWrR4YXXLX8eXBsejgAAxHUOdW1yw==", + "license": "MIT", + "dependencies": { + "@marijn/find-cluster-break": "^1.0.0" + } + }, + "node_modules/@codemirror/view": { + "version": "6.39.16", + "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.39.16.tgz", + "integrity": "sha512-m6S22fFpKtOWhq8HuhzsI1WzUP/hB9THbDj0Tl5KX4gbO6Y91hwBl7Yky33NdvB6IffuRFiBxf1R8kJMyXmA4Q==", + "license": "MIT", + "dependencies": { + "@codemirror/state": "^6.5.0", + "crelt": "^1.0.6", + "style-mod": "^4.1.0", + "w3c-keyname": "^2.2.4" + } + }, "node_modules/@esbuild/aix-ppc64": { "version": "0.25.12", "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz", @@ -523,6 +625,47 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@lezer/common": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.5.1.tgz", + "integrity": "sha512-6YRVG9vBkaY7p1IVxL4s44n5nUnaNnGM2/AckNgYOnxTG2kWh1vR8BMxPseWPjRNpb5VtXnMpeYAEAADoRV1Iw==", + "license": "MIT" + }, + "node_modules/@lezer/highlight": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.3.tgz", + "integrity": "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.3.0" + } + }, + "node_modules/@lezer/lr": { + "version": "1.4.8", + "resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.8.tgz", + "integrity": "sha512-bPWa0Pgx69ylNlMlPvBPryqeLYQjyJjqPx+Aupm5zydLIF3NE+6MMLT8Yi23Bd9cif9VS00aUebn+6fDIGBcDA==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.0.0" + } + }, + "node_modules/@lezer/yaml": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@lezer/yaml/-/yaml-1.0.4.tgz", + "integrity": "sha512-2lrrHqxalACEbxIbsjhqGpSW8kWpUKuY6RHgnSAFZa6qK62wvnPxA8hGOwOoDbwHcOFs5M4o27mjGu+P7TvBmw==", + "license": "MIT", + "dependencies": { + "@lezer/common": "^1.2.0", + "@lezer/highlight": "^1.0.0", + "@lezer/lr": "^1.4.0" + } + }, + "node_modules/@marijn/find-cluster-break": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz", + "integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==", + "license": "MIT" + }, "node_modules/@rollup/pluginutils": { "version": "5.3.0", "resolved": "https://registry.npmjs.org/@rollup/pluginutils/-/pluginutils-5.3.0.tgz", @@ -1446,6 +1589,12 @@ "node": ">=0.4.0" } }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "license": "Python-2.0" + }, "node_modules/aria-query": { "version": "5.3.2", "resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.2.tgz", @@ -1559,6 +1708,21 @@ "node": ">=6" } }, + "node_modules/codemirror": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz", + "integrity": "sha512-VhydHotNW5w1UGK0Qj96BwSk/Zqbp9WbnyK2W/eVMv4QyF41INRGpjUhFJY7/uDNuudSc33a/PKr4iDqRduvHw==", + "license": "MIT", + "dependencies": { + "@codemirror/autocomplete": "^6.0.0", + "@codemirror/commands": "^6.0.0", + "@codemirror/language": "^6.0.0", + "@codemirror/lint": "^6.0.0", + "@codemirror/search": "^6.0.0", + "@codemirror/state": "^6.0.0", + "@codemirror/view": "^6.0.0" + } + }, "node_modules/comma-separated-tokens": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/comma-separated-tokens/-/comma-separated-tokens-2.0.3.tgz", @@ -1578,6 +1742,12 @@ "node": ">= 12" } }, + "node_modules/crelt": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", + "integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==", + "license": "MIT" + }, "node_modules/debug": { "version": "4.4.3", "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", @@ -2038,6 +2208,18 @@ "jiti": "lib/jiti-cli.mjs" } }, + "node_modules/js-yaml": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + }, "node_modules/katex": { "version": "0.16.28", "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.28.tgz", @@ -3556,6 +3738,12 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/style-mod": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz", + "integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==", + "license": "MIT" + }, "node_modules/svelte": { "version": "5.48.5", "resolved": "https://registry.npmjs.org/svelte/-/svelte-5.48.5.tgz", @@ -4072,6 +4260,12 @@ } } }, + "node_modules/w3c-keyname": { + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz", + "integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==", + "license": "MIT" + }, "node_modules/web-namespaces": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/web-namespaces/-/web-namespaces-2.0.1.tgz", diff --git a/ui-svelte/package.json b/ui-svelte/package.json index 7afbc757e..4a8299a62 100644 --- a/ui-svelte/package.json +++ b/ui-svelte/package.json @@ -26,7 +26,13 @@ "vitest": "^4.0.18" }, "dependencies": { + "@codemirror/lang-yaml": "^6.1.2", + "@codemirror/language": "^6.12.1", + "@codemirror/state": "^6.5.4", + "@codemirror/view": "^6.39.12", + "codemirror": "^6.0.2", "highlight.js": "^11.11.1", + "js-yaml": "^4.1.1", "katex": "^0.16.28", "lucide-svelte": "^0.563.0", "rehype-katex": "^7.0.1", diff --git a/ui-svelte/src/App.svelte b/ui-svelte/src/App.svelte index f3ca909bf..3d74b34ee 100644 --- a/ui-svelte/src/App.svelte +++ b/ui-svelte/src/App.svelte @@ -7,6 +7,7 @@ import Activity from "./routes/Activity.svelte"; import Playground from "./routes/Playground.svelte"; import PlaygroundStub from "./routes/PlaygroundStub.svelte"; + import Config from "./routes/Config.svelte"; import { enableAPIEvents } from "./stores/api"; import { initScreenWidth, isDarkMode, appTitle, connectionState } from "./stores/theme"; import { currentRoute } from "./stores/route"; @@ -16,6 +17,7 @@ "/models": Models, "/logs": LogViewer, "/activity": Activity, + "/config": Config, "*": PlaygroundStub, }; diff --git a/ui-svelte/src/components/Header.svelte b/ui-svelte/src/components/Header.svelte index c3cf4a8f5..6243c969b 100644 --- a/ui-svelte/src/components/Header.svelte +++ b/ui-svelte/src/components/Header.svelte @@ -78,6 +78,14 @@ > Logs + + Config + + + + + + + {#if validationError} +
+ Validation Error: {validationError} +
+ {/if} + + {#if error} +
+ {error} +
+ {/if} + + {#if loading} +
+
Loading configuration...
+
+ {:else} +
+ +
+

Current Config (Editable)

+
+
+ + +
+

Example Config (Reference)

+
+
+
+ {/if} +