Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ models:

# filters: a dictionary of filter settings
# - optional, default: empty dictionary
# - only stripParams is currently supported
# - same capabilities as peer filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
Expand All @@ -195,6 +195,16 @@ models:
# - recommended to stick to sampling parameters
stripParams: "temperature, top_p, top_k"

# setParams: a dictionary of parameters to set/override in requests
# - optional, default: empty dictionary
# - useful for enforcing specific parameter values
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce specific sampling parameters
temperature: 0.7
top_p: 0.9

# metadata: a dictionary of arbitrary values that are included in /v1/models
# - optional, default: empty dictionary
# - while metadata can contains complex types it is recommended to keep it simple
Expand Down Expand Up @@ -373,3 +383,23 @@ peers:
- z-ai/glm-4.7
- moonshotai/kimi-k2-0905
- minimax/minimax-m2.1
# filters: a dictionary of filter settings for peer requests
# - optional, default: empty dictionary
# - same capabilities as model filters (stripParams, setParams)
filters:
# stripParams: a comma separated list of parameters to remove from the request
# - optional, default: ""
# - useful for removing parameters that the peer doesn't support
# - the `model` parameter can never be removed
stripParams: "temperature, top_p"

# setParams: a dictionary of parameters to set/override in requests to this peer
# - optional, default: empty dictionary
# - useful for injecting provider-specific settings like data retention policies
# - protected params like "model" cannot be overridden
# - values can be strings, numbers, booleans, arrays, or objects
setParams:
# Example: enforce zero-data-retention for OpenRouter
provider:
data_collection: "deny"
allow_fallbacks: false
81 changes: 81 additions & 0 deletions proxy/config/filters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package config

import (
"slices"
"sort"
"strings"
)

// ProtectedParams is a list of parameters that cannot be set or stripped via filters
// These are protected to prevent breaking the proxy's ability to route requests correctly
var ProtectedParams = []string{"model"}

// Filters contains filter settings for modifying request parameters
// Used by both models and peers
type Filters struct {
// StripParams is a comma-separated list of parameters to remove from requests
// The "model" parameter can never be removed
StripParams string `yaml:"stripParams"`

// SetParams is a dictionary of parameters to set/override in requests
// Protected params (like "model") cannot be set
SetParams map[string]any `yaml:"setParams"`
}

// SanitizedStripParams returns a sorted list of parameters to strip,
// with duplicates, empty strings, and protected params removed
func (f Filters) SanitizedStripParams() []string {
if f.StripParams == "" {
return nil
}

params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)

for _, param := range params {
trimmed := strings.TrimSpace(param)
// Skip protected params, empty strings, and duplicates
if slices.Contains(ProtectedParams, trimmed) || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}

if len(cleaned) == 0 {
return nil
}

slices.Sort(cleaned)
return cleaned
}

// SanitizedSetParams returns a copy of SetParams with protected params removed
// and keys sorted for consistent iteration order
func (f Filters) SanitizedSetParams() (map[string]any, []string) {
if len(f.SetParams) == 0 {
return nil, nil
}

result := make(map[string]any, len(f.SetParams))
keys := make([]string, 0, len(f.SetParams))

for key, value := range f.SetParams {
// Skip protected params
if slices.Contains(ProtectedParams, key) {
continue
}
result[key] = value
keys = append(keys, key)
}

// Sort keys for consistent ordering
sort.Strings(keys)

if len(result) == 0 {
return nil, nil
}

return result, keys
}
168 changes: 168 additions & 0 deletions proxy/config/filters_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package config

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFilters_SanitizedStripParams(t *testing.T) {
tests := []struct {
name string
stripParams string
want []string
}{
{
name: "empty string",
stripParams: "",
want: nil,
},
{
name: "single param",
stripParams: "temperature",
want: []string{"temperature"},
},
{
name: "multiple params",
stripParams: "temperature, top_p, top_k",
want: []string{"temperature", "top_k", "top_p"}, // sorted
},
{
name: "model param filtered",
stripParams: "model, temperature, top_p",
want: []string{"temperature", "top_p"},
},
{
name: "only model param",
stripParams: "model",
want: nil,
},
{
name: "duplicates removed",
stripParams: "temperature, top_p, temperature",
want: []string{"temperature", "top_p"},
},
{
name: "extra whitespace",
stripParams: " temperature , top_p ",
want: []string{"temperature", "top_p"},
},
{
name: "empty values filtered",
stripParams: "temperature,,top_p,",
want: []string{"temperature", "top_p"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{StripParams: tt.stripParams}
got := f.SanitizedStripParams()
assert.Equal(t, tt.want, got)
})
}
}

