diff --git a/.github/workflows/go-ci-windows.yml b/.github/workflows/go-ci-windows.yml
index 6cd2747e..ed831b11 100644
--- a/.github/workflows/go-ci-windows.yml
+++ b/.github/workflows/go-ci-windows.yml
@@ -32,7 +32,6 @@ jobs:
# necessary for testing proxy/Process swapping
- name: Create simple-responder
- if: steps.restore-simple-responder.outputs.cache-hit != 'true'
shell: bash
run: make simple-responder-windows
diff --git a/README.md b/README.md
index 0a31333c..9b60e901 100644
--- a/README.md
+++ b/README.md
@@ -29,12 +29,14 @@ Built in Go for performance and simplicity, llama-swap has zero dependencies and
- `/ui` - web UI
- `/upstream/:model_id` - direct access to upstream server ([demo](https://github.com/mostlygeek/llama-swap/pull/31))
- `/models/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58))
+ - `/models/sleep/:model_id` - put a model to sleep (requires sleep/wake configuration)
- `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61))
- `/log` - remote log monitoring
- `/health` - just returns "OK"
- ✅ Customizable
- Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107))
- Automatic unloading of models after timeout by setting a `ttl`
+ - Fast model switching with sleep/wake support (vLLM sleep mode, offload memory instead of full restart)
- Reliable Docker and Podman support using `cmd` and `cmdStop` together
- Preload models on startup with `hooks` ([#235](https://github.com/mostlygeek/llama-swap/pull/235))
diff --git a/cmd/simple-responder/simple-responder.go b/cmd/simple-responder/simple-responder.go
index 6c65140e..5167b344 100644
--- a/cmd/simple-responder/simple-responder.go
+++ b/cmd/simple-responder/simple-responder.go
@@ -8,6 +8,7 @@ import (
"net/http"
"os"
"os/signal"
+ "strings"
"syscall"
"time"
@@ -264,6 +265,32 @@ func main() {
c.JSON(200, gin.H{"status": "ok"})
})
+ // Sleep/wake endpoints
+ r.POST("/sleep", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ r.POST("/wake_up", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ r.POST("/wake_up_fail", func(c *gin.Context) {
+ c.Status(http.StatusInternalServerError)
+ })
+
+ r.POST("/collective_rpc", func(c *gin.Context) {
+ body, _ := io.ReadAll(c.Request.Body)
+ if strings.Contains(string(body), "reload_weights") {
+ c.Status(http.StatusOK)
+ } else {
+ c.Status(http.StatusBadRequest)
+ }
+ })
+
+ r.POST("/reset_prefix_cache", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
r.GET("/", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
diff --git a/config-schema.json b/config-schema.json
index 63913981..d7dffa6c 100644
--- a/config-schema.json
+++ b/config-schema.json
@@ -48,6 +48,18 @@
"default": 120,
"description": "Number of seconds to wait for a model to be ready to serve requests."
},
+ "sleepRequestTimeout": {
+ "type": "integer",
+ "minimum": 1,
+ "default": 10,
+ "description": "Number of seconds to wait for each sleep HTTP request to complete. Applies globally to all sleep endpoints unless overridden per-endpoint with timeout field."
+ },
+ "wakeRequestTimeout": {
+ "type": "integer",
+ "minimum": 1,
+ "default": 10,
+ "description": "Number of seconds to wait for each wake HTTP request to complete. Applies globally to all wake endpoints unless overridden per-endpoint with timeout field."
+ },
"logLevel": {
"type": "string",
"enum": [
@@ -214,6 +226,80 @@
"type": "boolean",
"default": false,
"description": "If true the model will not show up in /v1/models responses. It can still be used as normal in API requests."
+ },
+ "sleepMode": {
+ "type": "string",
+ "enum": ["enable", "disable"],
+ "default": "disable",
+ "description": "Explicitly controls sleep/wake behavior. 'enable' activates sleep/wake functionality and requires sleepEndpoints and wakeEndpoints to be defined."
+ },
+ "sleepEndpoints": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": ["endpoint"],
+ "properties": {
+ "endpoint": {
+ "type": "string",
+ "minLength": 1,
+ "description": "URL path for the sleep endpoint (e.g., /sleep?level=1)."
+ },
+ "method": {
+ "type": "string",
+ "enum": ["GET", "POST", "PUT", "PATCH"],
+ "default": "POST",
+ "description": "HTTP method to use for the request."
+ },
+ "body": {
+ "type": "string",
+ "default": "",
+ "description": "Optional request body (JSON string)."
+ },
+ "timeout": {
+ "type": "integer",
+ "minimum": 0,
+ "default": 0,
+ "description": "Optional per-endpoint timeout in seconds. 0 uses global sleepRequestTimeout."
+ }
+ },
+ "additionalProperties": false
+ },
+ "default": [],
+ "description": "Array of HTTP endpoints to call for putting the model to sleep. Requires sleepMode to be 'enable'. Endpoints are called sequentially in array order. Used instead of cmdStop during model swapping."
+ },
+ "wakeEndpoints": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "required": ["endpoint"],
+ "properties": {
+ "endpoint": {
+ "type": "string",
+ "minLength": 1,
+ "description": "URL path for the wake endpoint (e.g., /wake_up)."
+ },
+ "method": {
+ "type": "string",
+ "enum": ["GET", "POST", "PUT", "PATCH"],
+ "default": "POST",
+ "description": "HTTP method to use for the request."
+ },
+ "body": {
+ "type": "string",
+ "default": "",
+ "description": "Optional request body (JSON string)."
+ },
+ "timeout": {
+ "type": "integer",
+ "minimum": 0,
+ "default": 0,
+ "description": "Optional per-endpoint timeout in seconds. 0 uses global wakeRequestTimeout."
+ }
+ },
+ "additionalProperties": false
+ },
+ "default": [],
+ "description": "Array of HTTP endpoints to call for waking the model. Requires sleepMode to be 'enable'. Required when sleepMode is 'enable'. Endpoints are called sequentially in array order."
}
}
}
diff --git a/config.example.yaml b/config.example.yaml
index e6b8c9c2..6949bbe3 100644
--- a/config.example.yaml
+++ b/config.example.yaml
@@ -21,6 +21,18 @@
# - minimum value is 15 seconds, anything less will be set to this value
healthCheckTimeout: 500
+# sleepRequestTimeout: number of seconds to wait for each sleep HTTP request to complete
+# - optional, default: 10
+# - applies globally to all sleep endpoints unless overridden per-endpoint with timeout field
+# - used when putting a model to sleep during model swapping
+sleepRequestTimeout: 20
+
+# wakeRequestTimeout: number of seconds to wait for each wake HTTP request to complete
+# - optional, default: 10
+# - applies globally to all wake endpoints unless overridden per-endpoint with timeout field
+# - used when waking a model from sleep
+wakeRequestTimeout: 20
+
# logLevel: sets the logging value
# - optional, default: info
# - Valid log levels: debug, info, warn, error
@@ -243,6 +255,93 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
+ # vLLM Sleep Mode Example - Level 1:
+ # vLLM supports sleep/wake functionality for fast model switching
+ # See: https://docs.vllm.ai/en/stable/features/sleep_mode.html
+ # Level 1: offload weights to CPU RAM (faster wake, higher RAM usage, single-step wake)
+ "vllm-sleep-level1":
+ # sleepMode: explicitly controls sleep/wake behavior
+ # - "enable": activates sleep/wake - requires sleepEndpoints and wakeEndpoints
+ # - "disable": disables sleep/wake - uses cmdStop instead
+ # - (empty): default - sleep mode disabled
+ sleepMode: enable
+
+ cmd: |
+ uv run python -m vllm.entrypoints.openai.api_server
+ --model /path/to/models/my-model
+ --served-model-name ${MODEL_ID}
+ --port ${PORT}
+ --enable-sleep-mode
+ env:
+ # Required to enable sleep mode in vLLM
+ - "VLLM_SERVER_DEV_MODE=1"
+
+ # sleepEndpoints: array of HTTP endpoints to call for putting the model to sleep
+ # - optional, default: []
+ # - if defined along with wakeEndpoints, used instead of cmdStop during model swapping
+ # - HTTP requests are sent to proxy base URL + endpoint
+ # - endpoints are called sequentially in array order
+ # - supports macro substitution: ${PORT}, ${MODEL_ID}
+ # - each endpoint can include query parameters: /sleep?level=1
+ # - vLLM sleep levels:
+ # - level 1: offload weights to CPU RAM (faster wake, higher RAM usage)
+ # - level 2: discard weights entirely (slower wake, minimal RAM usage)
+ sleepEndpoints:
+ - endpoint: /sleep?level=1
+ method: POST
+ # body is optional
+ # timeout is optional - overrides global sleepRequestTimeout for this specific endpoint
+
+ # wakeEndpoints: array of HTTP endpoints to call for waking the model
+ # - required if sleepEndpoints is defined
+ # - used when loading a sleeping model
+ # - HTTP requests are sent to proxy base URL + endpoint
+ # - endpoints are called sequentially in array order
+ # - level 1 sleep requires only single wake step
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+ # timeout is optional - overrides global wakeRequestTimeout for this specific endpoint
+
+ # vLLM Sleep Mode Example - Level 2:
+ # Level 2: discard weights entirely (slower wake, minimal RAM usage, multi-step wake)
+ # Requires a 3-step wake sequence to fully restore the model
+ "vllm-sleep-level2":
+ # Enable sleep/wake functionality
+ sleepMode: enable
+
+ cmd: |
+ uv run python -m vllm.entrypoints.openai.api_server
+ --model /path/to/models/my-large-model
+ --served-model-name ${MODEL_ID}
+ --port ${PORT}
+ --enable-sleep-mode
+ env:
+ # Required to enable sleep mode in vLLM
+ - "VLLM_SERVER_DEV_MODE=1"
+
+ # Level 2 sleep endpoint - discards weights for minimal RAM usage
+ sleepEndpoints:
+ - endpoint: /sleep?level=2
+ method: POST
+ # Optional: override global sleepRequestTimeout
+ timeout: 15
+
+ # Level 2 wake requires multi-step sequence to reload weights and reset cache
+ wakeEndpoints:
+ # Step 1: Wake the model
+ - endpoint: /wake_up
+ method: POST
+ # Step 2: Reload weights
+ - endpoint: /collective_rpc
+ method: POST
+ body: '{"method": "reload_weights"}'
+ # Optional: override timeout for this specific endpoint
+ timeout: 12
+ # Step 3: Reset the prefix cache
+ - endpoint: /reset_prefix_cache
+ method: POST
+
# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provides advanced controls over model swapping behaviour
diff --git a/docs/configuration.md b/docs/configuration.md
index c253d408..b938d52f 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -72,16 +72,17 @@ models:
llama-swap supports many more features to customize how you want to manage your environment.
-| Feature | Description |
-| --------- | ---------------------------------------------- |
-| `ttl` | automatic unloading of models after a timeout |
-| `macros` | reusable snippets to use in configurations |
-| `groups` | run multiple models at a time |
-| `hooks` | event driven functionality |
-| `env` | define environment variables per model |
-| `aliases` | serve a model with different names |
-| `filters` | modify requests before sending to the upstream |
-| `...` | And many more tweaks |
+| Feature | Description |
+| ------------- | ---------------------------------------------- |
+| `ttl` | automatic unloading of models after a timeout |
+| `sleep/wake` | fast model switching with sleep mode support |
+| `macros` | reusable snippets to use in configurations |
+| `groups` | run multiple models at a time |
+| `hooks` | event driven functionality |
+| `env` | define environment variables per model |
+| `aliases` | serve a model with different names |
+| `filters` | modify requests before sending to the upstream |
+| `...` | And many more tweaks |
## Full Configuration Example
@@ -120,6 +121,18 @@ logLevel: info
# - useful for limiting memory usage when processing large volumes of metrics
metricsMaxInMemory: 1000
+# sleepRequestTimeout: number of seconds to wait for each sleep HTTP request to complete
+# - optional, default: 10
+# - applies globally to all sleep endpoints unless overridden per-endpoint with timeout field
+# - used when putting a model to sleep during model swapping
+sleepRequestTimeout: 20
+
+# wakeRequestTimeout: number of seconds to wait for each wake HTTP request to complete
+# - optional, default: 10
+# - applies globally to all wake endpoints unless overridden per-endpoint with timeout field
+# - used when waking a model from sleep
+wakeRequestTimeout: 20
+
# startPort: sets the starting port number for the automatic ${PORT} macro.
# - optional, default: 5800
# - the ${PORT} macro can be used in model.cmd and model.proxy settings
@@ -305,6 +318,93 @@ models:
# - processes have 5 seconds to shutdown until forceful termination is attempted
cmdStop: docker stop ${MODEL_ID}
+ # vLLM Sleep Mode Example - Level 1:
+ # vLLM supports sleep/wake functionality for fast model switching
+ # See: https://docs.vllm.ai/en/stable/features/sleep_mode.html
+ # Level 1: offload weights to CPU RAM (faster wake, higher RAM usage, single-step wake)
+ "vllm-sleep-level1":
+ # sleepMode: explicitly controls sleep/wake behavior
+ # - "enable": activates sleep/wake - requires sleepEndpoints and wakeEndpoints
+ # - "disable": disables sleep/wake - uses cmdStop instead
+ # - (empty): default - sleep mode disabled
+ sleepMode: enable
+
+ cmd: |
+ uv run python -m vllm.entrypoints.openai.api_server
+ --model /path/to/models/my-model
+ --served-model-name ${MODEL_ID}
+ --port ${PORT}
+ --enable-sleep-mode
+ env:
+ # Required to enable sleep mode in vLLM
+ - "VLLM_SERVER_DEV_MODE=1"
+
+ # sleepEndpoints: array of HTTP endpoints to call for putting the model to sleep
+ # - optional, default: []
+ # - if defined along with wakeEndpoints, used instead of cmdStop during model swapping
+ # - HTTP requests are sent to proxy base URL + endpoint
+ # - endpoints are called sequentially in array order
+ # - supports macro substitution: ${PORT}, ${MODEL_ID}
+ # - each endpoint can include query parameters: /sleep?level=1
+ # - vLLM sleep levels:
+ # - level 1: offload weights to CPU RAM (faster wake, higher RAM usage)
+ # - level 2: discard weights entirely (slower wake, minimal RAM usage)
+ sleepEndpoints:
+ - endpoint: /sleep?level=1
+ method: POST
+ # body is optional
+ # timeout is optional - overrides global sleepRequestTimeout for this specific endpoint
+
+ # wakeEndpoints: array of HTTP endpoints to call for waking the model
+ # - required if sleepEndpoints is defined
+ # - used when loading a sleeping model
+ # - HTTP requests are sent to proxy base URL + endpoint
+ # - endpoints are called sequentially in array order
+ # - level 1 sleep requires only single wake step
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+ # timeout is optional - overrides global wakeRequestTimeout for this specific endpoint
+
+ # vLLM Sleep Mode Example - Level 2:
+ # Level 2: discard weights entirely (slower wake, minimal RAM usage, multi-step wake)
+ # Requires a 3-step wake sequence to fully restore the model
+ "vllm-sleep-level2":
+ # Enable sleep/wake functionality
+ sleepMode: enable
+
+ cmd: |
+ uv run python -m vllm.entrypoints.openai.api_server
+ --model /path/to/models/my-large-model
+ --served-model-name ${MODEL_ID}
+ --port ${PORT}
+ --enable-sleep-mode
+ env:
+ # Required to enable sleep mode in vLLM
+ - "VLLM_SERVER_DEV_MODE=1"
+
+ # Level 2 sleep endpoint - discards weights for minimal RAM usage
+ sleepEndpoints:
+ - endpoint: /sleep?level=2
+ method: POST
+ # Optional: override global sleepRequestTimeout
+ timeout: 15
+
+ # Level 2 wake requires multi-step sequence to reload weights and reset cache
+ wakeEndpoints:
+ # Step 1: Wake the model
+ - endpoint: /wake_up
+ method: POST
+ # Step 2: Reload weights
+ - endpoint: /collective_rpc
+ method: POST
+ body: '{"method": "reload_weights"}'
+ # Optional: override timeout for this specific endpoint
+ timeout: 12
+ # Step 3: Reset the prefix cache
+ - endpoint: /reset_prefix_cache
+ method: POST
+
# groups: a dictionary of group settings
# - optional, default: empty dictionary
# - provides advanced controls over model swapping behaviour
diff --git a/proxy/config/config.go b/proxy/config/config.go
index 0138e093..b262f62b 100644
--- a/proxy/config/config.go
+++ b/proxy/config/config.go
@@ -110,14 +110,16 @@ type HookOnStartup struct {
}
type Config struct {
- HealthCheckTimeout int `yaml:"healthCheckTimeout"`
- LogRequests bool `yaml:"logRequests"`
- LogLevel string `yaml:"logLevel"`
- LogTimeFormat string `yaml:"logTimeFormat"`
- MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
- Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
- Profiles map[string][]string `yaml:"profiles"`
- Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
+ HealthCheckTimeout int `yaml:"healthCheckTimeout"`
+ SleepRequestTimeout int `yaml:"sleepRequestTimeout"`
+ WakeRequestTimeout int `yaml:"wakeRequestTimeout"`
+ LogRequests bool `yaml:"logRequests"`
+ LogLevel string `yaml:"logLevel"`
+ LogTimeFormat string `yaml:"logTimeFormat"`
+ MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
+ Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
+ Profiles map[string][]string `yaml:"profiles"`
+ Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
// for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint
Macros MacroList `yaml:"macros"`
@@ -173,11 +175,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
// default configuration values
config := Config{
- HealthCheckTimeout: 120,
- StartPort: 5800,
- LogLevel: "info",
- LogTimeFormat: "",
- MetricsMaxInMemory: 1000,
+ HealthCheckTimeout: 120,
+ SleepRequestTimeout: 10,
+ WakeRequestTimeout: 10,
+ StartPort: 5800,
+ LogLevel: "info",
+ LogTimeFormat: "",
+ MetricsMaxInMemory: 1000,
}
err = yaml.Unmarshal(data, &config)
if err != nil {
@@ -189,6 +193,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
config.HealthCheckTimeout = 15
}
+ if config.SleepRequestTimeout < 1 {
+ // set a minimum of 1 second
+ config.SleepRequestTimeout = 1
+ }
+
+ if config.WakeRequestTimeout < 1 {
+ // set a minimum of 1 second
+ config.WakeRequestTimeout = 1
+ }
+
if config.StartPort < 1 {
return Config{}, fmt.Errorf("startPort must be greater than 1")
}
@@ -276,6 +290,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroStr)
modelConfig.Filters.StripParams = strings.ReplaceAll(modelConfig.Filters.StripParams, macroSlug, macroStr)
+ // Substitute in sleep/wake endpoint arrays
+ for j := range modelConfig.SleepEndpoints {
+ modelConfig.SleepEndpoints[j].Endpoint = strings.ReplaceAll(modelConfig.SleepEndpoints[j].Endpoint, macroSlug, macroStr)
+ modelConfig.SleepEndpoints[j].Body = strings.ReplaceAll(modelConfig.SleepEndpoints[j].Body, macroSlug, macroStr)
+ }
+ for j := range modelConfig.WakeEndpoints {
+ modelConfig.WakeEndpoints[j].Endpoint = strings.ReplaceAll(modelConfig.WakeEndpoints[j].Endpoint, macroSlug, macroStr)
+ modelConfig.WakeEndpoints[j].Body = strings.ReplaceAll(modelConfig.WakeEndpoints[j].Body, macroSlug, macroStr)
+ }
+
// Substitute in metadata (recursive)
if len(modelConfig.Metadata) > 0 {
var err error
@@ -306,6 +330,16 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroStr)
modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroStr)
+ // Substitute PORT in sleep/wake endpoint arrays
+ for j := range modelConfig.SleepEndpoints {
+ modelConfig.SleepEndpoints[j].Endpoint = strings.ReplaceAll(modelConfig.SleepEndpoints[j].Endpoint, macroSlug, macroStr)
+ modelConfig.SleepEndpoints[j].Body = strings.ReplaceAll(modelConfig.SleepEndpoints[j].Body, macroSlug, macroStr)
+ }
+ for j := range modelConfig.WakeEndpoints {
+ modelConfig.WakeEndpoints[j].Endpoint = strings.ReplaceAll(modelConfig.WakeEndpoints[j].Endpoint, macroSlug, macroStr)
+ modelConfig.WakeEndpoints[j].Body = strings.ReplaceAll(modelConfig.WakeEndpoints[j].Body, macroSlug, macroStr)
+ }
+
// Substitute PORT in metadata
if len(modelConfig.Metadata) > 0 {
var err error
@@ -344,6 +378,14 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
+ // Check sleep/wake endpoint arrays for unknown macros
+ if err := validateEndpointMacros(modelConfig.SleepEndpoints, modelId, "sleepEndpoints"); err != nil {
+ return Config{}, err
+ }
+ if err := validateEndpointMacros(modelConfig.WakeEndpoints, modelId, "wakeEndpoints"); err != nil {
+ return Config{}, err
+ }
+
// Check for unknown macros in metadata
if len(modelConfig.Metadata) > 0 {
if err := validateMetadataForUnknownMacros(modelConfig.Metadata, modelId); err != nil {
@@ -365,6 +407,19 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
modelConfig.SendLoadingState = &v
}
+ // Set default timeouts on sleep/wake endpoints if not already configured
+ // Use global config timeout values as defaults
+ for i := range modelConfig.SleepEndpoints {
+ if modelConfig.SleepEndpoints[i].Timeout == 0 {
+ modelConfig.SleepEndpoints[i].Timeout = config.SleepRequestTimeout
+ }
+ }
+ for i := range modelConfig.WakeEndpoints {
+ if modelConfig.WakeEndpoints[i].Timeout == 0 {
+ modelConfig.WakeEndpoints[i].Timeout = config.WakeRequestTimeout
+ }
+ }
+
config.Models[modelId] = modelConfig
}
@@ -567,6 +622,23 @@ func validateMetadataForUnknownMacros(value any, modelId string) error {
}
}
+// validateEndpointMacros checks for unknown macros in a list of HTTPEndpoints
+func validateEndpointMacros(endpoints []HTTPEndpoint, modelId, endpointType string) error {
+ for i, endpoint := range endpoints {
+ for _, fieldValue := range []string{endpoint.Endpoint, endpoint.Body} {
+ matches := macroPatternRegex.FindAllStringSubmatch(fieldValue, -1)
+ for _, match := range matches {
+ macroName := match[1]
+ if macroName == "PORT" || macroName == "MODEL_ID" {
+ return fmt.Errorf("macro '${%s}' should have been substituted in %s.%s[%d]", macroName, modelId, endpointType, i)
+ }
+ return fmt.Errorf("unknown macro '${%s}' found in %s.%s[%d]", macroName, modelId, endpointType, i)
+ }
+ }
+ }
+ return nil
+}
+
// substituteMacroInValue recursively substitutes a single macro in a value structure
// This is called once per macro, allowing LIFO substitution order
func substituteMacroInValue(value any, macroName string, macroValue any) (any, error) {
diff --git a/proxy/config/config_posix_test.go b/proxy/config/config_posix_test.go
index 8793319d..59499a52 100644
--- a/proxy/config/config_posix_test.go
+++ b/proxy/config/config_posix_test.go
@@ -185,6 +185,7 @@ groups:
CheckEndpoint: "/health",
Name: "Model 1",
Description: "This is model 1",
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
"model2": {
@@ -193,6 +194,7 @@ groups:
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
"model3": {
@@ -201,6 +203,7 @@ groups:
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
"model4": {
@@ -209,11 +212,14 @@ groups:
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
},
- HealthCheckTimeout: 15,
- MetricsMaxInMemory: 1000,
+ HealthCheckTimeout: 15,
+ SleepRequestTimeout: 10,
+ WakeRequestTimeout: 10,
+ MetricsMaxInMemory: 1000,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
diff --git a/proxy/config/config_test.go b/proxy/config/config_test.go
index e624a8ce..45679fce 100644
--- a/proxy/config/config_test.go
+++ b/proxy/config/config_test.go
@@ -761,3 +761,248 @@ models:
})
}
}
+
+func TestConfig_SleepWakeBasicConfiguration(t *testing.T) {
+ content := `
+startPort: 10000
+sleepRequestTimeout: 15
+wakeRequestTimeout: 20
+
+models:
+ vllm-model:
+ cmd: python -m vllm.entrypoints.openai.api_server --port ${PORT}
+ sleepEndpoints:
+ - endpoint: /sleep?level=1
+ method: POST
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+`
+
+ config, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.NoError(t, err)
+
+ // Verify global timeout settings
+ assert.Equal(t, 15, config.SleepRequestTimeout)
+ assert.Equal(t, 20, config.WakeRequestTimeout)
+
+ // Verify model sleep/wake endpoints
+ model := config.Models["vllm-model"]
+ assert.Len(t, model.SleepEndpoints, 1)
+ assert.Len(t, model.WakeEndpoints, 1)
+
+ // Check sleep endpoint
+ assert.Equal(t, "/sleep?level=1", model.SleepEndpoints[0].Endpoint)
+ assert.Equal(t, "POST", model.SleepEndpoints[0].Method)
+ assert.Equal(t, 15, model.SleepEndpoints[0].Timeout) // inherited from global
+
+ // Check wake endpoint
+ assert.Equal(t, "/wake_up", model.WakeEndpoints[0].Endpoint)
+ assert.Equal(t, "POST", model.WakeEndpoints[0].Method)
+ assert.Equal(t, 20, model.WakeEndpoints[0].Timeout) // inherited from global
+}
+
+func TestConfig_SleepWakeMacroSubstitution(t *testing.T) {
+ content := `
+startPort: 10000
+macros:
+ SLEEP_LEVEL: "2"
+
+models:
+ vllm-model:
+ cmd: python -m vllm --port ${PORT} --model ${MODEL_ID}
+ sleepEndpoints:
+ - endpoint: /sleep?level=${SLEEP_LEVEL}&port=${PORT}
+ method: POST
+ body: '{"model": "${MODEL_ID}"}'
+ wakeEndpoints:
+ - endpoint: /wake_up?port=${PORT}
+ method: POST
+ body: '{"model_id": "${MODEL_ID}", "port": ${PORT}}'
+`
+
+ config, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.NoError(t, err)
+
+ model := config.Models["vllm-model"]
+
+ // Verify macros were substituted in sleep endpoints
+ assert.Equal(t, "/sleep?level=2&port=10000", model.SleepEndpoints[0].Endpoint)
+ assert.Equal(t, `{"model": "vllm-model"}`, model.SleepEndpoints[0].Body)
+
+ // Verify macros were substituted in wake endpoints
+ assert.Equal(t, "/wake_up?port=10000", model.WakeEndpoints[0].Endpoint)
+ assert.Equal(t, `{"model_id": "vllm-model", "port": 10000}`, model.WakeEndpoints[0].Body)
+}
+
+func TestConfig_SleepWakeTimeoutOverrides(t *testing.T) {
+ content := `
+startPort: 10000
+sleepRequestTimeout: 10
+wakeRequestTimeout: 15
+
+models:
+ test-model:
+ cmd: server --port ${PORT}
+ sleepEndpoints:
+ - endpoint: /sleep
+ method: POST
+ timeout: 30
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+ timeout: 45
+ - endpoint: /reset
+ method: POST
+ # This one should inherit global timeout
+`
+
+ config, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.NoError(t, err)
+
+ model := config.Models["test-model"]
+
+ // Verify per-endpoint timeout overrides global
+ assert.Equal(t, 30, model.SleepEndpoints[0].Timeout)
+ assert.Equal(t, 45, model.WakeEndpoints[0].Timeout)
+
+ // Verify second wake endpoint inherits global timeout
+ assert.Equal(t, 15, model.WakeEndpoints[1].Timeout)
+}
+
+func TestConfig_SleepWakeMultiStepWake(t *testing.T) {
+ content := `
+startPort: 10000
+
+models:
+ vllm-level2:
+ cmd: python -m vllm --port ${PORT} --enable-sleep-mode
+ sleepEndpoints:
+ - endpoint: /sleep?level=2
+ method: POST
+ timeout: 15
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+ - endpoint: /collective_rpc
+ method: POST
+ body: '{"method": "reload_weights"}'
+ timeout: 12
+ - endpoint: /reset_prefix_cache
+ method: POST
+`
+
+ config, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.NoError(t, err)
+
+ model := config.Models["vllm-level2"]
+
+ // Verify sleep endpoint
+ assert.Len(t, model.SleepEndpoints, 1)
+ assert.Equal(t, "/sleep?level=2", model.SleepEndpoints[0].Endpoint)
+ assert.Equal(t, 15, model.SleepEndpoints[0].Timeout)
+
+ // Verify multi-step wake sequence
+ assert.Len(t, model.WakeEndpoints, 3)
+
+ // Step 1: Wake up
+ assert.Equal(t, "/wake_up", model.WakeEndpoints[0].Endpoint)
+ assert.Equal(t, "POST", model.WakeEndpoints[0].Method)
+ assert.Equal(t, 10, model.WakeEndpoints[0].Timeout) // default
+
+ // Step 2: Reload weights
+ assert.Equal(t, "/collective_rpc", model.WakeEndpoints[1].Endpoint)
+ assert.Equal(t, "POST", model.WakeEndpoints[1].Method)
+ assert.Equal(t, `{"method": "reload_weights"}`, model.WakeEndpoints[1].Body)
+ assert.Equal(t, 12, model.WakeEndpoints[1].Timeout)
+
+ // Step 3: Reset cache
+ assert.Equal(t, "/reset_prefix_cache", model.WakeEndpoints[2].Endpoint)
+ assert.Equal(t, "POST", model.WakeEndpoints[2].Method)
+ assert.Equal(t, 10, model.WakeEndpoints[2].Timeout) // default
+}
+
+func TestConfig_SleepWakeUnknownMacroInEndpoints(t *testing.T) {
+ content := `
+startPort: 10000
+
+models:
+ test-model:
+ cmd: server --port ${PORT}
+ sleepEndpoints:
+ - endpoint: /sleep?level=${UNKNOWN_MACRO}
+ method: POST
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+`
+
+ _, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "UNKNOWN_MACRO")
+ assert.Contains(t, err.Error(), "test-model")
+ assert.Contains(t, err.Error(), "sleepEndpoints")
+}
+
+func TestConfig_SleepWakeDefaultTimeouts(t *testing.T) {
+ content := `
+startPort: 10000
+
+models:
+ test-model:
+ cmd: server --port ${PORT}
+ sleepEndpoints:
+ - endpoint: /sleep
+ method: POST
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+`
+
+ config, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.NoError(t, err)
+
+ model := config.Models["test-model"]
+
+ // Verify default timeouts are applied (10 seconds from config.go defaults)
+ assert.Equal(t, 10, model.SleepEndpoints[0].Timeout)
+ assert.Equal(t, 10, model.WakeEndpoints[0].Timeout)
+}
+
+func TestConfig_SleepWakeModelLevelMacros(t *testing.T) {
+ content := `
+startPort: 10000
+macros:
+ LEVEL: "1"
+
+models:
+ model1:
+ macros:
+ LEVEL: "2"
+ cmd: server --port ${PORT}
+ sleepEndpoints:
+ - endpoint: /sleep?level=${LEVEL}
+ method: POST
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+
+ model2:
+ cmd: server --port ${PORT}
+ sleepEndpoints:
+ - endpoint: /sleep?level=${LEVEL}
+ method: POST
+ wakeEndpoints:
+ - endpoint: /wake_up
+ method: POST
+`
+
+ config, err := LoadConfigFromReader(strings.NewReader(content))
+ assert.NoError(t, err)
+
+ // model1 should use model-level macro override
+ assert.Equal(t, "/sleep?level=2", config.Models["model1"].SleepEndpoints[0].Endpoint)
+
+ // model2 should use global macro
+ assert.Equal(t, "/sleep?level=1", config.Models["model2"].SleepEndpoints[0].Endpoint)
+}
diff --git a/proxy/config/config_windows_test.go b/proxy/config/config_windows_test.go
index 9e633a70..79010258 100644
--- a/proxy/config/config_windows_test.go
+++ b/proxy/config/config_windows_test.go
@@ -171,6 +171,7 @@ groups:
Aliases: []string{"m1", "model-one"},
Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health",
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
"model2": {
@@ -180,6 +181,7 @@ groups:
Aliases: []string{"m2"},
Env: []string{},
CheckEndpoint: "/",
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
"model3": {
@@ -189,6 +191,7 @@ groups:
Aliases: []string{"mthree"},
Env: []string{},
CheckEndpoint: "/",
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
"model4": {
@@ -198,11 +201,14 @@ groups:
CheckEndpoint: "/",
Aliases: []string{},
Env: []string{},
+ SleepMode: SleepModeDisable,
SendLoadingState: &modelLoadingState,
},
},
- HealthCheckTimeout: 15,
- MetricsMaxInMemory: 1000,
+ HealthCheckTimeout: 15,
+ SleepRequestTimeout: 10,
+ WakeRequestTimeout: 10,
+ MetricsMaxInMemory: 1000,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go
index f1c79e31..f8e0d783 100644
--- a/proxy/config/model_config.go
+++ b/proxy/config/model_config.go
@@ -2,11 +2,28 @@ package config
import (
"errors"
+ "fmt"
"runtime"
"slices"
"strings"
)
+// HTTPEndpoint represents a single HTTP endpoint configuration
+type HTTPEndpoint struct {
+ Endpoint string `yaml:"endpoint"` // URL path (e.g., "/wake_up")
+ Method string `yaml:"method"` // HTTP method (GET, POST, PUT, PATCH)
+ Body string `yaml:"body"` // Optional request body (JSON string)
+ Timeout int `yaml:"timeout"` // Optional per-endpoint timeout (seconds)
+}
+
+// SleepMode represents the sleep/wake behavior mode
+type SleepMode string
+
+const (
+ SleepModeEnable SleepMode = SleepMode("enable")
+ SleepModeDisable SleepMode = SleepMode("disable")
+)
+
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
@@ -18,6 +35,15 @@ type ModelConfig struct {
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
+ // SleepMode explicitly controls sleep/wake behavior
+ // Valid values: SleepModeEnable, SleepModeDisable
+ // Future values may include: "auto", "level1", "level2"
+ SleepMode SleepMode `yaml:"sleepMode"`
+
+ // Array-based sleep/wake configuration
+ SleepEndpoints []HTTPEndpoint `yaml:"sleepEndpoints"`
+ WakeEndpoints []HTTPEndpoint `yaml:"wakeEndpoints"`
+
// #179 for /v1/models
Name string `yaml:"name"`
Description string `yaml:"description"`
@@ -55,6 +81,7 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
ConcurrencyLimit: 0,
Name: "",
Description: "",
+ SleepMode: SleepModeDisable,
}
// the default cmdStop to taskkill /f /t /pid ${PID}
@@ -67,6 +94,65 @@ func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
}
*m = ModelConfig(defaults)
+
+ // Validate sleepMode field
+ switch m.SleepMode {
+ case SleepModeEnable, SleepModeDisable:
+ // Valid values
+ default:
+ return fmt.Errorf("invalid sleepMode value '%s': must be 'enable' or 'disable'", m.SleepMode)
+ }
+
+ // Require endpoints when sleepMode is "enable"
+ if m.SleepMode == SleepModeEnable {
+ if len(m.SleepEndpoints) == 0 {
+ return errors.New("sleepEndpoints required when sleepMode is 'enable'")
+ }
+ if len(m.WakeEndpoints) == 0 {
+ return errors.New("wakeEndpoints required when sleepMode is 'enable'")
+ }
+ }
+
+ // Validate and normalize each endpoint
+ for i := range m.SleepEndpoints {
+ if err := m.validateEndpoint(&m.SleepEndpoints[i]); err != nil {
+ return fmt.Errorf("sleepEndpoints[%d]: %v", i, err)
+ }
+ }
+
+ for i := range m.WakeEndpoints {
+ if err := m.validateEndpoint(&m.WakeEndpoints[i]); err != nil {
+ return fmt.Errorf("wakeEndpoints[%d]: %v", i, err)
+ }
+ }
+
+ return nil
+}
+
+func (m *ModelConfig) validateEndpoint(ep *HTTPEndpoint) error {
+ // Endpoint path is required
+ if ep.Endpoint == "" {
+ return errors.New("endpoint path is required")
+ }
+
+ // Default method to POST if not specified
+ if ep.Method == "" {
+ ep.Method = "POST"
+ }
+
+ // Validate HTTP method
+ validMethods := map[string]bool{"GET": true, "POST": true, "PUT": true, "PATCH": true}
+ upperMethod := strings.ToUpper(ep.Method)
+ if !validMethods[upperMethod] {
+ return fmt.Errorf("invalid method %q (must be GET, POST, PUT, or PATCH)", ep.Method)
+ }
+ ep.Method = upperMethod
+
+ // Timeout validation (must be non-negative)
+ if ep.Timeout < 0 {
+ return fmt.Errorf("timeout must be non-negative, got %d", ep.Timeout)
+ }
+
return nil
}
diff --git a/proxy/process.go b/proxy/process.go
index 91bfbd44..d87baa80 100644
--- a/proxy/process.go
+++ b/proxy/process.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"math/rand"
"net"
"net/http"
@@ -31,6 +32,14 @@ const (
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
+
+ // sleep/wake states
+ StateSleepPending ProcessState = ProcessState("sleepPending")
+ StateAsleep ProcessState = ProcessState("asleep")
+ StateWaking ProcessState = ProcessState("waking")
+
+ // httpDialTimeout is the timeout for establishing TCP connections
+ httpDialTimeout = 500 * time.Millisecond
)
type StopStrategy int
@@ -71,6 +80,12 @@ type Process struct {
// used to block on multiple start() calls
waitStarting sync.WaitGroup
+ // used to block on multiple Sleep() calls
+ waitSleeping sync.WaitGroup
+
+ // used to block on multiple Wake() calls
+ waitWaking sync.WaitGroup
+
// for managing concurrency limits
concurrencyLimitSemaphore chan struct{}
@@ -170,10 +185,15 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
p.state = newState
- // Atomically increment waitStarting when entering StateStarting
- // This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
- if newState == StateStarting {
+ // Atomically increment WaitGroups when entering transitional states
+ // This ensures any thread that sees the transitional state will also see the WaitGroup counter incremented
+ switch newState {
+ case StateStarting:
p.waitStarting.Add(1)
+ case StateSleepPending:
+ p.waitSleeping.Add(1)
+ case StateWaking:
+ p.waitWaking.Add(1)
}
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
@@ -189,11 +209,17 @@ func isValidTransition(from, to ProcessState) bool {
case StateStarting:
return to == StateReady || to == StateStopping || to == StateStopped
case StateReady:
- return to == StateStopping
+ return to == StateStopping || to == StateSleepPending
case StateStopping:
return to == StateStopped || to == StateShutdown
case StateShutdown:
- return false // No transitions allowed from these states
+ return false // No transitions allowed from this state
+ case StateSleepPending:
+ return to == StateAsleep || to == StateStopping || to == StateStopped
+ case StateAsleep:
+ return to == StateWaking || to == StateStopping
+ case StateWaking:
+ return to == StateReady || to == StateStopping || to == StateStopped
}
return false
}
@@ -213,6 +239,25 @@ func (p *Process) forceState(newState ProcessState) {
p.state = newState
}
+func (p *Process) makeReady() error {
+ currentState := p.CurrentState()
+ if currentState == StateSleepPending || currentState == StateAsleep || currentState == StateWaking {
+ p.proxyLogger.Debugf("<%s> Process is sleeping, sleep pending, or already waking, use wake() instead of start()", p.ID)
+ return p.wake()
+ } else {
+ return p.start()
+ }
+}
+
+// MakeIdle transitions the process to an idle state, using sleep mode if configured, otherwise stopping.
+func (p *Process) MakeIdle() {
+ if p.isSleepEnabled() {
+ p.Sleep()
+ } else {
+ p.Stop()
+ }
+}
+
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
@@ -296,10 +341,9 @@ func (p *Process) start() error {
// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
- proxyTo := p.config.Proxy
- healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
+ healthURL, err := p.buildFullURL(checkEndpoint)
if err != nil {
- return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
+ return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", p.config.Proxy, checkEndpoint)
}
// Ready Check loop
@@ -317,7 +361,7 @@ func (p *Process) start() error {
return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds())
}
- if err := p.checkHealthEndpoint(healthURL); err == nil {
+ if err := p.checkHealthEndpoint(checkEndpoint); err == nil {
p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL)
break
} else {
@@ -332,6 +376,17 @@ func (p *Process) start() error {
}
}
+ if curState, err := p.swapState(StateStarting, StateReady); err != nil {
+ return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
+ } else {
+ p.failedStartCount = 0
+ p.startUnloadMonitoring()
+ return nil
+ }
+}
+
+// startUnloadMonitoring begins TTL monitoring for automatic model unloading.
+func (p *Process) startUnloadMonitoring() {
if p.config.UnloadAfter > 0 {
// start a goroutine to check every second if
// the process should be stopped
@@ -339,7 +394,11 @@ func (p *Process) start() error {
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for range time.Tick(time.Second) {
- if p.CurrentState() != StateReady {
+ curState := p.CurrentState()
+ if curState != StateReady &&
+ curState != StateSleepPending &&
+ curState != StateAsleep &&
+ curState != StateWaking {
return
}
@@ -356,13 +415,6 @@ func (p *Process) start() error {
}
}()
}
-
- if curState, err := p.swapState(StateStarting, StateReady); err != nil {
- return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err)
- } else {
- p.failedStartCount = 0
- return nil
- }
}
// Stop will wait for inflight requests to complete before stopping the process.
@@ -380,13 +432,14 @@ func (p *Process) Stop() {
// StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM.
// If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL.
func (p *Process) StopImmediately() {
- if !isValidTransition(p.CurrentState(), StateStopping) {
+ initState := p.CurrentState()
+ if !isValidTransition(initState, StateStopping) {
return
}
- p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState())
- if curState, err := p.swapState(StateReady, StateStopping); err != nil {
- p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState)
+ p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, initState)
+ if curState, err := p.swapState(initState, StateStopping); err != nil {
+ p.proxyLogger.Infof("<%s> Stop() %v -> StateStopping err: %v, current state: %v", p.ID, initState, err, curState)
return
}
@@ -407,6 +460,149 @@ func (p *Process) Shutdown() {
p.forceState(StateShutdown)
}
+// sendSleepRequests sends all sleep requests in sequence
+func (p *Process) sendSleepRequests() error {
+ if len(p.config.SleepEndpoints) == 0 {
+ return fmt.Errorf("no sleep endpoints configured")
+ }
+
+ p.proxyLogger.Infof("<%s> Executing %d sleep request(s)", p.ID, len(p.config.SleepEndpoints))
+
+ for i, endpoint := range p.config.SleepEndpoints {
+ p.proxyLogger.Debugf("<%s> Sleep step %d/%d: %s %s (timeout: %ds)",
+ p.ID, i+1, len(p.config.SleepEndpoints), endpoint.Method, endpoint.Endpoint, endpoint.Timeout)
+
+ if err := p.sendHTTPRequest(endpoint); err != nil {
+ return fmt.Errorf("sleep step %d/%d failed: %v", i+1, len(p.config.SleepEndpoints), err)
+ }
+ }
+
+ p.proxyLogger.Infof("<%s> All %d sleep request(s) completed successfully",
+ p.ID, len(p.config.SleepEndpoints))
+ return nil
+}
+
+// sendWakeRequests sends all wake requests in sequence
+func (p *Process) sendWakeRequests() error {
+ if len(p.config.WakeEndpoints) == 0 {
+ return fmt.Errorf("no wake endpoints configured")
+ }
+
+ p.proxyLogger.Infof("<%s> Executing %d wake request(s)", p.ID, len(p.config.WakeEndpoints))
+
+ for i, endpoint := range p.config.WakeEndpoints {
+ p.proxyLogger.Debugf("<%s> Wake step %d/%d: %s %s (timeout: %ds)",
+ p.ID, i+1, len(p.config.WakeEndpoints), endpoint.Method, endpoint.Endpoint, endpoint.Timeout)
+
+ if err := p.sendHTTPRequest(endpoint); err != nil {
+ return fmt.Errorf("wake step %d/%d failed: %v", i+1, len(p.config.WakeEndpoints), err)
+ }
+ }
+
+ p.proxyLogger.Infof("<%s> All %d wake request(s) completed successfully",
+ p.ID, len(p.config.WakeEndpoints))
+ return nil
+}
+
+// isSleepEnabled returns true if sleep mode is explicitly enabled
+func (p *Process) isSleepEnabled() bool {
+ return p.config.SleepMode == config.SleepModeEnable
+}
+
+// Sleep transitions the process to a sleeping state by executing sleep HTTP requests if defined.
+func (p *Process) Sleep() {
+ if !p.isSleepEnabled() {
+ p.proxyLogger.Errorf("<%s> sleep not configured", p.ID)
+ return
+ }
+
+ currentState := p.CurrentState()
+
+ // If sleep is already in progress, wait for it to complete
+ if currentState == StateSleepPending {
+ p.proxyLogger.Debugf("<%s> Sleep already in progress, waiting for completion", p.ID)
+ p.waitSleeping.Wait()
+ if state := p.CurrentState(); state == StateAsleep {
+ p.proxyLogger.Debugf("<%s> Sleep completed by concurrent call", p.ID)
+ return
+ } else {
+ p.proxyLogger.Warnf("<%s> Sleep operation failed, state: %v", p.ID, state)
+ return
+ }
+ }
+
+ if !isValidTransition(currentState, StateSleepPending) {
+ p.proxyLogger.Warnf("<%s> Cannot sleep from state %s", p.ID, currentState)
+ return
+ }
+
+ p.proxyLogger.Debugf("<%s> Sleep(): Waiting for inflight requests to complete", p.ID)
+ p.inFlightRequests.Wait()
+
+ if curState, err := p.swapState(StateReady, StateSleepPending); err != nil {
+ p.proxyLogger.Warnf("<%s> failed to transition to sleep pending: current state: %v, error: %v", p.ID, curState, err)
+ return
+ }
+
+ // waitSleeping.Add(1) is called atomically in swapState()
+ defer p.waitSleeping.Done()
+
+ sleepStartTime := time.Now()
+ if err := p.sendSleepRequests(); err != nil {
+ p.proxyLogger.Errorf("<%s> sendSleepRequests failed, falling back to StopImmediately(): %v", p.ID, err)
+ p.StopImmediately()
+ return
+ }
+
+ if curState, err := p.swapState(StateSleepPending, StateAsleep); err != nil {
+ p.proxyLogger.Errorf("<%s> failed to transition to asleep: current state: %v, error: %v", p.ID, curState, err)
+ // If we can't transition to asleep, fall back to stopping
+ p.StopImmediately()
+ return
+ }
+
+ p.proxyLogger.Infof("<%s> Model sleep completed in %v", p.ID, time.Since(sleepStartTime))
+}
+
+// wake transitions the process from asleep to ready.
+func (p *Process) wake() error {
+ if curState, err := p.swapState(StateAsleep, StateWaking); err != nil {
+ if err == ErrExpectedStateMismatch {
+ // already waking, just wait for it to complete and expect
+ // it to be be in the Ready start after. If not, return an error
+ if curState == StateWaking {
+ p.waitWaking.Wait()
+ if state := p.CurrentState(); state == StateReady {
+ return nil
+ } else {
+ return fmt.Errorf("process was already waking but wound up in state %v", state)
+ }
+ } else {
+ return fmt.Errorf("processes was in state %v when wake() was called", curState)
+ }
+ } else {
+ return fmt.Errorf("failed to set Process state to waking: current state: %v, error: %v", curState, err)
+ }
+ }
+
+ // waitWaking.Add(1) is called atomically in swapState()
+ defer p.waitWaking.Done()
+
+ wakeStartTime := time.Now()
+ if err := p.sendWakeRequests(); err != nil {
+ p.proxyLogger.Errorf("<%s> sendWakeRequests failed, falling back to restarting the process: %v", p.ID, err)
+ p.StopImmediately()
+ return p.start()
+ }
+
+ if curState, err := p.swapState(StateWaking, StateReady); err != nil {
+ return fmt.Errorf("failed to transition to ready after wake: current state: %v, error: %v", curState, err)
+ }
+
+ p.proxyLogger.Infof("<%s> Model wake completed in %v", p.ID, time.Since(wakeStartTime))
+ return nil
+}
+
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand() {
@@ -429,24 +625,56 @@ func (p *Process) stopCommand() {
<-cmdWaitChan
}
-func (p *Process) checkHealthEndpoint(healthURL string) error {
+// buildFullURL builds a full URL from the proxy base URL and an endpoint path
+func (p *Process) buildFullURL(endpoint string) (string, error) {
+ if endpoint == "" {
+ return "", fmt.Errorf("endpoint is empty")
+ }
+
+ baseURL, err := url.Parse(p.config.Proxy)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse proxy URL: %v", err)
+ }
+
+ endpointURL, err := url.Parse(endpoint)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse endpoint: %v", err)
+ }
+
+ return baseURL.ResolveReference(endpointURL).String(), nil
+}
+
+// sendHTTPRequest sends a single HTTP request based on endpoint config
+func (p *Process) sendHTTPRequest(endpoint config.HTTPEndpoint) error {
+ fullURL, err := p.buildFullURL(endpoint.Endpoint)
+ if err != nil {
+ return err
+ }
+
+ timeout := time.Duration(endpoint.Timeout) * time.Second
+ // Create HTTP client with timeout
client := &http.Client{
- // wait a short time for a tcp connection to be established
Transport: &http.Transport{
DialContext: (&net.Dialer{
- Timeout: 500 * time.Millisecond,
+ Timeout: httpDialTimeout,
}).DialContext,
},
+ Timeout: timeout,
+ }
- // give a long time to respond to the health check endpoint
- // after the connection is established. See issue: 276
- Timeout: 5000 * time.Millisecond,
+ var bodyReader io.Reader
+ if endpoint.Body != "" {
+ bodyReader = strings.NewReader(endpoint.Body)
}
- req, err := http.NewRequest("GET", healthURL, nil)
+ req, err := http.NewRequest(endpoint.Method, fullURL, bodyReader)
if err != nil {
- return err
+ return fmt.Errorf("failed to create request: %v", err)
+ }
+
+ if endpoint.Body != "" {
+ req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
@@ -455,7 +683,6 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
}
defer resp.Body.Close()
- // got a response but it was not an OK
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code: %d", resp.StatusCode)
}
@@ -463,6 +690,19 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
return nil
}
+func (p *Process) checkHealthEndpoint(endpoint string) error {
+ // Create HTTP endpoint config for health check
+ // Health check gets 5 seconds to respond after connection is established (see issue: 276)
+ healthEndpoint := config.HTTPEndpoint{
+ Method: "GET",
+ Endpoint: endpoint,
+ Timeout: 5,
+ Body: "",
+ }
+
+ return p.sendHTTPRequest(healthEndpoint)
+}
+
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if p.reverseProxy == nil {
@@ -475,7 +715,7 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
- if currentState == StateShutdown || currentState == StateStopping {
+ if currentState == StateShutdown || currentState == StateStopping || currentState == StateSleepPending {
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
return
}
@@ -514,8 +754,8 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
}
beginStartTime := time.Now()
- if err := p.start(); err != nil {
- errstr := fmt.Sprintf("unable to start process: %s", err)
+ if err := p.makeReady(); err != nil {
+ errstr := fmt.Sprintf("unable to makeReady process: %s", err)
cancelLoadCtx()
if srw != nil {
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
diff --git a/proxy/process_test.go b/proxy/process_test.go
index de226614..0ec60969 100644
--- a/proxy/process_test.go
+++ b/proxy/process_test.go
@@ -103,7 +103,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code)
- assert.Contains(t, w.Body.String(), "unable to start process")
+ assert.Contains(t, w.Body.String(), "unable to makeReady process")
w = httptest.NewRecorder()
process.ProxyRequest(w, req)
@@ -260,6 +260,18 @@ func TestProcess_SwapState(t *testing.T) {
{"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown},
{"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown},
{"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped},
+ // Sleep/wake state transitions
+ {"Ready to SleepPending", StateReady, StateReady, StateSleepPending, nil, StateSleepPending},
+ {"SleepPending to Asleep", StateSleepPending, StateSleepPending, StateAsleep, nil, StateAsleep},
+ {"SleepPending to Stopping", StateSleepPending, StateSleepPending, StateStopping, nil, StateStopping},
+ {"SleepPending to Stopped", StateSleepPending, StateSleepPending, StateStopped, nil, StateStopped},
+ {"Asleep to Waking", StateAsleep, StateAsleep, StateWaking, nil, StateWaking},
+ {"Asleep to Stopping", StateAsleep, StateAsleep, StateStopping, nil, StateStopping},
+ {"Waking to Ready", StateWaking, StateWaking, StateReady, nil, StateReady},
+ {"Waking to Stopping", StateWaking, StateWaking, StateStopping, nil, StateStopping},
+ {"Waking to Stopped", StateWaking, StateWaking, StateStopped, nil, StateStopped},
+ {"Ready to Asleep Invalid", StateReady, StateReady, StateAsleep, ErrInvalidStateTransition, StateReady},
+ {"Asleep to Ready Invalid", StateAsleep, StateAsleep, StateReady, ErrInvalidStateTransition, StateAsleep},
}
for _, test := range tests {
@@ -565,3 +577,131 @@ func (w *panicOnWriteResponseWriter) Write(b []byte) (int, error) {
}
return w.ResponseRecorder.Write(b)
}
+
+// TestProcess_SleepAndWakeBasic tests the basic sleep/wake cycle
+func TestProcess_SleepAndWakeBasic(t *testing.T) {
+ expectedMessage := "sleep_wake_test"
+
+ // Get base config and modify sleep/wake fields (like TestProcess_StopCmd modifies CmdStop)
+ cfg := getTestSimpleResponderConfig(expectedMessage)
+ cfg.SleepMode = config.SleepModeEnable
+ cfg.SleepEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/sleep", Method: "POST", Timeout: 5},
+ }
+ cfg.WakeEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/wake_up", Method: "POST", Timeout: 5},
+ }
+
+ process := NewProcess("sleep-wake-basic", 5, cfg, debugLogger, debugLogger)
+ defer process.Stop()
+
+ // Start the process
+ err := process.start()
+ assert.Nil(t, err)
+ assert.Equal(t, StateReady, process.CurrentState())
+
+ // Put it to sleep
+ process.Sleep()
+ assert.Equal(t, StateAsleep, process.CurrentState())
+
+ // Wake it up
+ err = process.wake()
+ assert.Nil(t, err)
+ assert.Equal(t, StateReady, process.CurrentState())
+}
+
+// TestProcess_MultiStepWakeSequence tests multi-step wake sequences like vLLM level 2
+func TestProcess_MultiStepWakeSequence(t *testing.T) {
+ expectedMessage := "multi_step_wake"
+
+ cfg := getTestSimpleResponderConfig(expectedMessage)
+ cfg.SleepMode = config.SleepModeEnable
+ cfg.SleepEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/sleep?level=2", Method: "POST", Timeout: 5},
+ }
+ cfg.WakeEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/wake_up", Method: "POST", Timeout: 5},
+ {Endpoint: "/collective_rpc", Method: "POST", Body: `{"method": "reload_weights"}`, Timeout: 5},
+ {Endpoint: "/reset_prefix_cache", Method: "POST", Timeout: 5},
+ }
+
+ process := NewProcess("multi-step-wake", 5, cfg, debugLogger, debugLogger)
+ defer process.Stop()
+
+ // Start and sleep the process
+ err := process.start()
+ assert.Nil(t, err)
+ process.Sleep()
+ assert.Equal(t, StateAsleep, process.CurrentState())
+
+ // Wake it up - should execute all three steps successfully
+ err = process.wake()
+ assert.Nil(t, err)
+ assert.Equal(t, StateReady, process.CurrentState())
+}
+
+// TestProcess_SleepInsteadOfStopWithSwap tests that sleep is used instead of Stop when swapping models
+func TestProcess_SleepInsteadOfStopWithSwap(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping sleep/wake swap test")
+ }
+
+ expectedMessage := "sleep_swap_test"
+
+ cfg := getTestSimpleResponderConfig(expectedMessage)
+ cfg.SleepMode = config.SleepModeEnable
+ cfg.SleepEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/sleep", Method: "POST", Timeout: 5},
+ }
+ cfg.WakeEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/wake_up", Method: "POST", Timeout: 5},
+ }
+
+ process := NewProcess("sleep-swap", 5, cfg, debugLogger, debugLogger)
+ defer process.Stop()
+
+ // Start the process
+ err := process.start()
+ assert.Nil(t, err)
+ assert.Equal(t, StateReady, process.CurrentState())
+
+ // Call MakeIdle which should use Sleep instead of Stop
+ process.MakeIdle()
+
+ // Process should be asleep, not stopped
+ assert.Equal(t, StateAsleep, process.CurrentState(), "Process should be asleep")
+}
+
+// TestProcess_WakeFailureFallsBackToStart tests that wake failures trigger a full restart
+func TestProcess_WakeFailureFallsBackToStart(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping wake failure test")
+ }
+
+ expectedMessage := "wake_failure_test"
+
+ cfg := getTestSimpleResponderConfig(expectedMessage)
+ cfg.SleepMode = config.SleepModeEnable
+ cfg.SleepEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/sleep", Method: "POST", Timeout: 5},
+ }
+ cfg.WakeEndpoints = []config.HTTPEndpoint{
+ {Endpoint: "/wake_up_fail", Method: "POST", Timeout: 5}, // Use failing endpoint
+ }
+
+ process := NewProcess("wake-failure", 5, cfg, debugLogger, debugLogger)
+ defer process.Stop()
+
+ // Start and sleep the process
+ err := process.start()
+ assert.Nil(t, err)
+ process.Sleep()
+ assert.Equal(t, StateAsleep, process.CurrentState())
+
+ // Try to wake - should fall back to start()
+ err = process.wake()
+ assert.Nil(t, err)
+
+ // Process should be ready
+ assert.Equal(t, StateReady, process.CurrentState())
+}
diff --git a/proxy/processgroup.go b/proxy/processgroup.go
index e0b06008..f4827b29 100644
--- a/proxy/processgroup.go
+++ b/proxy/processgroup.go
@@ -65,7 +65,8 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter,
// is there something already running?
if pg.lastUsedProcess != "" {
- pg.processes[pg.lastUsedProcess].Stop()
+ lastProcess := pg.processes[pg.lastUsedProcess]
+ lastProcess.MakeIdle()
}
// wait for the request to the new model to be fully handled
@@ -111,6 +112,26 @@ func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error
return nil
}
+func (pg *ProcessGroup) SleepProcess(modelID string) error {
+ pg.Lock()
+
+ process, exists := pg.processes[modelID]
+ if !exists {
+ pg.Unlock()
+ return fmt.Errorf("process not found for %s", modelID)
+ }
+
+ if !process.isSleepEnabled() {
+ pg.Unlock()
+ return fmt.Errorf("model does not support sleep mode")
+ }
+
+ pg.Unlock()
+
+ process.Sleep()
+ return nil
+}
+
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock()
defer pg.Unlock()
diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go
index a296ee8c..1614781b 100644
--- a/proxy/proxymanager_api.go
+++ b/proxy/proxymanager_api.go
@@ -18,6 +18,7 @@ type Model struct {
Description string `json:"description"`
State string `json:"state"`
Unlisted bool `json:"unlisted"`
+ SleepMode string `json:"sleepMode"`
}
func addApiHandlers(pm *ProxyManager) {
@@ -26,6 +27,7 @@ func addApiHandlers(pm *ProxyManager) {
{
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
+ apiGroup.POST("/models/sleep/*model", pm.apiSleepSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics)
apiGroup.GET("/version", pm.apiGetVersion)
@@ -67,6 +69,12 @@ func (pm *ProxyManager) getModelStatus() []Model {
stateStr = "shutdown"
case StateStopped:
stateStr = "stopped"
+ case StateSleepPending:
+ stateStr = "sleepPending"
+ case StateAsleep:
+ stateStr = "asleep"
+ case StateWaking:
+ stateStr = "waking"
default:
stateStr = "unknown"
}
@@ -79,6 +87,7 @@ func (pm *ProxyManager) getModelStatus() []Model {
Description: pm.config.Models[modelID].Description,
State: state,
Unlisted: pm.config.Models[modelID].Unlisted,
+ SleepMode: string(pm.config.Models[modelID].SleepMode),
})
}
@@ -229,6 +238,28 @@ func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
}
}
+func (pm *ProxyManager) apiSleepSingleModelHandler(c *gin.Context) {
+ requestedModel := strings.TrimPrefix(c.Param("model"), "/")
+ realModelName, found := pm.config.RealModelName(requestedModel)
+ if !found {
+ pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
+ return
+ }
+
+ processGroup := pm.findGroupByModelName(realModelName)
+ if processGroup == nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
+ return
+ }
+
+ if err := processGroup.SleepProcess(realModelName); err != nil {
+ pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error sleeping process: %s", err.Error()))
+ return
+ } else {
+ c.String(http.StatusOK, "OK")
+ }
+}
+
func (pm *ProxyManager) apiGetVersion(c *gin.Context) {
c.JSON(http.StatusOK, map[string]string{
"version": pm.version,
diff --git a/ui/src/contexts/APIProvider.tsx b/ui/src/contexts/APIProvider.tsx
index 3740a1f6..49b5a2e5 100644
--- a/ui/src/contexts/APIProvider.tsx
+++ b/ui/src/contexts/APIProvider.tsx
@@ -1,7 +1,7 @@
import { createContext, useState, useContext, useEffect, useCallback, useMemo, type ReactNode } from "react";
import type { ConnectionState } from "../lib/types";
-type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "unknown";
+type ModelStatus = "ready" | "starting" | "stopping" | "stopped" | "shutdown" | "sleepPending" | "asleep" | "waking" | "unknown";
const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */
export interface Model {
@@ -10,6 +10,7 @@ export interface Model {
name: string;
description: string;
unlisted: boolean;
+ sleepMode: string;
}
interface APIProviderType {
@@ -18,6 +19,7 @@ interface APIProviderType {
unloadAllModels: () => Promise
@@ -164,23 +164,41 @@ function ModelsPanel() {
{showIdorName === "id" ? "Model ID" : "Name"}
-
+ Actions
State