Skip to content

Commit 84e2c07

Browse files
authored
Refactor wildcard out of CORS headers (#81)
Changes to CORS functionality: - `Access-Control-Allow-Origin: *` is set for all requests - for pre-flight OPTIONS requests - specify methods: `Access-Control-Allow-Methods: GET, POST, PUT, PATCH, DELETE, OPTIONS` - if the client sent `Access-Control-Request-Headers` then echo back the same value in `Access-Control-Allow-Headers`. If no `Access-Control-Request-Headers` were sent, then send back a default set - set `Access-Control-Max-Age: 86400` to that may improve performance - Add CORS tests to the proxy-manager
1 parent 680af28 commit 84e2c07

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

proxy/proxymanager.go

+17-4
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,26 @@ func New(config *Config) *ProxyManager {
7272
})
7373
}
7474

75-
// see: https://github.com/mostlygeek/llama-swap/issues/42
75+
// see: issue: #81, #77 and #42 for CORS issues
7676
// respond with permissive OPTIONS for any endpoint
7777
pm.ginEngine.Use(func(c *gin.Context) {
78+
79+
// set this for all requests
80+
c.Header("Access-Control-Allow-Origin", "*")
81+
7882
if c.Request.Method == "OPTIONS" {
79-
c.Header("Access-Control-Allow-Origin", "*")
80-
c.Header("Access-Control-Allow-Methods", "*")
81-
c.Header("Access-Control-Allow-Headers", "*")
83+
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
84+
85+
// allow whatever the client requested by default
86+
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
87+
c.Header("Access-Control-Allow-Headers", headers)
88+
} else {
89+
c.Header(
90+
"Access-Control-Allow-Headers",
91+
"Content-Type, Authorization, Accept, X-Requested-With",
92+
)
93+
}
94+
c.Header("Access-Control-Max-Age", "86400")
8295
c.AbortWithStatus(http.StatusNoContent)
8396
return
8497
}

proxy/proxymanager_test.go

+90
Original file line numberDiff line numberDiff line change
@@ -639,5 +639,95 @@ func TestProxyManager_UseModelName(t *testing.T) {
639639
assert.Equal(t, upstreamModelName, response["model"])
640640
})
641641
}
642+
}
643+
644+
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
645+
config := &Config{
646+
HealthCheckTimeout: 15,
647+
Models: map[string]ModelConfig{
648+
"model1": getTestSimpleResponderConfig("model1"),
649+
},
650+
LogRequests: true,
651+
}
652+
653+
tests := []struct {
654+
name string
655+
method string
656+
requestHeaders map[string]string
657+
expectedStatus int
658+
expectedHeaders map[string]string
659+
}{
660+
{
661+
name: "OPTIONS with no headers",
662+
method: "OPTIONS",
663+
expectedStatus: http.StatusNoContent,
664+
expectedHeaders: map[string]string{
665+
"Access-Control-Allow-Origin": "*",
666+
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
667+
"Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With",
668+
},
669+
},
670+
{
671+
name: "OPTIONS with specific headers",
672+
method: "OPTIONS",
673+
requestHeaders: map[string]string{
674+
"Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header",
675+
},
676+
expectedStatus: http.StatusNoContent,
677+
expectedHeaders: map[string]string{
678+
"Access-Control-Allow-Origin": "*",
679+
"Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS",
680+
"Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header",
681+
},
682+
},
683+
{
684+
name: "Non-OPTIONS request",
685+
method: "GET",
686+
expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined
687+
},
688+
}
689+
690+
for _, tt := range tests {
691+
t.Run(tt.name, func(t *testing.T) {
692+
proxy := New(config)
693+
defer proxy.StopProcesses()
694+
695+
req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil)
696+
for k, v := range tt.requestHeaders {
697+
req.Header.Set(k, v)
698+
}
699+
700+
w := httptest.NewRecorder()
701+
proxy.ginEngine.ServeHTTP(w, req)
702+
703+
assert.Equal(t, tt.expectedStatus, w.Code)
704+
705+
for header, expectedValue := range tt.expectedHeaders {
706+
assert.Equal(t, expectedValue, w.Header().Get(header))
707+
}
708+
})
709+
}
710+
}
642711

712+
func TestProxyManager_CORSHeadersInRegularRequest(t *testing.T) {
713+
config := &Config{
714+
HealthCheckTimeout: 15,
715+
Models: map[string]ModelConfig{
716+
"model1": getTestSimpleResponderConfig("model1"),
717+
},
718+
LogRequests: true,
719+
}
720+
721+
proxy := New(config)
722+
defer proxy.StopProcesses()
723+
724+
// Test that CORS headers are present in regular POST requests
725+
reqBody := `{"model":"model1"}`
726+
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
727+
w := httptest.NewRecorder()
728+
729+
proxy.ginEngine.ServeHTTP(w, req)
730+
731+
assert.Equal(t, http.StatusOK, w.Code)
732+
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"))
643733
}

0 commit comments

Comments
 (0)