Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl`
- Automatic unloading of models after idle timeout by setting a `ttl`
- Request timeout protection with `requestTimeout` to prevent runaway inference
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))

Expand Down
6 changes: 6 additions & 0 deletions config-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@
"type": "boolean",
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
},
"requestTimeout": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Maximum time in seconds for a single request to complete before forcefully killing the model process. This prevents runaway inference processes from blocking the GPU indefinitely. 0 disables timeout (default). When exceeded, the process is terminated and must be restarted for the next request."
},
"unlisted": {
"type": "boolean",
"default": false,
Expand Down
10 changes: 10 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0

# requestTimeout: maximum time in seconds for a single request to complete
# - optional, default: 0 (no timeout)
# - useful for preventing runaway inference processes that never complete
# - when exceeded, the model process is forcefully stopped
# - protects against GPU overheating and blocking from stuck processes
# - the process must be restarted for the next request
# - set to 0 to disable timeout
# - recommended for models that may have infinite loops or excessive generation
requestTimeout: 0 # disabled by default, set to e.g., 300 for 5 minutes

# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
Expand Down
10 changes: 10 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,16 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0

# requestTimeout: maximum time in seconds for a single request to complete
# - optional, default: 0 (no timeout)
# - useful for preventing runaway inference processes that never complete
# - when exceeded, the model process is forcefully stopped
# - protects against GPU overheating and blocking from stuck processes
# - the process must be restarted for the next request
# - set to 0 to disable timeout
# - recommended for models that may have infinite loops or excessive generation
requestTimeout: 300 # 5 minutes

# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
Expand Down
5 changes: 5 additions & 0 deletions proxy/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type ModelConfig struct {

// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`

// Maximum time in seconds for a request to complete before killing the process
// 0 means no timeout (default)
RequestTimeout int `yaml:"requestTimeout"`
}

func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
Expand All @@ -53,6 +57,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
ConcurrencyLimit: 0,
Name: "",
Description: "",
RequestTimeout: 0,
}

// the default cmdStop to taskkill /f /t /pid ${PID}
Expand Down
40 changes: 36 additions & 4 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,17 @@ func (p *Process) Stop() {
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
currentState := p.CurrentState()
if !isValidTransition(currentState, StateStopping) {
return
}

p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, currentState)

// Try to transition from current state to StateStopping
// Process might be in StateReady or StateStarting when timeout fires
if _, err := p.swapState(currentState, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v", p.ID, currentState, err)
return
}

Expand Down Expand Up @@ -500,6 +504,34 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Done()
}()

// Start timeout monitoring if requestTimeout is configured
var timeoutCancel context.CancelFunc
var requestCtx context.Context = r.Context()

if p.config.RequestTimeout > 0 {
timeoutDuration := time.Duration(p.config.RequestTimeout) * time.Second
var cancel context.CancelFunc
requestCtx, cancel = context.WithTimeout(r.Context(), timeoutDuration)
timeoutCancel = cancel

go func() {
<-requestCtx.Done()
if requestCtx.Err() == context.DeadlineExceeded {
p.proxyLogger.Warnf("<%s> Request timeout exceeded (%v), force stopping process to prevent GPU blocking", p.ID, timeoutDuration)
// Force stop the process - this will kill the underlying inference process
p.StopImmediately()
}
}()

// Ensure timeout is cancelled when request completes
defer timeoutCancel()
}

// Create a new request with the timeout context
if requestCtx != r.Context() {
r = r.Clone(requestCtx)
}

// for #366
// - extract streaming param from request context, should have been set by proxymanager
var srw *statusResponseWriter
Expand Down
109 changes: 109 additions & 0 deletions proxy/process_timeout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package proxy

import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/mostlygeek/llama-swap/proxy/config"
)

// TestProcess_RequestTimeout verifies that requestTimeout actually kills the process
func TestProcess_RequestTimeout(t *testing.T) {
// Create a mock server that simulates a long-running inference
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("Mock server received request")

// Simulate streaming response that takes 60 seconds
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)

flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("Expected http.ResponseWriter to be an http.Flusher")
}

// Stream data for 60 seconds
for i := 0; i < 60; i++ {
select {
case <-r.Context().Done():
t.Logf("Mock server: client disconnected after %d seconds", i)
return
default:
fmt.Fprintf(w, "data: token %d\n\n", i)
flusher.Flush()
time.Sleep(1 * time.Second)
}
}
t.Logf("Mock server completed full 60 second response")
}))
defer mockServer.Close()

// Setup process logger - use NewLogMonitor() to avoid race in test
processLogger := NewLogMonitor()
proxyLogger := NewLogMonitor()

// Create process with 5 second request timeout
cfg := config.ModelConfig{
Proxy: mockServer.URL,
CheckEndpoint: "none", // skip health check
RequestTimeout: 5, // 5 second timeout
}

p := NewProcess("test-timeout", 30, cfg, processLogger, proxyLogger)
p.gracefulStopTimeout = 2 * time.Second // shorter for testing

// Manually set state to ready (skip actual process start)
p.forceState(StateReady)

// Make a request that should timeout
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
w := httptest.NewRecorder()

start := time.Now()
var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()
p.ProxyRequest(w, req)
}()

// Wait for either completion or timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
elapsed := time.Since(start)
t.Logf("Request completed after %v", elapsed)

// Request should complete within timeout + gracefulStopTimeout + some buffer
maxExpected := time.Duration(cfg.RequestTimeout+2)*time.Second + 3*time.Second
if elapsed > maxExpected {
t.Errorf("Request took %v, expected less than %v with 5s timeout", elapsed, maxExpected)
} else {
t.Logf("✓ Request was properly terminated by timeout")
}

case <-time.After(15 * time.Second):
t.Fatalf("Test timed out after 15 seconds - request should have been killed by requestTimeout")
}
}

// TestProcess_RequestTimeoutWithRealProcess tests with an actual process
func TestProcess_RequestTimeoutWithRealProcess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test with real process in short mode")
}

// This test would require a real llama.cpp server or similar
// For now, we can skip it or mock it
t.Skip("Requires real inference server")
}