diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 4a0d0e07..9896edd8 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -3,6 +3,7 @@ package proxy import ( "bytes" "context" + "encoding/base64" "fmt" "io" "mime/multipart" @@ -825,23 +826,30 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc { xApiKey := c.GetHeader("x-api-key") var bearerKey string + var basicKey string if auth := c.GetHeader("Authorization"); auth != "" { if strings.HasPrefix(auth, "Bearer ") { bearerKey = strings.TrimPrefix(auth, "Bearer ") + } else if strings.HasPrefix(auth, "Basic ") { + // Basic Auth: base64(username:password), password is the API key + encoded := strings.TrimPrefix(auth, "Basic ") + if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil { + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) == 2 { + basicKey = parts[1] // password is the API key + } + } } } - // If both headers present, they must match - if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey { - pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match") - c.Abort() - return - } - - // Use x-api-key first, then Authorization - providedKey := xApiKey - if providedKey == "" { + // Use first key found: Basic, then Bearer, then x-api-key + var providedKey string + if basicKey != "" { + providedKey = basicKey + } else if bearerKey != "" { providedKey = bearerKey + } else { + providedKey = xApiKey } // Validate key @@ -854,6 +862,7 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc { } if !valid { + c.Header("WWW-Authenticate", `Basic realm="llama-swap"`) pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key") c.Abort() return diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 2330b32b..bc566fb6 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -3,6 +3,7 @@ package proxy import ( "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "math/rand" @@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool { return r.closeChannel } -func (r *TestResponseRecorder) closeClient() { - r.closeChannel <- true -} - func CreateTestResponseRecorder() *TestResponseRecorder { return &TestResponseRecorder{ httptest.NewRecorder(), @@ -1253,22 +1250,43 @@ func TestProxyManager_APIKeyAuth(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) }) - t.Run("both headers with different keys returns 400", func(t *testing.T) { + t.Run("invalid key returns 401", func(t *testing.T) { reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("x-api-key", "valid-key-1") - req.Header.Set("Authorization", "Bearer valid-key-2") + req.Header.Set("x-api-key", "invalid-key") w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusBadRequest, w.Code) - assert.Contains(t, w.Body.String(), "do not match") + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Contains(t, w.Body.String(), "unauthorized") }) - t.Run("invalid key returns 401", func(t *testing.T) { + t.Run("missing key returns 401", func(t *testing.T) { reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - req.Header.Set("x-api-key", "invalid-key") + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) + }) + + t.Run("valid key in Basic Auth header", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + // Basic Auth: base64("anyuser:valid-key-1") + credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1")) + req.Header.Set("Authorization", "Basic "+credentials) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key")) + req.Header.Set("Authorization", "Basic "+credentials) w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) @@ -1276,13 +1294,26 @@ func TestProxyManager_APIKeyAuth(t *testing.T) { assert.Contains(t, w.Body.String(), "unauthorized") }) - t.Run("missing key returns 401", func(t *testing.T) { + t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) { + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + req.Header.Set("x-api-key", "valid-key-1") + credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1")) + req.Header.Set("Authorization", "Basic "+credentials) + w := CreateTestResponseRecorder() + + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) { reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := CreateTestResponseRecorder() proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate")) }) }