diff --git a/README.md b/README.md index 8d372c108..40e7182a1 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and - Request timeout protection with `requestTimeout` to prevent runaway inference - Reliable Docker and Podman support using `cmd` and `cmdStop` together - Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) + - RPC health checking for distributed inference - conditionally expose models based on RPC server availability ### Web UI @@ -175,6 +176,7 @@ Almost all configuration settings are optional and can be added one step at a ti - `useModelName` to override model names sent to upstream servers - `${PORT}` automatic port variables for dynamic port assignment - `filters` rewrite parts of requests before sending to the upstream server + - `rpcHealthCheck` monitor RPC server health for distributed inference models See the [configuration documentation](docs/configuration.md) for all options. diff --git a/config-schema.json b/config-schema.json index 9b77344ad..d0bfd2f5e 100644 --- a/config-schema.json +++ b/config-schema.json @@ -226,6 +226,11 @@ "type": "boolean", "default": false, "description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests." + }, + "rpcHealthCheck": { + "type": "boolean", + "default": false, + "description": "Enable TCP health checks for RPC endpoints specified in cmd. When enabled, parses --rpc host:port[,host:port,...] from cmd and performs health checks every 30 seconds. Models with unhealthy RPC endpoints are filtered from /v1/models and return 503 on inference requests." } } } diff --git a/config.example.yaml b/config.example.yaml index 0ef80c02a..defb8e55c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -272,6 +272,24 @@ models: unlisted: true cmd: llama-server --port ${PORT} -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 0 + # RPC health check example for distributed inference: + "qwen-distributed": + # rpcHealthCheck: enable TCP health checks for RPC endpoints + # - optional, default: false + # - when enabled, parses --rpc host:port[,host:port,...] from cmd + # - performs TCP connectivity checks every 30 seconds + # - model is only listed in /v1/models when ALL RPC endpoints are healthy + # - inference requests to unhealthy models return HTTP 503 + # - useful for distributed inference with llama.cpp's rpc-server + rpcHealthCheck: true + cmd: | + llama-server --port ${PORT} + --rpc 192.168.1.10:50051,192.168.1.11:50051 + -m Qwen2.5-32B-Instruct-Q4_K_M.gguf + -ngl 99 + name: "Qwen 32B (Distributed)" + description: "Large model using distributed RPC inference" + # Docker example: # container runtimes like Docker and Podman can be used reliably with # a combination of cmd, cmdStop, and ${MODEL_ID} diff --git a/docs/configuration.md b/docs/configuration.md index 32713d577..13b747d22 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -72,16 +72,17 @@ models: llama-swap supports many more features to customize how you want to manage your environment. -| Feature | Description | -| --------- | ---------------------------------------------- | -| `ttl` | automatic unloading of models after a timeout | -| `macros` | reusable snippets to use in configurations | -| `groups` | run multiple models at a time | -| `hooks` | event driven functionality | -| `env` | define environment variables per model | -| `aliases` | serve a model with different names | -| `filters` | modify requests before sending to the upstream | -| `...` | And many more tweaks | +| Feature | Description | +| ----------------- | ------------------------------------------------------- | +| `ttl` | automatic unloading of models after a timeout | +| `macros` | reusable snippets to use in configurations | +| `groups` | run multiple models at a time | +| `hooks` | event driven functionality | +| `env` | define environment variables per model | +| `aliases` | serve a model with different names | +| `filters` | modify requests before sending to the upstream | +| `rpcHealthCheck` | monitor RPC server health for distributed inference | +| `...` | And many more tweaks | ## Full Configuration Example diff --git a/proxy/config/config.go b/proxy/config/config.go index c4387f40a..1866d03f9 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "io" + "net" "net/url" "os" "regexp" @@ -533,6 +534,63 @@ func SanitizeCommand(cmdStr string) ([]string, error) { return args, nil } +// ParseRPCEndpoints extracts RPC endpoints from command string +// Handles: --rpc host:port,host2:port2 or --rpc=host:port or -rpc host:port +func ParseRPCEndpoints(cmdStr string) ([]string, error) { + args, err := SanitizeCommand(cmdStr) + if err != nil { + return nil, err + } + + var endpoints []string + for i, arg := range args { + if arg == "--rpc" || arg == "-rpc" { + if i+1 < len(args) { + endpoints = parseEndpointList(args[i+1]) + } + } else if strings.HasPrefix(arg, "--rpc=") { + endpoints = parseEndpointList(strings.TrimPrefix(arg, "--rpc=")) + } else if strings.HasPrefix(arg, "-rpc=") { + endpoints = parseEndpointList(strings.TrimPrefix(arg, "-rpc=")) + } + } + + // Validate each endpoint + for _, ep := range endpoints { + if _, _, err := net.SplitHostPort(ep); err != nil { + return nil, fmt.Errorf("invalid RPC endpoint %q: %w", ep, err) + } + } + + return endpoints, nil +} + +func parseEndpointList(s string) []string { + s = strings.TrimSpace(s) + + // Strip surrounding quotes (both single and double) from the whole string + // if they match. This handles cases like: "host:port,host2:port2" + if len(s) >= 2 { + if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"') { + s = s[1 : len(s)-1] + } + } + + parts := strings.Split(s, ",") + var result []string + for _, p := range parts { + p = strings.TrimSpace(p) + // Strip any remaining leading/trailing quotes from individual parts + // This handles Windows where shlex doesn't handle single quotes and + // may split 'host:port, host2:port' into "'host:port," and "host2:port'" + p = strings.Trim(p, "'\"") + if p != "" { + result = append(result, p) + } + } + return result +} + func StripComments(cmdStr string) string { var cleanedLines []string for _, line := range strings.Split(cmdStr, "\n") { diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go index a19cbb56a..11552f9d0 100644 --- a/proxy/config/config_test.go +++ b/proxy/config/config_test.go @@ -1309,3 +1309,108 @@ peers: assert.Contains(t, err.Error(), "unknown macro") }) } + +func TestParseRPCEndpoints_ValidFormats(t *testing.T) { + tests := []struct { + name string + cmd string + expected []string + }{ + { + name: "single endpoint with --rpc", + cmd: "llama-server --rpc localhost:50051 -ngl 99", + expected: []string{"localhost:50051"}, + }, + { + name: "single endpoint with --rpc=", + cmd: "llama-server --rpc=192.168.1.100:50051 -ngl 99", + expected: []string{"192.168.1.100:50051"}, + }, + { + name: "single endpoint with -rpc", + cmd: "llama-server -rpc localhost:50051 -ngl 99", + expected: []string{"localhost:50051"}, + }, + { + name: "single endpoint with -rpc=", + cmd: "llama-server -rpc=localhost:50051 -ngl 99", + expected: []string{"localhost:50051"}, + }, + { + name: "multiple endpoints comma-separated", + cmd: "llama-server --rpc 192.168.1.10:50051,192.168.1.11:50051 -ngl 99", + expected: []string{"192.168.1.10:50051", "192.168.1.11:50051"}, + }, + { + name: "multiple endpoints with spaces trimmed", + cmd: "llama-server --rpc '192.168.1.10:50051, 192.168.1.11:50051' -ngl 99", + expected: []string{"192.168.1.10:50051", "192.168.1.11:50051"}, + }, + { + name: "IPv6 endpoint", + cmd: "llama-server --rpc [::1]:50051 -ngl 99", + expected: []string{"[::1]:50051"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoints, err := ParseRPCEndpoints(tt.cmd) + assert.NoError(t, err) + assert.Equal(t, tt.expected, endpoints) + }) + } +} + +func TestParseRPCEndpoints_NoRPCFlag(t *testing.T) { + cmd := "llama-server -ngl 99 -m model.gguf" + endpoints, err := ParseRPCEndpoints(cmd) + assert.NoError(t, err) + assert.Empty(t, endpoints) +} + +func TestParseRPCEndpoints_InvalidFormats(t *testing.T) { + tests := []struct { + name string + cmd string + wantErr string + }{ + { + name: "missing port", + cmd: "llama-server --rpc localhost -ngl 99", + wantErr: "invalid RPC endpoint", + }, + { + name: "invalid host:port format", + cmd: "llama-server --rpc not-a-valid-endpoint -ngl 99", + wantErr: "invalid RPC endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseRPCEndpoints(tt.cmd) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestParseRPCEndpoints_EmptyEndpointsFiltered(t *testing.T) { + // Empty strings after commas are filtered out + cmd := "llama-server --rpc 'localhost:50051,,' -ngl 99" + endpoints, err := ParseRPCEndpoints(cmd) + assert.NoError(t, err) + assert.Equal(t, []string{"localhost:50051"}, endpoints) +} + +func TestParseRPCEndpoints_MultilineCommand(t *testing.T) { + cmd := `llama-server \ + --rpc localhost:50051 \ + -ngl 99 \ + -m model.gguf` + + endpoints, err := ParseRPCEndpoints(cmd) + assert.NoError(t, err) + assert.Equal(t, []string{"localhost:50051"}, endpoints) +} diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 6b2ba742a..92bed341f 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -37,6 +37,8 @@ type ModelConfig struct { // override global setting SendLoadingState *bool `yaml:"sendLoadingState"` + // RPC health checking + RPCHealthCheck bool `yaml:"rpcHealthCheck"` // Maximum time in seconds for a request to complete before killing the process // 0 means no timeout (default) RequestTimeout int `yaml:"requestTimeout"` @@ -57,6 +59,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { ConcurrencyLimit: 0, Name: "", Description: "", + RPCHealthCheck: false, RequestTimeout: 0, } diff --git a/proxy/process.go b/proxy/process.go index 7e311d11c..a464980eb 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -79,18 +79,25 @@ type Process struct { // track the number of failed starts failedStartCount int + + // RPC health checking + rpcEndpoints []string + rpcHealthy atomic.Bool + rpcHealthTicker *time.Ticker + rpcHealthCancel context.CancelFunc + shutdownCtx context.Context // from ProxyManager for graceful shutdown } -func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { +func NewProcess(ID string, healthCheckTimeout int, modelConfig config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor, shutdownCtx context.Context) *Process { concurrentLimit := 10 - if config.ConcurrencyLimit > 0 { - concurrentLimit = config.ConcurrencyLimit + if modelConfig.ConcurrencyLimit > 0 { + concurrentLimit = modelConfig.ConcurrencyLimit } // Setup the reverse proxy. - proxyURL, err := url.Parse(config.Proxy) + proxyURL, err := url.Parse(modelConfig.Proxy) if err != nil { - proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err) + proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, modelConfig.Proxy, err) } var reverseProxy *httputil.ReverseProxy @@ -105,9 +112,9 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr } } - return &Process{ + p := &Process{ ID: ID, - config: config, + config: modelConfig, cmd: nil, reverseProxy: reverseProxy, cancelUpstream: nil, @@ -124,7 +131,25 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr // stop timeout gracefulStopTimeout: 10 * time.Second, cmdWaitChan: make(chan struct{}), + shutdownCtx: shutdownCtx, } + + // Parse RPC endpoints if health checking enabled + if modelConfig.RPCHealthCheck { + endpoints, err := config.ParseRPCEndpoints(modelConfig.Cmd) + if err != nil { + proxyLogger.Errorf("<%s> failed to parse RPC endpoints: %v", ID, err) + } else if len(endpoints) == 0 { + proxyLogger.Warnf("<%s> rpcHealthCheck enabled but no --rpc flag found in cmd", ID) + } else { + p.rpcEndpoints = endpoints + p.rpcHealthy.Store(false) // start unhealthy until first check passes + // Start health checker immediately - runs independent of process state + p.startRPCHealthChecker() + } + } + + return p } // LogMonitor returns the log monitor associated with the process. @@ -909,3 +934,67 @@ func (s *statusResponseWriter) Flush() { flusher.Flush() } } + +// startRPCHealthChecker launches background goroutine for RPC health monitoring. +// Runs independently of process state - checks RPC endpoints regardless of whether +// the model is loaded, starting, stopped, etc. +func (p *Process) startRPCHealthChecker() { + if !p.config.RPCHealthCheck || len(p.rpcEndpoints) == 0 { + return + } + + ctx, cancel := context.WithCancel(p.shutdownCtx) + p.rpcHealthCancel = cancel + p.rpcHealthTicker = time.NewTicker(30 * time.Second) + + go func() { + defer p.rpcHealthTicker.Stop() + + // Run initial check immediately + p.checkRPCHealth() + + for { + select { + case <-ctx.Done(): + p.proxyLogger.Debugf("<%s> RPC health checker shutting down", p.ID) + return + case <-p.rpcHealthTicker.C: + // Check regardless of process state + p.checkRPCHealth() + } + } + }() +} + +func (p *Process) checkRPCHealth() { + allHealthy := true + + for _, endpoint := range p.rpcEndpoints { + dialer := net.Dialer{Timeout: 500 * time.Millisecond} + conn, err := dialer.Dial("tcp", endpoint) + if err != nil { + p.proxyLogger.Warnf("<%s> RPC endpoint %s unhealthy: %v", p.ID, endpoint, err) + allHealthy = false + break + } + conn.Close() + } + + wasHealthy := p.rpcHealthy.Load() + p.rpcHealthy.Store(allHealthy) + + // Log state changes + if wasHealthy && !allHealthy { + p.proxyLogger.Infof("<%s> RPC endpoints now UNHEALTHY", p.ID) + } else if !wasHealthy && allHealthy { + p.proxyLogger.Infof("<%s> RPC endpoints now HEALTHY", p.ID) + } +} + +// IsRPCHealthy returns true if RPC health checking is disabled or all endpoints healthy +func (p *Process) IsRPCHealthy() bool { + if !p.config.RPCHealthCheck || len(p.rpcEndpoints) == 0 { + return true // not using RPC health checks + } + return p.rpcHealthy.Load() +} diff --git a/proxy/process_rpc_health_test.go b/proxy/process_rpc_health_test.go new file mode 100644 index 000000000..cb9d1d259 --- /dev/null +++ b/proxy/process_rpc_health_test.go @@ -0,0 +1,84 @@ +package proxy + +import ( + "context" + "io" + "testing" + + "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/stretchr/testify/assert" +) + +func TestProcess_RPCHealthIndependentOfState(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + modelConfig := config.ModelConfig{ + Cmd: "llama-server --rpc 127.0.0.1:50051", + Proxy: "http://localhost:8080", + RPCHealthCheck: true, + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Verify endpoints were parsed + assert.NotEmpty(t, process.rpcEndpoints, "RPC endpoints should be parsed from cmd") + assert.Equal(t, []string{"127.0.0.1:50051"}, process.rpcEndpoints) + + // Initially should be unhealthy (false) until first check + assert.False(t, process.rpcHealthy.Load(), "RPC health should start as false") + + // Health checker should be running regardless of process state + assert.NotNil(t, process.rpcHealthTicker, "Health checker ticker should be running") + assert.NotNil(t, process.rpcHealthCancel, "Health checker should have cancel func") + + // Process state should not affect health checking + assert.Equal(t, StateStopped, process.CurrentState(), "Process should be in stopped state") + + // Health check runs independently - simulate RPC becoming healthy + process.rpcHealthy.Store(true) + assert.True(t, process.IsRPCHealthy(), "Process should report healthy regardless of state") +} + +func TestProcess_RPCHealthCheckDisabled(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx := context.Background() + + modelConfig := config.ModelConfig{ + Cmd: "llama-server --rpc 127.0.0.1:50051", + Proxy: "http://localhost:8080", + RPCHealthCheck: false, // Disabled + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Should always return healthy when disabled + assert.True(t, process.IsRPCHealthy(), "Should return true when RPC health check is disabled") +} + +func TestProcess_RPCHealthCheckNoEndpoints(t *testing.T) { + testLogger := NewLogMonitorWriter(io.Discard) + proxyLogger := NewLogMonitorWriter(io.Discard) + ctx := context.Background() + + modelConfig := config.ModelConfig{ + Cmd: "llama-server --port 8080", // No --rpc flag + Proxy: "http://localhost:8080", + RPCHealthCheck: true, // Enabled but no endpoints + } + + process := NewProcess("test-model", 5, modelConfig, testLogger, proxyLogger, ctx) + + // Should have no endpoints + assert.Empty(t, process.rpcEndpoints, "Should have no RPC endpoints when --rpc flag is missing") + + // Should return healthy when no endpoints configured (treat as not using RPC) + assert.True(t, process.IsRPCHealthy(), "Should return true when no RPC endpoints found") + + // Health checker should NOT start when no endpoints + assert.Nil(t, process.rpcHealthTicker, "Health checker should not run without endpoints") + assert.Nil(t, process.rpcHealthCancel, "Health checker cancel should be nil") +} diff --git a/proxy/process_test.go b/proxy/process_test.go index 3881c3dde..87e31d6dd 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -33,7 +34,7 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { config := getTestSimpleResponderConfig(expectedMessage) // Create a process - process := NewProcess("test-process", 5, config, debugLogger, debugLogger) + process := NewProcess("test-process", 5, config, debugLogger, debugLogger, context.Background()) defer process.Stop() req := httptest.NewRequest("GET", "/test", nil) @@ -69,7 +70,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) { expectedMessage := "testing91931" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("test-process", 5, config, debugLogger, debugLogger) + process := NewProcess("test-process", 5, config, debugLogger, debugLogger, context.Background()) defer process.Stop() var wg sync.WaitGroup @@ -97,7 +98,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("broken", 1, config, debugLogger, debugLogger) + process := NewProcess("broken", 1, config, debugLogger, debugLogger, context.Background()) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -122,7 +123,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { config.UnloadAfter = 3 // seconds assert.Equal(t, 3, config.UnloadAfter) - process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) + process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger, context.Background()) defer process.Stop() // this should take 4 seconds @@ -164,7 +165,7 @@ func TestProcess_LowTTLValue(t *testing.T) { config.UnloadAfter = 1 // second assert.Equal(t, 1, config.UnloadAfter) - process := NewProcess("ttl", 2, config, debugLogger, debugLogger) + process := NewProcess("ttl", 2, config, debugLogger, debugLogger, context.Background()) defer process.Stop() for i := 0; i < 100; i++ { @@ -191,7 +192,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { expectedMessage := "12345" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("t", 10, config, debugLogger, debugLogger) + process := NewProcess("t", 10, config, debugLogger, debugLogger, context.Background()) defer process.Stop() results := map[string]string{ @@ -264,7 +265,7 @@ func TestProcess_SwapState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) + p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger, context.Background()) p.state = test.currentState resultState, err := p.swapState(test.expectedState, test.newState) @@ -297,7 +298,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { config.Proxy = "http://localhost:9998/test" healthCheckTTLSeconds := 30 - process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) + process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger, context.Background()) // make it a lot faster process.healthCheckLoopInterval = time.Second @@ -332,7 +333,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger) + process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger, context.Background()) process.healthCheckLoopInterval = time.Second // make it faster err := process.start() assert.Equal(t, "upstream command exited prematurely but successfully", err.Error()) @@ -350,7 +351,7 @@ func TestProcess_ConcurrencyLimit(t *testing.T) { // only allow 1 concurrent request at a time config.ConcurrencyLimit = 1 - process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) + process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger, context.Background()) assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore)) defer process.Stop() @@ -375,7 +376,7 @@ func TestProcess_StopImmediately(t *testing.T) { expectedMessage := "test_stop_immediate" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) + process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger, context.Background()) defer process.Stop() err := process.start() @@ -415,7 +416,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger) + process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger, context.Background()) defer process.Stop() // reduce to make testing go faster @@ -465,7 +466,7 @@ func TestProcess_StopCmd(t *testing.T) { conf.CmdStop = "kill -TERM ${PID}" } - process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger) + process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger, context.Background()) defer process.Stop() err := process.start() @@ -485,8 +486,8 @@ func TestProcess_EnvironmentSetCorrectly(t *testing.T) { // ensure the additiona variables are appended to the process' environment configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2") - process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger) - process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger) + process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger, context.Background()) + process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger, context.Background()) process1.start() defer process1.Stop() @@ -521,7 +522,7 @@ func TestProcess_ReverseProxyPanicIsHandled(t *testing.T) { expectedMessage := "panic_test" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("panic-test", 5, config, debugLogger, debugLogger) + process := NewProcess("panic-test", 5, config, debugLogger, debugLogger, context.Background()) defer process.Stop() // Start the process diff --git a/proxy/processgroup.go b/proxy/processgroup.go index b401d8a68..c920f302c 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "fmt" "net/http" "slices" @@ -24,9 +25,11 @@ type ProcessGroup struct { // map of current processes processes map[string]*Process lastUsedProcess string + + shutdownCtx context.Context } -func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { +func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor, shutdownCtx context.Context) *ProcessGroup { groupConfig, ok := config.Groups[id] if !ok { panic("Unable to find configuration for group id: " + id) @@ -41,13 +44,14 @@ func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, u proxyLogger: proxyLogger, upstreamLogger: upstreamLogger, processes: make(map[string]*Process), + shutdownCtx: shutdownCtx, } // Create a Process for each member in the group for _, modelID := range groupConfig.Members { modelConfig, modelID, _ := pg.config.FindConfig(modelID) processLogger := NewLogMonitorWriter(upstreamLogger) - process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger) + process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, processLogger, pg.proxyLogger, shutdownCtx) pg.processes[modelID] = process } diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index 6b90f4433..55e5276a8 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -2,6 +2,7 @@ package proxy import ( "bytes" + "context" "net/http" "net/http/httptest" "sync" @@ -35,12 +36,12 @@ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{ }) func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { - pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger, context.Background()) assert.True(t, pg.HasMember("model5")) } func TestProcessGroup_HasMember(t *testing.T) { - pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger, context.Background()) assert.True(t, pg.HasMember("model1")) assert.True(t, pg.HasMember("model2")) assert.False(t, pg.HasMember("model3")) @@ -74,7 +75,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { }, }) - pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger, context.Background()) defer pg.StopProcesses(StopWaitForInflightRequest) tests := []string{"model1", "model2", "model3", "model4", "model5"} @@ -96,7 +97,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { } func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { - pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) + pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger, context.Background()) defer pg.StopProcesses(StopWaitForInflightRequest) tests := []string{"model3", "model4"} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index c33c9f960..3bb33b73b 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -173,7 +173,7 @@ func New(proxyConfig config.Config) *ProxyManager { // create the process groups for groupID := range proxyConfig.Groups { - processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger) + processGroup := NewProcessGroup(groupID, proxyConfig, proxyLogger, upstreamLogger, shutdownCtx) pm.processGroups[groupID] = processGroup } @@ -481,6 +481,16 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { continue } + // Filter models with unhealthy RPC endpoints + if processGroup := pm.findGroupByModelName(id); processGroup != nil { + if process, ok := processGroup.GetMember(id); ok { + if !process.IsRPCHealthy() { + pm.proxyLogger.Debugf("<%s> filtered from /v1/models (unhealthy RPC)", id) + continue + } + } + } + data = append(data, newRecord(id, modelConfig)) // Include aliases @@ -633,6 +643,15 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { return } + // Check RPC health before processing request + if process, ok := processGroup.GetMember(modelID); ok { + if !process.IsRPCHealthy() { + pm.sendErrorResponse(c, http.StatusServiceUnavailable, + fmt.Sprintf("model %s unavailable (RPC endpoints unhealthy)", modelID)) + return + } + } + // issue #69 allow custom model names to be sent to upstream useModelName := pm.config.Models[modelID].UseModelName if useModelName != "" {