diff --git a/README.md b/README.md index c26962351..8d372c108 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,8 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and - ✅ API Key support - define keys to restrict access to API endpoints - ✅ Customizable - Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - - Automatic unloading of models after timeout by setting a `ttl` + - Automatic unloading of models after idle timeout by setting a `ttl` + - Request timeout protection with `requestTimeout` to prevent runaway inference - Reliable Docker and Podman support using `cmd` and `cmdStop` together - Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235)) diff --git a/config-schema.json b/config-schema.json index 8baa0cc43..9b77344ad 100644 --- a/config-schema.json +++ b/config-schema.json @@ -216,6 +216,12 @@ "type": "boolean", "description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting." }, + "requestTimeout": { + "type": "integer", + "minimum": 0, + "default": 0, + "description": "Maximum time in seconds for a single request to complete before forcefully killing the model process. This prevents runaway inference processes from blocking the GPU indefinitely. 0 disables timeout (default). When exceeded, the process is terminated and must be restarted for the next request." + }, "unlisted": { "type": "boolean", "default": false, diff --git a/config.example.yaml b/config.example.yaml index d8282fc17..0ef80c02a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -249,6 +249,16 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # requestTimeout: maximum time in seconds for a single request to complete + # - optional, default: 0 (no timeout) + # - useful for preventing runaway inference processes that never complete + # - when exceeded, the model process is forcefully stopped + # - protects against GPU overheating and blocking from stuck processes + # - the process must be restarted for the next request + # - set to 0 to disable timeout + # - recommended for models that may have infinite loops or excessive generation + requestTimeout: 0 # disabled by default, set to e.g., 300 for 5 minutes + # sendLoadingState: overrides the global sendLoadingState setting for this model # - optional, default: undefined (use global setting) sendLoadingState: false diff --git a/docs/configuration.md b/docs/configuration.md index 5aac2706c..32713d577 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -319,6 +319,16 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # requestTimeout: maximum time in seconds for a single request to complete + # - optional, default: 0 (no timeout) + # - useful for preventing runaway inference processes that never complete + # - when exceeded, the model process is forcefully stopped + # - protects against GPU overheating and blocking from stuck processes + # - the process must be restarted for the next request + # - set to 0 to disable timeout + # - recommended for models that may have infinite loops or excessive generation + requestTimeout: 300 # 5 minutes + # sendLoadingState: overrides the global sendLoadingState setting for this model # - optional, default: undefined (use global setting) sendLoadingState: false diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 9dc37aea6..6b2ba742a 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -36,6 +36,10 @@ type ModelConfig struct { // override global setting SendLoadingState *bool `yaml:"sendLoadingState"` + + // Maximum time in seconds for a request to complete before killing the process + // 0 means no timeout (default) + RequestTimeout int `yaml:"requestTimeout"` } func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { @@ -53,6 +57,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { ConcurrencyLimit: 0, Name: "", Description: "", + RequestTimeout: 0, } // the default cmdStop to taskkill /f /t /pid ${PID} diff --git a/proxy/process.go b/proxy/process.go index 414270595..7e311d11c 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -381,13 +381,17 @@ func (p *Process) Stop() { // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. func (p *Process) StopImmediately() { - if !isValidTransition(p.CurrentState(), StateStopping) { + currentState := p.CurrentState() + if !isValidTransition(currentState, StateStopping) { return } - p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState()) - if curState, err := p.swapState(StateReady, StateStopping); err != nil { - p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) + p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, currentState) + + // Try to transition from current state to StateStopping + // Process might be in StateReady or StateStarting when timeout fires + if _, err := p.swapState(currentState, StateStopping); err != nil { + p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v", p.ID, currentState, err) return } @@ -500,6 +504,34 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { p.inFlightRequests.Done() }() + // Start timeout monitoring if requestTimeout is configured + var timeoutCancel context.CancelFunc + var requestCtx context.Context = r.Context() + + if p.config.RequestTimeout > 0 { + timeoutDuration := time.Duration(p.config.RequestTimeout) * time.Second + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(r.Context(), timeoutDuration) + timeoutCancel = cancel + + go func() { + <-requestCtx.Done() + if requestCtx.Err() == context.DeadlineExceeded { + p.proxyLogger.Warnf("<%s> Request timeout exceeded (%v), force stopping process to prevent GPU blocking", p.ID, timeoutDuration) + // Force stop the process - this will kill the underlying inference process + p.StopImmediately() + } + }() + + // Ensure timeout is cancelled when request completes + defer timeoutCancel() + } + + // Create a new request with the timeout context + if requestCtx != r.Context() { + r = r.Clone(requestCtx) + } + // for #366 // - extract streaming param from request context, should have been set by proxymanager var srw *statusResponseWriter diff --git a/proxy/process_timeout_test.go b/proxy/process_timeout_test.go new file mode 100644 index 000000000..9f048d9e2 --- /dev/null +++ b/proxy/process_timeout_test.go @@ -0,0 +1,109 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/mostlygeek/llama-swap/proxy/config" +) + +// TestProcess_RequestTimeout verifies that requestTimeout actually kills the process +func TestProcess_RequestTimeout(t *testing.T) { + // Create a mock server that simulates a long-running inference + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Mock server received request") + + // Simulate streaming response that takes 60 seconds + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("Expected http.ResponseWriter to be an http.Flusher") + } + + // Stream data for 60 seconds + for i := 0; i < 60; i++ { + select { + case <-r.Context().Done(): + t.Logf("Mock server: client disconnected after %d seconds", i) + return + default: + fmt.Fprintf(w, "data: token %d\n\n", i) + flusher.Flush() + time.Sleep(1 * time.Second) + } + } + t.Logf("Mock server completed full 60 second response") + })) + defer mockServer.Close() + + // Setup process logger - use NewLogMonitor() to avoid race in test + processLogger := NewLogMonitor() + proxyLogger := NewLogMonitor() + + // Create process with 5 second request timeout + cfg := config.ModelConfig{ + Proxy: mockServer.URL, + CheckEndpoint: "none", // skip health check + RequestTimeout: 5, // 5 second timeout + } + + p := NewProcess("test-timeout", 30, cfg, processLogger, proxyLogger) + p.gracefulStopTimeout = 2 * time.Second // shorter for testing + + // Manually set state to ready (skip actual process start) + p.forceState(StateReady) + + // Make a request that should timeout + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) + w := httptest.NewRecorder() + + start := time.Now() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + p.ProxyRequest(w, req) + }() + + // Wait for either completion or timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + elapsed := time.Since(start) + t.Logf("Request completed after %v", elapsed) + + // Request should complete within timeout + gracefulStopTimeout + some buffer + maxExpected := time.Duration(cfg.RequestTimeout+2)*time.Second + 3*time.Second + if elapsed > maxExpected { + t.Errorf("Request took %v, expected less than %v with 5s timeout", elapsed, maxExpected) + } else { + t.Logf("✓ Request was properly terminated by timeout") + } + + case <-time.After(15 * time.Second): + t.Fatalf("Test timed out after 15 seconds - request should have been killed by requestTimeout") + } +} + +// TestProcess_RequestTimeoutWithRealProcess tests with an actual process +func TestProcess_RequestTimeoutWithRealProcess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test with real process in short mode") + } + + // This test would require a real llama.cpp server or similar + // For now, we can skip it or mock it + t.Skip("Requires real inference server") +}