Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions internal/apiserver/batch/batch_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,18 @@ func (c *BatchAPIHandler) CreateBatch(w http.ResponseWriter, r *http.Request) {

// Capture configured pass-through headers into tags with "pth:" prefix
for _, headerName := range c.config.BatchAPI.PassThroughHeaders {
if value := r.Header.Get(headerName); value != "" {
tags[batch_types.TagPrefixPassThroughHeader+headerName] = value
// The external auth service (via Envoy ext_authz) may append
// request headers as separate entries instead of overwriting them. If a client
// sends a spoofed pass-through header, the auth service appends the real value as a
// second entry. We take the last entry from r.Header.Values() because Envoy's
// ext_authz pipeline guarantees auth-injected entries come after client-supplied
// ones.
if values := r.Header.Values(headerName); len(values) > 0 {
// Skip empty last values to avoid persisting blank tags (e.g. when the
// auth service clears a spoofed header by appending an empty entry).
if last := values[len(values)-1]; last != "" {
tags[batch_types.TagPrefixPassThroughHeader+headerName] = last
}
}
}

Expand Down
332 changes: 331 additions & 1 deletion internal/apiserver/batch/batch_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ import (
)

func setupTestHandler() *BatchAPIHandler {
config := &common.ServerConfig{}
return setupTestHandlerWithConfig(&common.ServerConfig{})
}

func setupTestHandlerWithConfig(config *common.ServerConfig) *BatchAPIHandler {
clients := &clientset.Clientset{
Inference: nil,
File: nil,
Expand Down Expand Up @@ -280,6 +282,334 @@ func TestBatchHandler(t *testing.T) {
}
})
})

