Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,21 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}

# afterHealthy: a command to run once after the model passes its health check
# - optional, default: ""
# - runs as a one-shot command as soon as the server reports healthy
# - blocks the model from becoming ready until it completes
# - a failure is logged as a warning but does not prevent the model from starting
# - useful for loading saved prompt cache slots in llama.cpp
# afterHealthy: "curl -X POST 'http://localhost:${PORT}/slots/0?action=restore' -H 'Content-Type: application/json' -d '{\"filename\": \"slot0.bin\"}'"

# beforeStop: a command to run right before the model process is killed
# - optional, default: ""
# - blocks the model shutdown until it completes (or fails)
# - the model will be stopped regardless of whether this command succeeds
# - useful for saving prompt cache slots in llama.cpp before unloading
# beforeStop: "curl -X POST 'http://localhost:${PORT}/slots/0?action=save' -H 'Content-Type: application/json' -d '{\"filename\": \"slot0.bin\"}'"

# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provides advanced controls over model swapping behaviour
Expand Down
12 changes: 10 additions & 2 deletions proxy/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {

modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.AfterHealthy = strings.ReplaceAll(modelConfig.AfterHealthy, macroSlug, macroStr)
modelConfig.BeforeStop = strings.ReplaceAll(modelConfig.BeforeStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
Expand Down Expand Up @@ -339,10 +341,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}

// Handle PORT macro - only allocate if cmd uses it
// Handle PORT macro - only allocate if cmd, afterHealthy, or beforeStop uses it
cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}")
afterHealthyHasPort := strings.Contains(modelConfig.AfterHealthy, "${PORT}")
beforeStopHasPort := strings.Contains(modelConfig.BeforeStop, "${PORT}")
proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}")
if cmdHasPort || proxyHasPort {
if cmdHasPort || afterHealthyHasPort || beforeStopHasPort || proxyHasPort {
if !cmdHasPort && proxyHasPort {
return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId)
}
Expand All @@ -352,6 +356,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {

modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr)
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.AfterHealthy = strings.ReplaceAll(modelConfig.AfterHealthy, macroSlug, macroStr)
modelConfig.BeforeStop = strings.ReplaceAll(modelConfig.BeforeStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr)
modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr)
Expand All @@ -371,6 +377,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
fieldMap := map[string]string{
"cmd": modelConfig.Cmd,
"cmdStop": modelConfig.CmdStop,
"afterHealthy": modelConfig.AfterHealthy,
"beforeStop": modelConfig.BeforeStop,
"proxy": modelConfig.Proxy,
"checkEndpoint": modelConfig.CheckEndpoint,
"filters.stripParams": modelConfig.Filters.StripParams,
Expand Down
2 changes: 2 additions & 0 deletions proxy/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const (
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
AfterHealthy string `yaml:"afterHealthy"`
BeforeStop string `yaml:"beforeStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
Expand Down
46 changes: 46 additions & 0 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,13 @@ func (p *Process) start() error {
}
}

if p.config.AfterHealthy != "" {
p.proxyLogger.Debugf("<%s> Running afterHealthy hook: %s", p.ID, p.config.AfterHealthy)
if err := p.runHookCommand(p.config.AfterHealthy); err != nil {
p.proxyLogger.Warnf("<%s> afterHealthy hook failed: %v", p.ID, err)
}
}

if p.config.UnloadAfter > 0 {
// start a goroutine to check every second if
// the process should be stopped
Expand Down Expand Up @@ -429,6 +436,13 @@ func (p *Process) stopCommand() {
return
}

if p.config.BeforeStop != "" {
p.proxyLogger.Debugf("<%s> Running beforeStop hook: %s", p.ID, p.config.BeforeStop)
if err := p.runHookCommand(p.config.BeforeStop); err != nil {
p.proxyLogger.Warnf("<%s> beforeStop hook failed: %v", p.ID, err)
}
}

cancelUpstream()
<-cmdWaitChan
}
Expand Down Expand Up @@ -654,6 +668,38 @@ func (p *Process) Logger() *LogMonitor {
return p.processLogger
}

var hookCommandTimeout = 30 * time.Second

// runHookCommand executes a hook command, logging its output through the
// process logger. The command inherits the environment of the upstream process.
func (p *Process) runHookCommand(hookCmd string) error {
args, err := config.SanitizeCommand(hookCmd)
if err != nil {
return fmt.Errorf("failed to sanitize hook command %q: %v", hookCmd, err)
}

ctx, cancel := context.WithTimeout(context.Background(), hookCommandTimeout)
defer cancel()

cmd := exec.CommandContext(ctx, args[0], args[1:]...)
cmd.Stdout = p.processLogger
cmd.Stderr = p.processLogger
if p.cmd != nil {
cmd.Env = p.cmd.Env
}
setProcAttributes(cmd)

err = cmd.Run()
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return fmt.Errorf("hook command timed out after %v: %w", hookCommandTimeout, err)
}
return fmt.Errorf("hook command failed: %w", err)
}

return nil
}

var loadingRemarks = []string{
"Still faster than your last standup meeting...",
"Reticulating splines...",
Expand Down
48 changes: 48 additions & 0 deletions proxy/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,54 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
<-waitChan
}

func TestProcess_AfterHealthyHook(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping AfterHealthy hook test on Windows")
}

tmpFile := t.TempDir() + "/after_healthy_ran"
conf := getTestSimpleResponderConfig("after_healthy_test")
conf.AfterHealthy = fmt.Sprintf("touch %s", tmpFile)

process := NewProcess("test-after-healthy", 5, conf, debugLogger, debugLogger)
defer process.Stop()

err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())

// The hook should have created the temp file before the process became ready
_, statErr := os.Stat(tmpFile)
assert.Nil(t, statErr, "afterHealthy hook should have created temp file")
}

func TestProcess_BeforeStopHook(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping BeforeStop hook test on Windows")
}

tmpFile := t.TempDir() + "/before_stop_ran"
conf := getTestSimpleResponderConfig("before_stop_test")
conf.BeforeStop = fmt.Sprintf("touch %s", tmpFile)

process := NewProcess("test-before-stop", 5, conf, debugLogger, debugLogger)

err := process.start()
assert.Nil(t, err)
assert.Equal(t, StateReady, process.CurrentState())

// Verify hook hasn't run yet
_, statErr := os.Stat(tmpFile)
assert.True(t, os.IsNotExist(statErr), "beforeStop hook should not have run yet")

process.Stop()
assert.Equal(t, StateStopped, process.CurrentState())

// The hook should have created the temp file before the process was killed
_, statErr = os.Stat(tmpFile)
assert.Nil(t, statErr, "beforeStop hook should have created temp file")
}

func TestProcess_StopCmd(t *testing.T) {
conf := getTestSimpleResponderConfig("test_stop_cmd")

Expand Down
Loading