func TestFilters_SanitizedSetParams(t *testing.T) {
tests := []struct {
name string
setParams map[string]any
wantParams map[string]any
wantKeys []string
}{
{
name: "empty setParams",
setParams: nil,
wantParams: nil,
wantKeys: nil,
},
{
name: "empty map",
setParams: map[string]any{},
wantParams: nil,
wantKeys: nil,
},
{
name: "normal params",
setParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantParams: map[string]any{
"temperature": 0.7,
"top_p": 0.9,
},
wantKeys: []string{"temperature", "top_p"},
},
{
name: "protected model param filtered",
setParams: map[string]any{
"model": "should-be-filtered",
"temperature": 0.7,
},
wantParams: map[string]any{
"temperature": 0.7,
},
wantKeys: []string{"temperature"},
},
{
name: "only protected param",
setParams: map[string]any{
"model": "should-be-filtered",
},
wantParams: nil,
wantKeys: nil,
},
{
name: "complex nested values",
setParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantParams: map[string]any{
"provider": map[string]any{
"data_collection": "deny",
"allow_fallbacks": false,
},
"transforms": []string{"middle-out"},
},
wantKeys: []string{"provider", "transforms"},
},
Comment on lines +116 to +133
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Complex nested values are not fully asserted.

The "complex nested values" test case at lines 117-133 includes a nested map and slice, but the assertion logic at lines 156-159 only compares basic types (string, int, float64, bool). The nested provider map and transforms slice are not actually verified for correctness.

Proposed fix to assert nested values
 			for key, wantValue := range tt.wantParams {
 				gotValue, exists := gotParams[key]
 				assert.True(t, exists, "missing key: %s", key)
-				// Simple comparison for basic types
-				switch v := wantValue.(type) {
-				case string, int, float64, bool:
-					assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
-				}
+				assert.Equal(t, wantValue, gotValue, "value mismatch for key %s", key)
 			}

Also applies to: 151-160

}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := Filters{SetParams: tt.setParams}
gotParams, gotKeys := f.SanitizedSetParams()

assert.Equal(t, len(tt.wantKeys), len(gotKeys), "keys length mismatch")
for i, key := range gotKeys {
assert.Equal(t, tt.wantKeys[i], key, "key mismatch at %d", i)
}

if tt.wantParams == nil {
assert.Nil(t, gotParams, "expected nil params")
return
}

assert.Equal(t, len(tt.wantParams), len(gotParams), "params length mismatch")
for key, wantValue := range tt.wantParams {
gotValue, exists := gotParams[key]
assert.True(t, exists, "missing key: %s", key)
// Simple comparison for basic types
switch v := wantValue.(type) {
case string, int, float64, bool:
assert.Equal(t, v, gotValue, "value mismatch for key %s", key)
}
}
})
}
}

func TestProtectedParams(t *testing.T) {
// Verify that "model" is protected
assert.Contains(t, ProtectedParams, "model")
}
34 changes: 7 additions & 27 deletions proxy/config/model_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package config
import (
"errors"
"runtime"
"slices"
"strings"
)

type ModelConfig struct {
Expand Down Expand Up @@ -74,16 +72,15 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
return SanitizeCommand(m.Cmd)
}

// ModelFilters see issue #174
// ModelFilters embeds Filters and adds legacy support for strip_params field
// See issue #174
type ModelFilters struct {
StripParams string `yaml:"stripParams"`
Filters `yaml:",inline"`
}

func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
type rawModelFilters ModelFilters
defaults := rawModelFilters{
StripParams: "",
}
defaults := rawModelFilters{}

if err := unmarshal(&defaults); err != nil {
return err
Expand All @@ -104,25 +101,8 @@ func (m *ModelFilters) UnmarshalYAML(unmarshal func(interface{}) error) error {
return nil
}

// SanitizedStripParams wraps Filters.SanitizedStripParams for backwards compatibility
// Returns ([]string, error) to match existing API
func (f ModelFilters) SanitizedStripParams() ([]string, error) {
if f.StripParams == "" {
return nil, nil
}

params := strings.Split(f.StripParams, ",")
cleaned := make([]string, 0, len(params))
seen := make(map[string]bool)

for _, param := range params {
trimmed := strings.TrimSpace(param)
if trimmed == "model" || trimmed == "" || seen[trimmed] {
continue
}
seen[trimmed] = true
cleaned = append(cleaned, trimmed)
}

// sort cleaned
slices.Sort(cleaned)
return cleaned, nil
return f.Filters.SanitizedStripParams(), nil
}
32 changes: 32 additions & 0 deletions proxy/config/model_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,35 @@ models:
assert.True(t, *config.Models["model2"].SendLoadingState)
}
}

func TestConfig_ModelFiltersWithSetParams(t *testing.T) {
content := `
models:
model1:
cmd: path/to/cmd --port ${PORT}
filters:
stripParams: "top_k"
setParams:
temperature: 0.7
top_p: 0.9
stop:
- "<|end|>"
- "<|stop|>"
`
config, err := LoadConfigFromReader(strings.NewReader(content))
assert.NoError(t, err)

modelConfig := config.Models["model1"]

// Check stripParams
stripParams, err := modelConfig.Filters.SanitizedStripParams()
assert.NoError(t, err)
assert.Equal(t, []string{"top_k"}, stripParams)

// Check setParams
setParams, keys := modelConfig.Filters.SanitizedSetParams()
assert.NotNil(t, setParams)
assert.Equal(t, []string{"stop", "temperature", "top_p"}, keys)
assert.Equal(t, 0.7, setParams["temperature"])
assert.Equal(t, 0.9, setParams["top_p"])
}
Loading
Loading