diff --git a/config.example.yaml b/config.example.yaml index ab8030e4..e0d61830 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -185,7 +185,7 @@ models: # filters: a dictionary of filter settings # - optional, default: empty dictionary - # - only stripParams is currently supported + # - same capabilities as peer filters (stripParams, setParams) filters: # stripParams: a comma separated list of parameters to remove from the request # - optional, default: "" @@ -195,6 +195,16 @@ models: # - recommended to stick to sampling parameters stripParams: "temperature, top_p, top_k" + # setParams: a dictionary of parameters to set/override in requests + # - optional, default: empty dictionary + # - useful for enforcing specific parameter values + # - protected params like "model" cannot be overridden + # - values can be strings, numbers, booleans, arrays, or objects + setParams: + # Example: enforce specific sampling parameters + temperature: 0.7 + top_p: 0.9 + # metadata: a dictionary of arbitrary values that are included in /v1/models # - optional, default: empty dictionary # - while metadata can contains complex types it is recommended to keep it simple @@ -373,3 +383,23 @@ peers: - z-ai/glm-4.7 - moonshotai/kimi-k2-0905 - minimax/minimax-m2.1 + # filters: a dictionary of filter settings for peer requests + # - optional, default: empty dictionary + # - same capabilities as model filters (stripParams, setParams) + filters: + # stripParams: a comma separated list of parameters to remove from the request + # - optional, default: "" + # - useful for removing parameters that the peer doesn't support + # - the `model` parameter can never be removed + stripParams: "temperature, top_p" + + # setParams: a dictionary of parameters to set/override in requests to this peer + # - optional, default: empty dictionary + # - useful for injecting provider-specific settings like data retention policies + # - protected params like "model" cannot be overridden + # - values can be strings, numbers, booleans, arrays, or objects + setParams: + # Example: enforce zero-data-retention for OpenRouter + provider: + data_collection: "deny" + allow_fallbacks: false diff --git a/proxy/config/filters.go b/proxy/config/filters.go new file mode 100644 index 00000000..39900075 --- /dev/null +++ b/proxy/config/filters.go @@ -0,0 +1,81 @@ +package config + +import ( + "slices" + "sort" + "strings" +) + +// ProtectedParams is a list of parameters that cannot be set or stripped via filters +// These are protected to prevent breaking the proxy's ability to route requests correctly +var ProtectedParams = []string{"model"} + +// Filters contains filter settings for modifying request parameters +// Used by both models and peers +type Filters struct { + // StripParams is a comma-separated list of parameters to remove from requests + // The "model" parameter can never be removed + StripParams string `yaml:"stripParams"` + + // SetParams is a dictionary of parameters to set/override in requests + // Protected params (like "model") cannot be set + SetParams map[string]any `yaml:"setParams"` +} + +// SanitizedStripParams returns a sorted list of parameters to strip, +// with duplicates, empty strings, and protected params removed +func (f Filters) SanitizedStripParams() []string { + if f.StripParams == "" { + return nil + } + + params := strings.Split(f.StripParams, ",") + cleaned := make([]string, 0, len(params)) + seen := make(map[string]bool) + + for _, param := range params { + trimmed := strings.TrimSpace(param) + // Skip protected params, empty strings, and duplicates + if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] { + continue + } + seen[trimmed] = true + cleaned = append(cleaned, trimmed) + } + + if len(cleaned) == 0 { + return nil + } + + slices.Sort(cleaned) + return cleaned +} + +// SanitizedSetParams returns a copy of SetParams with protected params removed +// and keys sorted for consistent iteration order +func (f Filters) SanitizedSetParams() (map[string]any, []string) { + if len(f.SetParams) == 0 { + return nil, nil + } + + result := make(map[string]any, len(f.SetParams)) + keys := make([]string, 0, len(f.SetParams)) + + for key, value := range f.SetParams { + // Skip protected params + if slices.Contains(ProtectedParams, key) { + continue + } + result[key] = value + keys = append(keys, key) + } + + // Sort keys for consistent ordering + sort.Strings(keys) + + if len(result) == 0 { + return nil, nil + } + + return result, keys +} diff --git a/proxy/config/filters_test.go b/proxy/config/filters_test.go new file mode 100644 index 00000000..d1f54dcd --- /dev/null +++ b/proxy/config/filters_test.go @@ -0,0 +1,168 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilters_SanitizedStripParams(t *testing.T) { + tests := []struct { + name string + stripParams string + want []string + }{ + { + name: "empty string", + stripParams: "", + want: nil, + }, + { + name: "single param", + stripParams: "temperature", + want: []string{"temperature"}, + }, + { + name: "multiple params", + stripParams: "temperature, top_p, top_k", + want: []string{"temperature", "top_k", "top_p"}, // sorted + }, + { + name: "model param filtered", + stripParams: "model, temperature, top_p", + want: []string{"temperature", "top_p"}, + }, + { + name: "only model param", + stripParams: "model", + want: nil, + }, + { + name: "duplicates removed", + stripParams: "temperature, top_p, temperature", + want: []string{"temperature", "top_p"}, + }, + { + name: "extra whitespace", + stripParams: " temperature , top_p ", + want: []string{"temperature", "top_p"}, + }, + { + name: "empty values filtered", + stripParams: "temperature,,top_p,", + want: []string{"temperature", "top_p"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := Filters{StripParams: tt.stripParams} + got := f.SanitizedStripParams() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestFilters_SanitizedSetParams(t *testing.T) { + tests := []struct { + name string + setParams map[string]any + wantParams map[string]any + wantKeys []string + }{ + { + name: "empty setParams", + setParams: nil, + wantParams: nil, + wantKeys: nil, + }, + { + name: "empty map", + setParams: map[string]any{}, + wantParams: nil, + wantKeys: nil, + }, + { + name: "normal params", + setParams: map[string]any{ + "temperature": 0.7, + "top_p": 0.9, + }, + wantParams: map[string]any{ + "temperature": 0.7, + "top_p": 0.9, + }, + wantKeys: []string{"temperature", "top_p"}, + }, + { + name: "protected model param filtered", + setParams: map[string]any{ + "model": "should-be-filtered", + "temperature": 0.7, + }, + wantParams: map[string]any{ + "temperature": 0.7, + }, + wantKeys: []string{"temperature"}, + }, + { + name: "only protected param", + setParams: map[string]any{ + "model": "should-be-filtered", + }, + wantParams: nil, + wantKeys: nil, + }, + { + name: "complex nested values", + setParams: map[string]any{ + "provider": map[string]any{ + "data_collection": "deny", + "allow_fallbacks": false, + }, + "transforms": []string{"middle-out"}, + }, + wantParams: map[string]any{ + "provider": map[string]any{ + "data_collection": "deny", + "allow_fallbacks": false, + }, + "transforms": []string{"middle-out"}, + }, + wantKeys: []string{"provider", "transforms"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := Filters{SetParams: tt.setParams} + gotParams, gotKeys := f.SanitizedSetParams() + + assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch") + for i, key := range gotKeys { + assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i) + } + + if tt.wantParams == nil { + assert.Nil(t, gotParams, "expected nil params") + return + } + + assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch") + for key, wantValue := range tt.wantParams { + gotValue, exists := gotParams[key] + assert.True(t, exists, "missing key: %s", key) + // Simple comparison for basic types + switch v := wantValue.(type) { + case string, int, float64, bool: + assert.Equal(t, v, gotValue, "value mismatch for key %s", key) + } + } + }) + } +} + +func TestProtectedParams(t *testing.T) { + // Verify that "model" is protected + assert.Contains(t, ProtectedParams, "model") +} diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index f1c79e31..9dc37aea 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -3,8 +3,6 @@ package config import ( "errors" "runtime" - "slices" - "strings" ) type ModelConfig struct { @@ -74,16 +72,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) { return SanitizeCommand(m.Cmd) } -// ModelFilters see issue #174 +// ModelFilters embeds Filters and adds legacy support for strip_params field +// See issue #174 type ModelFilters struct { - StripParams string `yaml:"stripParams"` + Filters `yaml:",inline"` } func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error { type rawModelFilters ModelFilters - defaults := rawModelFilters{ - StripParams: "", - } + defaults := rawModelFilters{} if err := unmarshal(&defaults); err != nil { return err @@ -104,25 +101,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } +// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility +// Returns ([]string, error) to match existing API func (f ModelFilters) SanitizedStripParams() ([]string, error) { - if f.StripParams == "" { - return nil, nil - } - - params := strings.Split(f.StripParams, ",") - cleaned := make([]string, 0, len(params)) - seen := make(map[string]bool) - - for _, param := range params { - trimmed := strings.TrimSpace(param) - if trimmed == "model" || trimmed == "" || seen[trimmed] { - continue - } - seen[trimmed] = true - cleaned = append(cleaned, trimmed) - } - - // sort cleaned - slices.Sort(cleaned) - return cleaned, nil + return f.Filters.SanitizedStripParams(), nil } diff --git a/proxy/config/model_config_test.go b/proxy/config/model_config_test.go index 9f1e9b4f..32392952 100644 --- a/proxy/config/model_config_test.go +++ b/proxy/config/model_config_test.go @@ -72,3 +72,35 @@ models: assert.True(t, *config.Models["model2"].SendLoadingState) } } + +func TestConfig_ModelFiltersWithSetParams(t *testing.T) { + content := ` +models: + model1: + cmd: path/to/cmd --port ${PORT} + filters: + stripParams: "top_k" + setParams: + temperature: 0.7 + top_p: 0.9 + stop: + - "<|end|>" + - "<|stop|>" +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + + modelConfig := config.Models["model1"] + + // Check stripParams + stripParams, err := modelConfig.Filters.SanitizedStripParams() + assert.NoError(t, err) + assert.Equal(t, []string{"top_k"}, stripParams) + + // Check setParams + setParams, keys := modelConfig.Filters.SanitizedSetParams() + assert.NotNil(t, setParams) + assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys) + assert.Equal(t, 0.7, setParams["temperature"]) + assert.Equal(t, 0.9, setParams["top_p"]) +} diff --git a/proxy/config/peer.go b/proxy/config/peer.go index 4d5ecfb9..63b0aaf0 100644 --- a/proxy/config/peer.go +++ b/proxy/config/peer.go @@ -11,14 +11,16 @@ type PeerConfig struct { ProxyURL *url.URL `yaml:"-"` ApiKey string `yaml:"apiKey"` Models []string `yaml:"models"` + Filters Filters `yaml:"filters"` } func (c *PeerConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { type rawPeerConfig PeerConfig defaults := rawPeerConfig{ - Proxy: "", - ApiKey: "", - Models: []string{}, + Proxy: "", + ApiKey: "", + Models: []string{}, + Filters: Filters{}, } if err := unmarshal(&defaults); err != nil { diff --git a/proxy/config/peer_test.go b/proxy/config/peer_test.go index d02f619d..c1c455b7 100644 --- a/proxy/config/peer_test.go +++ b/proxy/config/peer_test.go @@ -137,3 +137,73 @@ func searchSubstring(s, substr string) bool { } return false } + +func TestPeerConfig_WithFilters(t *testing.T) { + yamlData := ` +proxy: https://openrouter.ai/api +apiKey: sk-test +models: + - model_a +filters: + setParams: + temperature: 0.7 + provider: + data_collection: deny +` + var config PeerConfig + err := yaml.Unmarshal([]byte(yamlData), &config) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if config.Filters.SetParams == nil { + t.Fatal("Filters.SetParams should not be nil") + } + + if config.Filters.SetParams["temperature"] != 0.7 { + t.Errorf("expected temperature 0.7, got %v", config.Filters.SetParams["temperature"]) + } + + provider, ok := config.Filters.SetParams["provider"].(map[string]any) + if !ok { + t.Fatal("provider should be a map") + } + if provider["data_collection"] != "deny" { + t.Errorf("expected data_collection deny, got %v", provider["data_collection"]) + } +} + +func TestPeerConfig_WithBothFilters(t *testing.T) { + yamlData := ` +proxy: https://openrouter.ai/api +apiKey: sk-test +models: + - model_a +filters: + stripParams: "temperature, top_p" + setParams: + max_tokens: 1000 +` + var config PeerConfig + err := yaml.Unmarshal([]byte(yamlData), &config) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check stripParams + stripParams := config.Filters.SanitizedStripParams() + if len(stripParams) != 2 { + t.Errorf("expected 2 strip params, got %d", len(stripParams)) + } + if stripParams[0] != "temperature" || stripParams[1] != "top_p" { + t.Errorf("unexpected strip params: %v", stripParams) + } + + // Check setParams + if config.Filters.SetParams == nil { + t.Fatal("Filters.SetParams should not be nil") + } + if config.Filters.SetParams["max_tokens"] != 1000 { + t.Errorf("expected max_tokens 1000, got %v", config.Filters.SetParams["max_tokens"]) + } +} diff --git a/proxy/peerproxy.go b/proxy/peerproxy.go index 876f6bff..5f76bf12 100644 --- a/proxy/peerproxy.go +++ b/proxy/peerproxy.go @@ -106,6 +106,20 @@ func (p *PeerProxy) HasPeerModel(modelID string) bool { return found } +// GetPeerFilters returns the filters for a peer model, or empty filters if not found +func (p *PeerProxy) GetPeerFilters(modelID string) config.Filters { + pp, found := p.proxyMap[modelID] + if !found { + return config.Filters{} + } + // Get the peer config using the peerID + peer, found := p.peers[pp.peerID] + if !found { + return config.Filters{} + } + return peer.Filters +} + func (p *PeerProxy) ListPeers() config.PeerDictionaryConfig { return p.peers } diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 687e83c5..cb13fb3d 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -650,13 +650,49 @@ func (pm *ProxyManager) proxyInferenceHandler(c *gin.Context) { } } + // issue #453 set/override parameters in the JSON body + setParams, setParamKeys := pm.config.Models[modelID].Filters.SanitizedSetParams() + for _, key := range setParamKeys { + pm.proxyLogger.Debugf("<%s> setting param: %s", modelID, key) + bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) + return + } + } + pm.proxyLogger.Debugf("ProxyManager using local Process for model: %s", requestedModel) nextHandler = processGroup.ProxyRequest } else if pm.peerProxy != nil && pm.peerProxy.HasPeerModel(requestedModel) { pm.proxyLogger.Debugf("ProxyManager using ProxyPeer for model: %s", requestedModel) modelID = requestedModel - nextHandler = pm.peerProxy.ProxyRequest + // issue #453 apply filters for peer requests + peerFilters := pm.peerProxy.GetPeerFilters(requestedModel) + + // Apply stripParams - remove specified parameters from request + stripParams := peerFilters.SanitizedStripParams() + for _, param := range stripParams { + pm.proxyLogger.Debugf("<%s> stripping param: %s", requestedModel, param) + bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stripping parameter %s from request", param)) + return + } + } + + // Apply setParams - set/override specified parameters in request + setParams, setParamKeys := peerFilters.SanitizedSetParams() + for _, key := range setParamKeys { + pm.proxyLogger.Debugf("<%s> setting param: %s", requestedModel, key) + bodyBytes, err = sjson.SetBytes(bodyBytes, key, setParams[key]) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error setting parameter %s in request", key)) + return + } + } + + nextHandler = pm.peerProxy.ProxyRequest } if nextHandler == nil { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index dace6f4f..2f01a0ff 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -966,7 +966,9 @@ func TestProxyManager_ChatContentLength(t *testing.T) { func TestProxyManager_FiltersStripParams(t *testing.T) { modelConfig := getTestSimpleResponderConfig("model1") modelConfig.Filters = config.ModelFilters{ - StripParams: "temperature, model, stream", + Filters: config.Filters{ + StripParams: "temperature, model, stream", + }, } config := config.AddDefaultGroupToConfig(config.Config{