t.Run("PassThroughHeaders", func(t *testing.T) {
t.Run("SingleValue", func(t *testing.T) {
handler := setupTestHandlerWithConfig(&common.ServerConfig{
BatchAPI: common.BatchAPIConfig{
PassThroughHeaders: []string{"X-Custom-Header"},
},
})

fileItem := &dbapi.FileItem{
BaseIndexes: dbapi.BaseIndexes{
ID: "file-pth-single",
TenantID: common.DefaultTenantID,
},
}
ctx := context.Background()
if err := handler.clients.FileDB.DBStore(ctx, fileItem); err != nil {
t.Fatalf("Failed to store file: %v", err)
}

reqBody := openai.CreateBatchRequest{
InputFileID: "file-pth-single",
Endpoint: openai.EndpointChatCompletions,
CompletionWindow: "24h",
}
body, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/batches", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Custom-Header", "custom-value")
rr := httptest.NewRecorder()
handler.CreateBatch(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Fatalf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

var batch openai.Batch
if err := json.NewDecoder(rr.Body).Decode(&batch); err != nil {
t.Fatalf("Failed to decode response body: %v", err)
}

// Verify the pass-through header was stored as a tag
query := &dbapi.BatchQuery{
BaseQuery: dbapi.BaseQuery{
IDs: []string{batch.ID},
TenantID: common.DefaultTenantID,
},
}
items, _, _, err := handler.clients.BatchDB.DBGet(ctx, query, true, 0, 1)
if err != nil {
t.Fatalf("Failed to retrieve batch from database: %v", err)
}
if len(items) == 0 {
t.Fatal("Batch not found in database")
}

if got := items[0].Tags["pth:X-Custom-Header"]; got != "custom-value" {
t.Errorf("Expected tag pth:X-Custom-Header to be 'custom-value', got %q", got)
}
})

// When multiple values exist for the same header (e.g. client-supplied
// followed by Envoy ext_authz-injected), the handler must use the last
// value because the auth service appends after any client-spoofed entry.
t.Run("MultipleValuesUsesLast", func(t *testing.T) {
handler := setupTestHandlerWithConfig(&common.ServerConfig{
BatchAPI: common.BatchAPIConfig{
PassThroughHeaders: []string{"X-Auth-User"},
},
})

fileItem := &dbapi.FileItem{
BaseIndexes: dbapi.BaseIndexes{
ID: "file-pth-multi",
TenantID: common.DefaultTenantID,
},
}
ctx := context.Background()
if err := handler.clients.FileDB.DBStore(ctx, fileItem); err != nil {
t.Fatalf("Failed to store file: %v", err)
}

reqBody := openai.CreateBatchRequest{
InputFileID: "file-pth-multi",
Endpoint: openai.EndpointChatCompletions,
CompletionWindow: "24h",
}
body, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/batches", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// Simulate a spoofed client header followed by an auth-injected header.
// Header.Add appends additional values for the same key.
req.Header.Set("X-Auth-User", "spoofed-user")
req.Header.Add("X-Auth-User", "real-user")
rr := httptest.NewRecorder()
handler.CreateBatch(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Fatalf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

var batch openai.Batch
if err := json.NewDecoder(rr.Body).Decode(&batch); err != nil {
t.Fatalf("Failed to decode response body: %v", err)
}

query := &dbapi.BatchQuery{
BaseQuery: dbapi.BaseQuery{
IDs: []string{batch.ID},
TenantID: common.DefaultTenantID,
},
}
items, _, _, err := handler.clients.BatchDB.DBGet(ctx, query, true, 0, 1)
if err != nil {
t.Fatalf("Failed to retrieve batch from database: %v", err)
}
if len(items) == 0 {
t.Fatal("Batch not found in database")
}

// Must be "real-user" (the last/auth-injected value), not "spoofed-user"
if got := items[0].Tags["pth:X-Auth-User"]; got != "real-user" {
t.Errorf("Expected tag pth:X-Auth-User to be 'real-user', got %q", got)
}
})

// When the auth service clears a spoofed header by appending an empty
// entry, the empty last value must not produce a tag — otherwise
// downstream consumers could misinterpret an empty string as valid.
t.Run("EmptyLastValueSkipped", func(t *testing.T) {
handler := setupTestHandlerWithConfig(&common.ServerConfig{
BatchAPI: common.BatchAPIConfig{
PassThroughHeaders: []string{"X-Auth-User"},
},
})

fileItem := &dbapi.FileItem{
BaseIndexes: dbapi.BaseIndexes{
ID: "file-pth-empty",
TenantID: common.DefaultTenantID,
},
}
ctx := context.Background()
if err := handler.clients.FileDB.DBStore(ctx, fileItem); err != nil {
t.Fatalf("Failed to store file: %v", err)
}

reqBody := openai.CreateBatchRequest{
InputFileID: "file-pth-empty",
Endpoint: openai.EndpointChatCompletions,
CompletionWindow: "24h",
}
body, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/batches", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// Client sends a spoofed value, auth service clears it with an empty entry.
req.Header.Set("X-Auth-User", "spoofed-user")
req.Header.Add("X-Auth-User", "")
rr := httptest.NewRecorder()
handler.CreateBatch(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Fatalf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

var batch openai.Batch
if err := json.NewDecoder(rr.Body).Decode(&batch); err != nil {
t.Fatalf("Failed to decode response body: %v", err)
}

query := &dbapi.BatchQuery{
BaseQuery: dbapi.BaseQuery{
IDs: []string{batch.ID},
TenantID: common.DefaultTenantID,
},
}
items, _, _, err := handler.clients.BatchDB.DBGet(ctx, query, true, 0, 1)
if err != nil {
t.Fatalf("Failed to retrieve batch from database: %v", err)
}
if len(items) == 0 {
t.Fatal("Batch not found in database")
}

// The spoofed value must NOT leak through; the empty last value
// should cause the tag to be omitted entirely.
if _, exists := items[0].Tags["pth:X-Auth-User"]; exists {
t.Errorf("Expected no tag for empty last value, but pth:X-Auth-User was set to %q",
items[0].Tags["pth:X-Auth-User"])
}
})

t.Run("HeaderNotPresent", func(t *testing.T) {
handler := setupTestHandlerWithConfig(&common.ServerConfig{
BatchAPI: common.BatchAPIConfig{
PassThroughHeaders: []string{"X-Missing-Header"},
},
})

fileItem := &dbapi.FileItem{
BaseIndexes: dbapi.BaseIndexes{
ID: "file-pth-missing",
TenantID: common.DefaultTenantID,
},
}
ctx := context.Background()
if err := handler.clients.FileDB.DBStore(ctx, fileItem); err != nil {
t.Fatalf("Failed to store file: %v", err)
}

reqBody := openai.CreateBatchRequest{
InputFileID: "file-pth-missing",
Endpoint: openai.EndpointChatCompletions,
CompletionWindow: "24h",
}
body, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/batches", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
// Do not set X-Missing-Header
rr := httptest.NewRecorder()
handler.CreateBatch(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Fatalf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

var batch openai.Batch
if err := json.NewDecoder(rr.Body).Decode(&batch); err != nil {
t.Fatalf("Failed to decode response body: %v", err)
}

query := &dbapi.BatchQuery{
BaseQuery: dbapi.BaseQuery{
IDs: []string{batch.ID},
TenantID: common.DefaultTenantID,
},
}
items, _, _, err := handler.clients.BatchDB.DBGet(ctx, query, true, 0, 1)
if err != nil {
t.Fatalf("Failed to retrieve batch from database: %v", err)
}
if len(items) == 0 {
t.Fatal("Batch not found in database")
}

// Tag should not exist when the header is absent
if _, exists := items[0].Tags["pth:X-Missing-Header"]; exists {
t.Error("Expected no tag for absent header, but pth:X-Missing-Header was set")
}
})

t.Run("MultipleConfiguredHeaders", func(t *testing.T) {
handler := setupTestHandlerWithConfig(&common.ServerConfig{
BatchAPI: common.BatchAPIConfig{
PassThroughHeaders: []string{"X-Header-A", "X-Header-B"},
},
})

fileItem := &dbapi.FileItem{
BaseIndexes: dbapi.BaseIndexes{
ID: "file-pth-multiple",
TenantID: common.DefaultTenantID,
},
}
ctx := context.Background()
if err := handler.clients.FileDB.DBStore(ctx, fileItem); err != nil {
t.Fatalf("Failed to store file: %v", err)
}

reqBody := openai.CreateBatchRequest{
InputFileID: "file-pth-multiple",
Endpoint: openai.EndpointChatCompletions,
CompletionWindow: "24h",
}
body, err := json.Marshal(reqBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/batches", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Header-A", "value-a")
req.Header.Set("X-Header-B", "value-b")
rr := httptest.NewRecorder()
handler.CreateBatch(rr, req)

if status := rr.Code; status != http.StatusOK {
t.Fatalf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
}

var batch openai.Batch
if err := json.NewDecoder(rr.Body).Decode(&batch); err != nil {
t.Fatalf("Failed to decode response body: %v", err)
}

query := &dbapi.BatchQuery{
BaseQuery: dbapi.BaseQuery{
IDs: []string{batch.ID},
TenantID: common.DefaultTenantID,
},
}
items, _, _, err := handler.clients.BatchDB.DBGet(ctx, query, true, 0, 1)
if err != nil {
t.Fatalf("Failed to retrieve batch from database: %v", err)
}
if len(items) == 0 {
t.Fatal("Batch not found in database")
}

if got := items[0].Tags["pth:X-Header-A"]; got != "value-a" {
t.Errorf("Expected tag pth:X-Header-A to be 'value-a', got %q", got)
}
if got := items[0].Tags["pth:X-Header-B"]; got != "value-b" {
t.Errorf("Expected tag pth:X-Header-B to be 'value-b', got %q", got)
}
})
})
})

t.Run("RetrieveBatch", func(t *testing.T) {
Expand Down