Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions proxy/proxymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proxy
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"mime/multipart"
Expand Down Expand Up @@ -825,23 +826,30 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
xApiKey := c.GetHeader("x-api-key")

var bearerKey string
var basicKey string
if auth := c.GetHeader("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
bearerKey = strings.TrimPrefix(auth, "Bearer ")
} else if strings.HasPrefix(auth, "Basic ") {
// Basic Auth: base64(username:password), password is the API key
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
parts := strings.SplitN(string(decoded), ":", 2)
if len(parts) == 2 {
basicKey = parts[1] // password is the API key
}
}
}
}

// If both headers present, they must match
if xApiKey != "" && bearerKey != "" && xApiKey != bearerKey {
pm.sendErrorResponse(c, http.StatusBadRequest, "x-api-key and Authorization header values do not match")
c.Abort()
return
}

// Use x-api-key first, then Authorization
providedKey := xApiKey
if providedKey == "" {
// Use first key found: Basic, then Bearer, then x-api-key
var providedKey string
if basicKey != "" {
providedKey = basicKey
} else if bearerKey != "" {
providedKey = bearerKey
} else {
providedKey = xApiKey
}

// Validate key
Expand All @@ -854,6 +862,7 @@ func (pm *ProxyManager) apiKeyAuth() gin.HandlerFunc {
}

if !valid {
c.Header("WWW-Authenticate", `Basic realm="llama-swap"`)
pm.sendErrorResponse(c, http.StatusUnauthorized, "unauthorized: invalid or missing API key")
c.Abort()
return
Expand Down
55 changes: 43 additions & 12 deletions proxy/proxymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proxy
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"math/rand"
Expand Down Expand Up @@ -36,10 +37,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool {
return r.closeChannel
}

func (r *TestResponseRecorder) closeClient() {
r.closeChannel <- true
}

func CreateTestResponseRecorder() *TestResponseRecorder {
return &TestResponseRecorder{
httptest.NewRecorder(),
Expand Down Expand Up @@ -1253,36 +1250,70 @@ func TestProxyManager_APIKeyAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
})

t.Run("both headers with different keys returns 400", func(t *testing.T) {
t.Run("invalid key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
req.Header.Set("Authorization", "Bearer valid-key-2")
req.Header.Set("x-api-key", "invalid-key")
w := CreateTestResponseRecorder()

proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
assert.Contains(t, w.Body.String(), "do not match")
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})

t.Run("invalid key returns 401", func(t *testing.T) {
t.Run("missing key returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "invalid-key")
w := CreateTestResponseRecorder()

proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
})

t.Run("valid key in Basic Auth header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
// Basic Auth: base64("anyuser:valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()

proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})

t.Run("invalid key in Basic Auth header returns 401", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
credentials := base64.StdEncoding.EncodeToString([]byte("anyuser:wrong-key"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()

proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Contains(t, w.Body.String(), "unauthorized")
})

t.Run("missing key returns 401", func(t *testing.T) {
t.Run("x-api-key and Basic Auth with matching keys", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
req.Header.Set("x-api-key", "valid-key-1")
credentials := base64.StdEncoding.EncodeToString([]byte("user:valid-key-1"))
req.Header.Set("Authorization", "Basic "+credentials)
w := CreateTestResponseRecorder()

proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
})

t.Run("401 response includes WWW-Authenticate header", func(t *testing.T) {
reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()

proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnauthorized, w.Code)
assert.Equal(t, `Basic realm="llama-swap"`, w.Header().Get("WWW-Authenticate"))
})
}

Expand Down
Loading