Skip to content

Commit b3d331d

Browse files
authored
Properly strip profile name slug from models fixes (#62)
The profile slug in a model name, `profile:model`, is specific to llama-swap. This strips `profile:` out of the model name request so upstreams that expect just `model` work and do not require knowing about the profile slug.
1 parent 62275e0 commit b3d331d

File tree

5 files changed

+84
-14
lines changed

5 files changed

+84
-14
lines changed

go.mod

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ require (
2929
github.com/modern-go/reflect2 v1.0.2 // indirect
3030
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
3131
github.com/pmezard/go-difflib v1.0.0 // indirect
32+
github.com/tidwall/gjson v1.18.0 // indirect
33+
github.com/tidwall/match v1.1.1 // indirect
34+
github.com/tidwall/pretty v1.2.1 // indirect
35+
github.com/tidwall/sjson v1.2.5 // indirect
3236
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
3337
github.com/ugorji/go/codec v1.2.12 // indirect
3438
golang.org/x/arch v0.8.0 // indirect

go.sum

+10
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
5757
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
5858
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
5959
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
60+
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
61+
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
62+
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
63+
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
64+
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
65+
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
66+
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
67+
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
68+
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
69+
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
6070
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
6171
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
6272
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

misc/simple-responder/simple-responder.go

+21
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ import (
1212
"time"
1313

1414
"github.com/gin-gonic/gin"
15+
"github.com/tidwall/gjson"
1516
)
1617

1718
func main() {
1819
gin.SetMode(gin.TestMode)
1920
// Define a command-line flag for the port
2021
port := flag.String("port", "8080", "port to listen on")
22+
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
2123

2224
// Define a command-line flag for the response message
2325
responseMessage := flag.String("respond", "hi", "message to respond with")
@@ -41,6 +43,25 @@ func main() {
4143
c.String(200, *responseMessage)
4244
})
4345

46+
// for issue #62 to check model name strips profile slug
47+
// has to be one of the openAI API endpoints that llama-swap proxies
48+
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
49+
r.POST("/v1/audio/speech", func(c *gin.Context) {
50+
body, err := io.ReadAll(c.Request.Body)
51+
if err != nil {
52+
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
53+
return
54+
}
55+
defer c.Request.Body.Close()
56+
modelName := gjson.GetBytes(body, "model").String()
57+
if modelName != *expectedModel {
58+
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
59+
return
60+
} else {
61+
c.JSON(http.StatusOK, gin.H{"message": "ok"})
62+
}
63+
})
64+
4465
r.POST("/v1/completions", func(c *gin.Context) {
4566
c.Header("Content-Type", "text/plain")
4667
c.String(200, *responseMessage)

proxy/proxymanager.go

+26-14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"time"
1414

1515
"github.com/gin-gonic/gin"
16+
"github.com/tidwall/gjson"
17+
"github.com/tidwall/sjson"
1618
)
1719

1820
const (
@@ -224,11 +226,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
224226
defer pm.Unlock()
225227

226228
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
227-
profileName, modelName := "", requestedModel
228-
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
229-
profileName = requestedModel[:idx]
230-
modelName = requestedModel[idx+1:]
231-
}
229+
profileName, modelName := splitRequestedModel(requestedModel)
232230

233231
if profileName != "" {
234232
if _, found := pm.config.Profiles[profileName]; !found {
@@ -344,21 +342,26 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
344342
return
345343
}
346344

347-
var requestBody map[string]interface{}
348-
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
349-
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
350-
return
351-
}
352-
model, ok := requestBody["model"].(string)
353-
if !ok {
345+
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
346+
if requestedModel == "" {
354347
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
355-
return
356348
}
357349

358-
if process, err := pm.swapModel(model); err != nil {
350+
if process, err := pm.swapModel(requestedModel); err != nil {
359351
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
360352
return
361353
} else {
354+
355+
// strip
356+
profileName, modelName := splitRequestedModel(requestedModel)
357+
if profileName != "" {
358+
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
359+
if err != nil {
360+
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
361+
return
362+
}
363+
}
364+
362365
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
363366

364367
// dechunk it as we already have all the body bytes see issue #11
@@ -387,3 +390,12 @@ func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
387390
func ProcessKeyName(groupName, modelName string) string {
388391
return groupName + PROFILE_SPLIT_CHAR + modelName
389392
}
393+
394+
func splitRequestedModel(requestedModel string) (string, string) {
395+
profileName, modelName := "", requestedModel
396+
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
397+
profileName = requestedModel[:idx]
398+
modelName = requestedModel[idx+1:]
399+
}
400+
return profileName, modelName
401+
}

proxy/proxymanager_test.go

+23
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,26 @@ func TestProxyManager_Unload(t *testing.T) {
326326
assert.Equal(t, w.Body.String(), "OK")
327327
assert.Len(t, proxy.currentProcesses, 0)
328328
}
329+
330+
// issue 62, strip profile slug from model name
331+
func TestProxyManager_StripProfileSlug(t *testing.T) {
332+
config := &Config{
333+
HealthCheckTimeout: 15,
334+
Profiles: map[string][]string{
335+
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
336+
},
337+
Models: map[string]ModelConfig{
338+
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
339+
},
340+
}
341+
342+
proxy := New(config)
343+
defer proxy.StopProcesses()
344+
345+
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
346+
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
347+
w := httptest.NewRecorder()
348+
proxy.HandlerFunc(w, req)
349+
assert.Equal(t, http.StatusOK, w.Code)
350+
assert.Contains(t, w.Body.String(), "ok")
351+
}

0 commit comments

Comments
 (0)