diff --git a/README.md b/README.md index c2696235..8d372c10 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,8 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and - ✅ API Key support - define keys to restrict access to API endpoints - ✅ Customizable - Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - - Automatic unloading of models after timeout by setting a `ttl` + - Automatic unloading of models after idle timeout by setting a `ttl` + - Request timeout protection with `requestTimeout` to prevent runaway inference - Reliable Docker and Podman support using `cmd` and `cmdStop` together - Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) diff --git a/config-schema.json b/config-schema.json index 8baa0cc4..9b77344a 100644 --- a/config-schema.json +++ b/config-schema.json @@ -216,6 +216,12 @@ "type": "boolean", "description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting." }, + "requestTimeout": { + "type": "integer", + "minimum": 0, + "default": 0, + "description": "Maximum time in seconds for a single request to complete before forcefully killing the model process. This prevents runaway inference processes from blocking the GPU indefinitely. 0 disables timeout (default). When exceeded, the process is terminated and must be restarted for the next request." + }, "unlisted": { "type": "boolean", "default": false, diff --git a/config.example.yaml b/config.example.yaml index d8282fc1..0ef80c02 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -249,6 +249,16 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # requestTimeout: maximum time in seconds for a single request to complete + # - optional, default: 0 (no timeout) + # - useful for preventing runaway inference processes that never complete + # - when exceeded, the model process is forcefully stopped + # - protects against GPU overheating and blocking from stuck processes + # - the process must be restarted for the next request + # - set to 0 to disable timeout + # - recommended for models that may have infinite loops or excessive generation + requestTimeout: 0 # disabled by default, set to e.g., 300 for 5 minutes + # sendLoadingState: overrides the global sendLoadingState setting for this model # - optional, default: undefined (use global setting) sendLoadingState: false diff --git a/docs/configuration.md b/docs/configuration.md index 5aac2706..32713d57 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -319,6 +319,16 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # requestTimeout: maximum time in seconds for a single request to complete + # - optional, default: 0 (no timeout) + # - useful for preventing runaway inference processes that never complete + # - when exceeded, the model process is forcefully stopped + # - protects against GPU overheating and blocking from stuck processes + # - the process must be restarted for the next request + # - set to 0 to disable timeout + # - recommended for models that may have infinite loops or excessive generation + requestTimeout: 300 # 5 minutes + # sendLoadingState: overrides the global sendLoadingState setting for this model # - optional, default: undefined (use global setting) sendLoadingState: false diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 9dc37aea..6b2ba742 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -36,6 +36,10 @@ type ModelConfig struct { // override global setting SendLoadingState *bool `yaml:"sendLoadingState"` + + // Maximum time in seconds for a request to complete before killing the process + // 0 means no timeout (default) + RequestTimeout int `yaml:"requestTimeout"` } func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { @@ -53,6 +57,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { ConcurrencyLimit: 0, Name: "", Description: "", + RequestTimeout: 0, } // the default cmdStop to taskkill /f /t /pid ${PID} diff --git a/proxy/process.go b/proxy/process.go index 41427059..fce7a705 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -49,6 +49,7 @@ type Process struct { // PR #155 called to cancel the upstream process cmdMutex sync.RWMutex cancelUpstream context.CancelFunc + cmdStarted bool // tracks if cmd.Start() completed successfully // closed when command exits cmdWaitChan chan struct{} @@ -250,26 +251,43 @@ func (p *Process) start() error { defer p.waitStarting.Done() cmdContext, ctxCancelUpstream := context.WithCancel(context.Background()) - p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...) - p.cmd.Stdout = p.processLogger - p.cmd.Stderr = p.processLogger - p.cmd.Env = append(p.cmd.Environ(), p.config.Env...) - p.cmd.Cancel = p.cmdStopUpstreamProcess - p.cmd.WaitDelay = p.gracefulStopTimeout - setProcAttributes(p.cmd) + cmd := exec.CommandContext(cmdContext, args[0], args[1:]...) + cmd.Stdout = p.processLogger + cmd.Stderr = p.processLogger + cmd.Env = append(cmd.Environ(), p.config.Env...) + cmd.Cancel = p.cmdStopUpstreamProcess + cmd.WaitDelay = p.gracefulStopTimeout + setProcAttributes(cmd) + p.failedStartCount++ // this will be reset to zero when the process has successfully started + + // Initialize cancelUpstream and cmdWaitChan before Start() so stopCommand() can use them + // if called during startup (e.g., due to timeout) p.cmdMutex.Lock() + p.cmd = cmd p.cancelUpstream = ctxCancelUpstream p.cmdWaitChan = make(chan struct{}) p.cmdMutex.Unlock() - p.failedStartCount++ // this will be reset to zero when the process has successfully started - p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", ")) err = p.cmd.Start() + // Set cmdStarted flag under lock after Start() completes + // This prevents data race with stopCommand() which checks cmdStarted instead of cmd.Process + p.cmdMutex.Lock() + if err == nil { + p.cmdStarted = true + } + p.cmdMutex.Unlock() + // Set process state to failed if err != nil { + // Close cmdWaitChan to prevent stopCommand() from hanging if a timeout + // transitions StateStarting -> StateStopping before Start() completes + p.cmdMutex.Lock() + close(p.cmdWaitChan) + p.cmdMutex.Unlock() + if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { p.forceState(StateStopped) // force it into a stopped state return fmt.Errorf( @@ -381,14 +399,28 @@ func (p *Process) Stop() { // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. func (p *Process) StopImmediately() { - if !isValidTransition(p.CurrentState(), StateStopping) { - return - } + // Try to transition from current state to StateStopping + // Process might be in StateReady or StateStarting when timeout fires + // Retry on ErrExpectedStateMismatch to handle transient state changes + for { + currentState := p.CurrentState() + if !isValidTransition(currentState, StateStopping) { + return + } - p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState()) - if curState, err := p.swapState(StateReady, StateStopping); err != nil { - p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) - return + p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, currentState) + + if _, err := p.swapState(currentState, StateStopping); err != nil { + if err == ErrExpectedStateMismatch { + // State changed between CurrentState() and swapState(), retry + continue + } + p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v", p.ID, currentState, err) + return + } + + // Successfully transitioned to StateStopping + break } p.stopCommand() @@ -422,14 +454,28 @@ func (p *Process) stopCommand() { p.cmdMutex.RLock() cancelUpstream := p.cancelUpstream cmdWaitChan := p.cmdWaitChan + cmdStarted := p.cmdStarted p.cmdMutex.RUnlock() + // If cancelUpstream is nil, the process was never actually started + // (e.g., forceState() was used in tests). Just return silently. if cancelUpstream == nil { - p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID) + p.proxyLogger.Debugf("<%s> stopCommand: cancelUpstream is nil, process was never started", p.ID) return } + // Always cancel the context to stop the command cancelUpstream() + + // If cmdStarted is false, the process never actually started + // (cmd.Start() was never called or failed), so skip waiting on cmdWaitChan + // to avoid hanging. This can happen if a timeout transitions StateStarting + // to StateStopping before cmd.Start() completes. + if !cmdStarted { + p.proxyLogger.Debugf("<%s> stopCommand: process never started (cmdStarted is false), skipping wait", p.ID) + return + } + <-cmdWaitChan } @@ -500,6 +546,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { p.inFlightRequests.Done() }() + // Start timeout monitoring if requestTimeout is configured + var requestCancel context.CancelFunc + var timeoutCancel context.CancelFunc + + if p.config.RequestTimeout > 0 { + timeoutDuration := time.Duration(p.config.RequestTimeout) * time.Second + + // Add timeout to request context to cancel the request when exceeded + requestCtx, cancel := context.WithTimeout(r.Context(), timeoutDuration) + requestCancel = cancel + r = r.Clone(requestCtx) + + // Create a separate timeout context for monitoring only + // Use context.Background() to ensure we detect our configured timeout, + // not parent-imposed deadlines that would cause misattribution + timeoutCtx, cancel := context.WithTimeout(context.Background(), timeoutDuration) + timeoutCancel = cancel + + go func() { + <-timeoutCtx.Done() + if timeoutCtx.Err() == context.DeadlineExceeded { + p.proxyLogger.Warnf("<%s> Request timeout exceeded (%v), force stopping process to prevent GPU blocking", p.ID, timeoutDuration) + // Force stop the process - this will kill the underlying inference process + p.StopImmediately() + } + }() + + // Ensure both timeouts are cancelled when request completes + defer func() { + if requestCancel != nil { + requestCancel() + } + if timeoutCancel != nil { + timeoutCancel() + } + }() + } + // for #366 // - extract streaming param from request context, should have been set by proxymanager var srw *statusResponseWriter @@ -606,6 +690,7 @@ func (p *Process) waitForCmd() { p.cmdMutex.Lock() close(p.cmdWaitChan) + p.cmdStarted = false p.cmdMutex.Unlock() } diff --git a/proxy/process_test.go b/proxy/process_test.go index 3881c3dd..eac4873d 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -18,6 +18,15 @@ var ( debugLogger = NewLogMonitorWriter(os.Stdout) ) +// getSleepCommand returns a platform-appropriate command to sleep for the given number of seconds +func getSleepCommand(seconds int) string { + if runtime.GOOS == "windows" { + // Use full path to avoid conflict with GNU coreutils timeout in Git Bash + return fmt.Sprintf("C:\\Windows\\System32\\timeout.exe /t %d /nobreak", seconds) + } + return fmt.Sprintf("sleep %d", seconds) +} + func init() { // flip to help with debugging tests if false { @@ -569,3 +578,106 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) { } return w.ResponseRecorder.Write(b) } + +// TestProcess_StopCommandDoesNotHangWhenStartFails verifies that stopCommand() +// does not hang when cmd.Start() fails or hasn't completed yet. This can happen +// when a timeout transitions StateStarting -> StateStopping before cmd.Start() +// completes. +func TestProcess_StopCommandDoesNotHangWhenStartFails(t *testing.T) { + // Create a process with a command that will fail to start + config := config.ModelConfig{ + Cmd: "nonexistent-command-that-will-fail", + Proxy: "http://127.0.0.1:9999", + CheckEndpoint: "/health", + } + + process := NewProcess("fail-test", 1, config, debugLogger, debugLogger) + + // Try to start the process - this will fail + err := process.start() + assert.Error(t, err) + assert.Contains(t, err.Error(), "start() failed for command") + assert.Equal(t, StateStopped, process.CurrentState()) + + // Now try to stop the process - this should not hang + // Create a channel to track if stopCommand completes + done := make(chan struct{}) + go func() { + process.stopCommand() + close(done) + }() + + // Wait for stopCommand to complete with a timeout + select { + case <-done: + // Success - stopCommand completed without hanging + case <-time.After(2 * time.Second): + t.Fatal("stopCommand() hung when process never started") + } +} + +// TestProcess_StopImmediatelyDuringStartup verifies that StopImmediately() +// can safely interrupt a process during StateStarting without hanging. +func TestProcess_StopImmediatelyDuringStartup(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test") + } + + // Use a platform-appropriate command that takes a while to start but won't respond to health checks + cmd := getSleepCommand(10) + config := config.ModelConfig{ + Cmd: cmd, + Proxy: "http://127.0.0.1:9999", + CheckEndpoint: "/health", + } + + process := NewProcess("interrupt-test", 20, config, debugLogger, debugLogger) + process.healthCheckLoopInterval = 100 * time.Millisecond + + // Start the process in a goroutine (it will be in StateStarting) + startDone := make(chan struct{}) + errCh := make(chan error, 1) + go func() { + err := process.start() + errCh <- err + close(startDone) + }() + + // Wait a bit for the process to enter StateStarting + <-time.After(200 * time.Millisecond) + currentState := process.CurrentState() + assert.Equal(t, StateStarting, currentState, "Process should be in StateStarting") + + // Now call StopImmediately while in StateStarting + // This simulates a timeout firing during startup + stopDone := make(chan struct{}) + go func() { + process.StopImmediately() + close(stopDone) + }() + + // Verify StopImmediately completes without hanging + select { + case <-stopDone: + // Success - StopImmediately completed + case <-time.After(3 * time.Second): + t.Fatal("StopImmediately() hung when called during StateStarting") + } + + // Wait for start() to complete + select { + case <-startDone: + // Success + case <-time.After(2 * time.Second): + t.Fatal("start() did not complete after StopImmediately") + } + + // Verify start() returned an error due to StopImmediately interrupt + err := <-errCh + assert.Error(t, err) + + // Process should be in StateStopped or StateStopping + finalState := process.CurrentState() + assert.True(t, finalState == StateStopped || finalState == StateStopping, + "Expected StateStopped or StateStopping, got %s", finalState) +} diff --git a/proxy/process_timeout_test.go b/proxy/process_timeout_test.go new file mode 100644 index 00000000..0f889e6a --- /dev/null +++ b/proxy/process_timeout_test.go @@ -0,0 +1,125 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/proxy/config" +) + +// TestProcess_RequestTimeout verifies that requestTimeout actually kills the process +func TestProcess_RequestTimeout(t *testing.T) { + // Create error channel to report handler errors from the mock server goroutine + srvErrCh := make(chan error, 1) + + // Create a mock server that simulates a long-running inference + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Mock server received request") + + // Simulate streaming response that takes 60 seconds + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + srvErrCh <- fmt.Errorf("Expected http.ResponseWriter to be an http.Flusher") + return + } + + // Stream data for 60 seconds + for i := 0; i < 60; i++ { + select { + case <-r.Context().Done(): + t.Logf("Mock server: client disconnected after %d seconds", i) + return + default: + fmt.Fprintf(w, "data: token %d\n\n", i) + flusher.Flush() + time.Sleep(1 * time.Second) + } + } + t.Logf("Mock server completed full 60 second response") + })) + defer mockServer.Close() + + // Setup process logger - use NewLogMonitor() to avoid race in test + processLogger := NewLogMonitor() + proxyLogger := NewLogMonitor() + + // Create process with 5 second request timeout + cfg := config.ModelConfig{ + Proxy: mockServer.URL, + CheckEndpoint: "none", // skip health check + RequestTimeout: 5, // 5 second timeout + } + + p := NewProcess("test-timeout", 30, cfg, processLogger, proxyLogger) + p.gracefulStopTimeout = 2 * time.Second // shorter for testing + + // Manually set state to ready (skip actual process start) + p.forceState(StateReady) + + // Make a request that should timeout + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + start := time.Now() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + p.ProxyRequest(w, req) + }() + + // Wait for either completion or timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case err := <-srvErrCh: + // Handler error - fail the test immediately + t.Fatalf("Mock server handler error: %v", err) + + case <-done: + elapsed := time.Since(start) + t.Logf("Request completed after %v", elapsed) + + // Check for any deferred server errors + select { + case err := <-srvErrCh: + t.Fatalf("Mock server handler error: %v", err) + default: + // No server errors, continue with assertions + } + + // Request should complete within timeout + gracefulStopTimeout + some buffer + maxExpected := time.Duration(cfg.RequestTimeout+2)*time.Second + 3*time.Second + if elapsed > maxExpected { + t.Errorf("Request took %v, expected less than %v with 5s timeout", elapsed, maxExpected) + } else { + t.Logf("✓ Request was properly terminated by timeout") + } + + case <-time.After(15 * time.Second): + t.Fatalf("Test timed out after 15 seconds - request should have been killed by requestTimeout") + } +} + +// TestProcess_RequestTimeoutWithRealProcess tests with an actual process +func TestProcess_RequestTimeoutWithRealProcess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test with real process in short mode") + } + + // This test would require a real llama.cpp server or similar + // For now, we can skip it or mock it + t.Skip("Requires real inference server") +}