Skip to content
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- ✅ API Key support - define keys to restrict access to API endpoints
- ✅ Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl`
- Automatic unloading of models after idle timeout by setting a `ttl`
- Request timeout protection with `requestTimeout` to prevent runaway inference
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))

Expand Down
6 changes: 6 additions & 0 deletions config-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@
"type": "boolean",
"description": "Overrides the global sendLoadingState for this model. Ommitting this property will use the global setting."
},
"requestTimeout": {
"type": "integer",
"minimum": 0,
"default": 0,
"description": "Maximum time in seconds for a single request to complete before forcefully killing the model process. This prevents runaway inference processes from blocking the GPU indefinitely. 0 disables timeout (default). When exceeded, the process is terminated and must be restarted for the next request."
},
"unlisted": {
"type": "boolean",
"default": false,
Expand Down
10 changes: 10 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0

# requestTimeout: maximum time in seconds for a single request to complete
# - optional, default: 0 (no timeout)
# - useful for preventing runaway inference processes that never complete
# - when exceeded, the model process is forcefully stopped
# - protects against GPU overheating and blocking from stuck processes
# - the process must be restarted for the next request
# - set to 0 to disable timeout
# - recommended for models that may have infinite loops or excessive generation
requestTimeout: 0 # disabled by default, set to e.g., 300 for 5 minutes

# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
Expand Down
10 changes: 10 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,16 @@ models:
# - recommended to be omitted and the default used
concurrencyLimit: 0

# requestTimeout: maximum time in seconds for a single request to complete
# - optional, default: 0 (no timeout)
# - useful for preventing runaway inference processes that never complete
# - when exceeded, the model process is forcefully stopped
# - protects against GPU overheating and blocking from stuck processes
# - the process must be restarted for the next request
# - set to 0 to disable timeout
# - recommended for models that may have infinite loops or excessive generation
requestTimeout: 300 # 5 minutes

# sendLoadingState: overrides the global sendLoadingState setting for this model
# - optional, default: undefined (use global setting)
sendLoadingState: false
Expand Down
5 changes: 5 additions & 0 deletions proxy/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type ModelConfig struct {

// override global setting
SendLoadingState *bool `yaml:"sendLoadingState"`

// Maximum time in seconds for a request to complete before killing the process
// 0 means no timeout (default)
RequestTimeout int `yaml:"requestTimeout"`
}

func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
Expand All @@ -53,6 +57,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
ConcurrencyLimit: 0,
Name: "",
Description: "",
RequestTimeout: 0,
}

// the default cmdStop to taskkill /f /t /pid ${PID}
Expand Down
119 changes: 102 additions & 17 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type Process struct {
// PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc
cmdStarted bool // tracks if cmd.Start() completed successfully

// closed when command exits
cmdWaitChan chan struct{}
Expand Down Expand Up @@ -250,26 +251,43 @@ func (p *Process) start() error {
defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())

p.cmd = exec.CommandContext(cmdContext, args[0], args[1:]...)
p.cmd.Stdout = p.processLogger
p.cmd.Stderr = p.processLogger
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout
setProcAttributes(p.cmd)
cmd := exec.CommandContext(cmdContext, args[0], args[1:]...)
cmd.Stdout = p.processLogger
cmd.Stderr = p.processLogger
cmd.Env = append(cmd.Environ(), p.config.Env...)
cmd.Cancel = p.cmdStopUpstreamProcess
cmd.WaitDelay = p.gracefulStopTimeout
setProcAttributes(cmd)

p.failedStartCount++ // this will be reset to zero when the process has successfully started

// Initialize cancelUpstream and cmdWaitChan before Start() so stopCommand() can use them
// if called during startup (e.g., due to timeout)
p.cmdMutex.Lock()
p.cmd = cmd
p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()

p.failedStartCount++ // this will be reset to zero when the process has successfully started

p.proxyLogger.Debugf("<%s> Executing start command: %s, env: %s", p.ID, strings.Join(args, " "), strings.Join(p.config.Env, ", "))
err = p.cmd.Start()

// Set cmdStarted flag under lock after Start() completes
// This prevents data race with stopCommand() which checks cmdStarted instead of cmd.Process
p.cmdMutex.Lock()
if err == nil {
p.cmdStarted = true
}
p.cmdMutex.Unlock()

// Set process state to failed
if err != nil {
// Close cmdWaitChan to prevent stopCommand() from hanging if a timeout
// transitions StateStarting -> StateStopping before Start() completes
p.cmdMutex.Lock()
close(p.cmdWaitChan)
p.cmdMutex.Unlock()

if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf(
Expand Down Expand Up @@ -381,14 +399,28 @@ func (p *Process) Stop() {
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
if !isValidTransition(p.CurrentState(), StateStopping) {
return
}
// Try to transition from current state to StateStopping
// Process might be in StateReady or StateStarting when timeout fires
// Retry on ErrExpectedStateMismatch to handle transient state changes
for {
currentState := p.CurrentState()
if !isValidTransition(currentState, StateStopping) {
return
}

p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
return
p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, currentState)

if _, err := p.swapState(currentState, StateStopping); err != nil {
if err == ErrExpectedStateMismatch {
// State changed between CurrentState() and swapState(), retry
continue
}
p.proxyLogger.Infof("<%s> Stop() %s -> StateStopping err: %v", p.ID, currentState, err)
return
}

// Successfully transitioned to StateStopping
break
}

p.stopCommand()
Expand Down Expand Up @@ -422,14 +454,28 @@ func (p *Process) stopCommand() {
p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
cmdStarted := p.cmdStarted
p.cmdMutex.RUnlock()

// If cancelUpstream is nil, the process was never actually started
// (e.g., forceState() was used in tests). Just return silently.
if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
p.proxyLogger.Debugf("<%s> stopCommand: cancelUpstream is nil, process was never started", p.ID)
return
}

// Always cancel the context to stop the command
cancelUpstream()

// If cmdStarted is false, the process never actually started
// (cmd.Start() was never called or failed), so skip waiting on cmdWaitChan
// to avoid hanging. This can happen if a timeout transitions StateStarting
// to StateStopping before cmd.Start() completes.
if !cmdStarted {
p.proxyLogger.Debugf("<%s> stopCommand: process never started (cmdStarted is false), skipping wait", p.ID)
return
}

<-cmdWaitChan
}

Expand Down Expand Up @@ -500,6 +546,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Done()
}()

// Start timeout monitoring if requestTimeout is configured
var requestCancel context.CancelFunc
var timeoutCancel context.CancelFunc

if p.config.RequestTimeout > 0 {
timeoutDuration := time.Duration(p.config.RequestTimeout) * time.Second

// Add timeout to request context to cancel the request when exceeded
requestCtx, cancel := context.WithTimeout(r.Context(), timeoutDuration)
requestCancel = cancel
r = r.Clone(requestCtx)

// Create a separate timeout context for monitoring only
// Use context.Background() to ensure we detect our configured timeout,
// not parent-imposed deadlines that would cause misattribution
timeoutCtx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
timeoutCancel = cancel

go func() {
<-timeoutCtx.Done()
if timeoutCtx.Err() == context.DeadlineExceeded {
p.proxyLogger.Warnf("<%s> Request timeout exceeded (%v), force stopping process to prevent GPU blocking", p.ID, timeoutDuration)
// Force stop the process - this will kill the underlying inference process
p.StopImmediately()
}
}()

// Ensure both timeouts are cancelled when request completes
defer func() {
if requestCancel != nil {
requestCancel()
}
if timeoutCancel != nil {
timeoutCancel()
}
}()
}

// for #366
// - extract streaming param from request context, should have been set by proxymanager
var srw *statusResponseWriter
Expand Down Expand Up @@ -606,6 +690,7 @@ func (p *Process) waitForCmd() {

p.cmdMutex.Lock()
close(p.cmdWaitChan)
p.cmdStarted = false
p.cmdMutex.Unlock()
}

Expand Down
112 changes: 112 additions & 0 deletions proxy/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ var (
debugLogger = NewLogMonitorWriter(os.Stdout)
)

// getSleepCommand returns a platform-appropriate command to sleep for the given number of seconds
func getSleepCommand(seconds int) string {
if runtime.GOOS == "windows" {
// Use full path to avoid conflict with GNU coreutils timeout in Git Bash
return fmt.Sprintf("C:\\Windows\\System32\\timeout.exe /t %d /nobreak", seconds)
}
return fmt.Sprintf("sleep %d", seconds)
}

func init() {
// flip to help with debugging tests
if false {
Expand Down Expand Up @@ -569,3 +578,106 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
}
return w.ResponseRecorder.Write(b)
}

// TestProcess_StopCommandDoesNotHangWhenStartFails verifies that stopCommand()
// does not hang when cmd.Start() fails or hasn't completed yet. This can happen
// when a timeout transitions StateStarting -> StateStopping before cmd.Start()
// completes.
func TestProcess_StopCommandDoesNotHangWhenStartFails(t *testing.T) {
// Create a process with a command that will fail to start
config := config.ModelConfig{
Cmd: "nonexistent-command-that-will-fail",
Proxy: "http://127.0.0.1:9999",
CheckEndpoint: "/health",
}

process := NewProcess("fail-test", 1, config, debugLogger, debugLogger)

// Try to start the process - this will fail
err := process.start()
assert.Error(t, err)
assert.Contains(t, err.Error(), "start() failed for command")
assert.Equal(t, StateStopped, process.CurrentState())

// Now try to stop the process - this should not hang
// Create a channel to track if stopCommand completes
done := make(chan struct{})
go func() {
process.stopCommand()
close(done)
}()

// Wait for stopCommand to complete with a timeout
select {
case <-done:
// Success - stopCommand completed without hanging
case <-time.After(2 * time.Second):
t.Fatal("stopCommand() hung when process never started")
}
}

// TestProcess_StopImmediatelyDuringStartup verifies that StopImmediately()
// can safely interrupt a process during StateStarting without hanging.
func TestProcess_StopImmediatelyDuringStartup(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}

// Use a platform-appropriate command that takes a while to start but won't respond to health checks
cmd := getSleepCommand(10)
config := config.ModelConfig{
Cmd: cmd,
Proxy: "http://127.0.0.1:9999",
CheckEndpoint: "/health",
}

process := NewProcess("interrupt-test", 20, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = 100 * time.Millisecond

// Start the process in a goroutine (it will be in StateStarting)
startDone := make(chan struct{})
errCh := make(chan error, 1)
go func() {
err := process.start()
errCh <- err
close(startDone)
}()

// Wait a bit for the process to enter StateStarting
<-time.After(200 * time.Millisecond)
currentState := process.CurrentState()
assert.Equal(t, StateStarting, currentState, "Process should be in StateStarting")

// Now call StopImmediately while in StateStarting
// This simulates a timeout firing during startup
stopDone := make(chan struct{})
go func() {
process.StopImmediately()
close(stopDone)
}()

// Verify StopImmediately completes without hanging
select {
case <-stopDone:
// Success - StopImmediately completed
case <-time.After(3 * time.Second):
t.Fatal("StopImmediately() hung when called during StateStarting")
}

// Wait for start() to complete
select {
case <-startDone:
// Success
case <-time.After(2 * time.Second):
t.Fatal("start() did not complete after StopImmediately")
}

// Verify start() returned an error due to StopImmediately interrupt
err := <-errCh
assert.Error(t, err)

// Process should be in StateStopped or StateStopping
finalState := process.CurrentState()
assert.True(t, finalState == StateStopped || finalState == StateStopping,
"Expected StateStopped or StateStopping, got %s", finalState)
}
Loading
Loading