diff --git a/config.example.yaml b/config.example.yaml index 35f74c12..da173eef 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -315,6 +315,21 @@ models: # - processes have 5 seconds to shutdown until forceful termination is attempted cmdStop: docker stop ${MODEL_ID} + # afterHealthy: a command to run once after the model passes its health check + # - optional, default: "" + # - runs as a one-shot command as soon as the server reports healthy + # - blocks the model from becoming ready until it completes + # - a failure is logged as a warning but does not prevent the model from starting + # - useful for loading saved prompt cache slots in llama.cpp + # afterHealthy: "curl -X POST 'http://localhost:${PORT}/slots/0?action=restore' -H 'Content-Type: application/json' -d '{\"filename\": \"slot0.bin\"}'" + + # beforeStop: a command to run right before the model process is killed + # - optional, default: "" + # - blocks the model shutdown until it completes (or fails) + # - the model will be stopped regardless of whether this command succeeds + # - useful for saving prompt cache slots in llama.cpp before unloading + # beforeStop: "curl -X POST 'http://localhost:${PORT}/slots/0?action=save' -H 'Content-Type: application/json' -d '{\"filename\": \"slot0.bin\"}'" + # groups: a dictionary of group settings # - optional, default: empty dictionary # - provides advanced controls over model swapping behaviour diff --git a/proxy/config/config.go b/proxy/config/config.go index 00f44970..91696471 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -305,6 +305,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) + modelConfig.AfterHealthy = strings.ReplaceAll(modelConfig.AfterHealthy, macroSlug, macroStr) + modelConfig.BeforeStop = strings.ReplaceAll(modelConfig.BeforeStop, macroSlug, macroStr) modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr) modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr) @@ -339,10 +341,12 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - // Handle PORT macro - only allocate if cmd uses it + // Handle PORT macro - only allocate if cmd, afterHealthy, or beforeStop uses it cmdHasPort := strings.Contains(modelConfig.Cmd, "${PORT}") + afterHealthyHasPort := strings.Contains(modelConfig.AfterHealthy, "${PORT}") + beforeStopHasPort := strings.Contains(modelConfig.BeforeStop, "${PORT}") proxyHasPort := strings.Contains(modelConfig.Proxy, "${PORT}") - if cmdHasPort || proxyHasPort { + if cmdHasPort || afterHealthyHasPort || beforeStopHasPort || proxyHasPort { if !cmdHasPort && proxyHasPort { return Config{}, fmt.Errorf("model %s: proxy uses ${PORT} but cmd does not - ${PORT} is only available when used in cmd", modelId) } @@ -352,6 +356,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroStr) modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr) + modelConfig.AfterHealthy = strings.ReplaceAll(modelConfig.AfterHealthy, macroSlug, macroStr) + modelConfig.BeforeStop = strings.ReplaceAll(modelConfig.BeforeStop, macroSlug, macroStr) modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr) modelConfig.Name = strings.ReplaceAll(modelConfig.Name, macroSlug, macroStr) modelConfig.Description = strings.ReplaceAll(modelConfig.Description, macroSlug, macroStr) @@ -371,6 +377,8 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { fieldMap := map[string]string{ "cmd": modelConfig.Cmd, "cmdStop": modelConfig.CmdStop, + "afterHealthy": modelConfig.AfterHealthy, + "beforeStop": modelConfig.BeforeStop, "proxy": modelConfig.Proxy, "checkEndpoint": modelConfig.CheckEndpoint, "filters.stripParams": modelConfig.Filters.StripParams, diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 685687ba..974b01aa 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -12,6 +12,8 @@ const ( type ModelConfig struct { Cmd string `yaml:"cmd"` CmdStop string `yaml:"cmdStop"` + AfterHealthy string `yaml:"afterHealthy"` + BeforeStop string `yaml:"beforeStop"` Proxy string `yaml:"proxy"` Aliases []string `yaml:"aliases"` Env []string `yaml:"env"` diff --git a/proxy/process.go b/proxy/process.go index 41427059..73efa2ee 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -333,6 +333,13 @@ func (p *Process) start() error { } } + if p.config.AfterHealthy != "" { + p.proxyLogger.Debugf("<%s> Running afterHealthy hook: %s", p.ID, p.config.AfterHealthy) + if err := p.runHookCommand(p.config.AfterHealthy); err != nil { + p.proxyLogger.Warnf("<%s> afterHealthy hook failed: %v", p.ID, err) + } + } + if p.config.UnloadAfter > 0 { // start a goroutine to check every second if // the process should be stopped @@ -429,6 +436,13 @@ func (p *Process) stopCommand() { return } + if p.config.BeforeStop != "" { + p.proxyLogger.Debugf("<%s> Running beforeStop hook: %s", p.ID, p.config.BeforeStop) + if err := p.runHookCommand(p.config.BeforeStop); err != nil { + p.proxyLogger.Warnf("<%s> beforeStop hook failed: %v", p.ID, err) + } + } + cancelUpstream() <-cmdWaitChan } @@ -654,6 +668,38 @@ func (p *Process) Logger() *LogMonitor { return p.processLogger } +var hookCommandTimeout = 30 * time.Second + +// runHookCommand executes a hook command, logging its output through the +// process logger. The command inherits the environment of the upstream process. +func (p *Process) runHookCommand(hookCmd string) error { + args, err := config.SanitizeCommand(hookCmd) + if err != nil { + return fmt.Errorf("failed to sanitize hook command %q: %v", hookCmd, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), hookCommandTimeout) + defer cancel() + + cmd := exec.CommandContext(ctx, args[0], args[1:]...) + cmd.Stdout = p.processLogger + cmd.Stderr = p.processLogger + if p.cmd != nil { + cmd.Env = p.cmd.Env + } + setProcAttributes(cmd) + + err = cmd.Run() + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return fmt.Errorf("hook command timed out after %v: %w", hookCommandTimeout, err) + } + return fmt.Errorf("hook command failed: %w", err) + } + + return nil +} + var loadingRemarks = []string{ "Still faster than your last standup meeting...", "Reticulating splines...", diff --git a/proxy/process_test.go b/proxy/process_test.go index dd9e9d8a..797b360f 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -456,6 +456,54 @@ func TestProcess_ForceStopWithKill(t *testing.T) { <-waitChan } +func TestProcess_AfterHealthyHook(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping AfterHealthy hook test on Windows") + } + + tmpFile := t.TempDir() + "/after_healthy_ran" + conf := getTestSimpleResponderConfig("after_healthy_test") + conf.AfterHealthy = fmt.Sprintf("touch %s", tmpFile) + + process := NewProcess("test-after-healthy", 5, conf, debugLogger, debugLogger) + defer process.Stop() + + err := process.start() + assert.Nil(t, err) + assert.Equal(t, StateReady, process.CurrentState()) + + // The hook should have created the temp file before the process became ready + _, statErr := os.Stat(tmpFile) + assert.Nil(t, statErr, "afterHealthy hook should have created temp file") +} + +func TestProcess_BeforeStopHook(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping BeforeStop hook test on Windows") + } + + tmpFile := t.TempDir() + "/before_stop_ran" + conf := getTestSimpleResponderConfig("before_stop_test") + conf.BeforeStop = fmt.Sprintf("touch %s", tmpFile) + + process := NewProcess("test-before-stop", 5, conf, debugLogger, debugLogger) + + err := process.start() + assert.Nil(t, err) + assert.Equal(t, StateReady, process.CurrentState()) + + // Verify hook hasn't run yet + _, statErr := os.Stat(tmpFile) + assert.True(t, os.IsNotExist(statErr), "beforeStop hook should not have run yet") + + process.Stop() + assert.Equal(t, StateStopped, process.CurrentState()) + + // The hook should have created the temp file before the process was killed + _, statErr = os.Stat(tmpFile) + assert.Nil(t, statErr, "beforeStop hook should have created temp file") +} + func TestProcess_StopCmd(t *testing.T) { conf := getTestSimpleResponderConfig("test_stop_cmd")