Skip to content
Merged
4 changes: 1 addition & 3 deletions internal/extensionserver/extensionserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ func (s *Server) maybeModifyCluster(cluster *clusterv3.Cluster) {
extProcConfig.RequestAttributes = []string{"xds.upstream_host_metadata"}
extProcConfig.ProcessingMode = &extprocv3http.ProcessingMode{
RequestHeaderMode: extprocv3http.ProcessingMode_SEND,
// At the upstream filter, it can access the original body in its memory, so it can perform the translation
// as well as the authentication at the request headers. Hence, there's no need to send the request body to the extproc.
RequestBodyMode: extprocv3http.ProcessingMode_NONE,
RequestBodyMode: extprocv3http.ProcessingMode_BUFFERED,
// Response will be handled at the router filter level so that we could avoid the shenanigans around the retry+the upstream filter.
ResponseHeaderMode: extprocv3http.ProcessingMode_SKIP,
ResponseBodyMode: extprocv3http.ProcessingMode_NONE,
Expand Down
47 changes: 25 additions & 22 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,31 @@ func (c *chatCompletionProcessorUpstreamFilter) selectTranslator(out filterapi.V
}

// ProcessRequestHeaders implements [Processor.ProcessRequestHeaders].
//
// At the upstream filter, we already have the original request body at request headers phase.
// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
// to the extproc.
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
// Start tracking metrics for this request.
c.metrics.StartRequest(c.requestHeaders)

// The request headers have already been at the time the processor was created.
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{},
}}, nil
}

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.Context, _ *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
defer func() {
if err != nil {
c.metrics.RecordRequestCompletion(ctx, false)
}
}()

// Start tracking metrics for this request.
c.metrics.StartRequest(c.requestHeaders)
c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey])
// TODO: We do not use the body from the extproc request since we might have already translated it
// to the upstream format on the previous retry (if any). If it's possible, we should be able to
// configure the extproc filter to "not send the body but execute the ProcessRequestBody" method.
// Currently, there's no way to do this, hence Envoy has to "unnecessarily" send the entire request body
// to the extproc twice.

c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey])
headerMutation, bodyMutation, err := c.translator.RequestBody(c.originalRequestBodyRaw, c.originalRequestBody, c.onRetry)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
Expand All @@ -220,21 +229,15 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx contex
}
}

return &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{
Response: &extprocv3.CommonResponse{
HeaderMutation: headerMutation, BodyMutation: bodyMutation,
Status: extprocv3.CommonResponse_CONTINUE_AND_REPLACE,
},
resp := &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_RequestBody{
RequestBody: &extprocv3.BodyResponse{
Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation, BodyMutation: bodyMutation},
},
},
}, nil
}

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
panic("BUG: ProcessRequestBody should not be called in the upstream filter")
}
c.stream = c.originalRequestBody.Stream
return resp, nil
}

// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
Expand Down
24 changes: 18 additions & 6 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
})
}

func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) {
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{metrics: mm}
res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}},
})
require.NoError(t, err)
_, ok := res.Response.(*extprocv3.ProcessingResponse_RequestHeaders)
require.True(t, ok)
require.NotZero(t, mm.requestStart)
mm.RequireRequestNotCompleted(t)
}

func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) {
t.Run("error translation", func(t *testing.T) {
mm := &mockChatCompletionMetrics{}
Expand Down Expand Up @@ -294,7 +307,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) {
require.False(t, p.stream) // On error, stream should be false regardless of the input.
}

func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) {
func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) {
const modelKey = "x-ai-gateway-model-key"
for _, stream := range []bool{false, true} {
t.Run(fmt.Sprintf("stream%v", stream), func(t *testing.T) {
Expand All @@ -315,13 +328,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
translator: tr,
originalRequestBodyRaw: someBody,
originalRequestBody: &body,
stream: stream,
}
_, err := p.ProcessRequestHeaders(t.Context(), nil)
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody})
require.ErrorContains(t, err, "failed to transform request: test error")
mm.RequireRequestFailure(t)
mm.RequireTokensRecorded(t, 0)
mm.RequireSelectedModel(t, "some-model")
require.False(t, p.stream) // On error, stream should be false regardless of the input.
})
t.Run("ok", func(t *testing.T) {
someBody := bodyFromModel(t, "some-model", stream)
Expand All @@ -344,13 +357,12 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
translator: mt,
originalRequestBodyRaw: someBody,
originalRequestBody: &expBody,
stream: stream,
}
resp, err := p.ProcessRequestHeaders(t.Context(), nil)
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.NoError(t, err)
require.Equal(t, mt, p.translator)
require.NotNil(t, resp)
commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response
commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestBody).RequestBody.Response
require.Equal(t, headerMut, commonRes.HeaderMutation)
require.Equal(t, bodyMut, commonRes.BodyMutation)

Expand Down
24 changes: 12 additions & 12 deletions tests/extproc/envoy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -270,7 +270,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -315,7 +315,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -360,7 +360,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -416,7 +416,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -478,7 +478,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -540,7 +540,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -610,7 +610,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -661,7 +661,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -712,7 +712,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -763,7 +763,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -814,7 +814,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "NONE"
request_body_mode: "BUFFERED"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down
14 changes: 14 additions & 0 deletions tests/internal/testupstreamlib/testupstream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net"
"net/http"
"os"
"slices"
"strconv"
"time"

Expand Down Expand Up @@ -161,6 +162,19 @@ func handler(w http.ResponseWriter, r *http.Request) {
return
}

// At least for the endpoints we want to support, all requests should have a Content-Length header
// and should not use chunked transfer encoding.
if r.Header.Get("Content-Length") == "" {
logger.Println("no Content-Length header, using request body length:", len(requestBody))
http.Error(w, "no Content-Length header, using request body length: "+strconv.Itoa(len(requestBody)), http.StatusBadRequest)
return
}
if slices.Contains(r.TransferEncoding, "chunked") {
logger.Println("chunked transfer encoding detected")
http.Error(w, "chunked transfer encoding is not supported", http.StatusBadRequest)
return
}

if expectedReqBody := r.Header.Get(testupstreamlib.ExpectedRequestBodyHeaderKey); expectedReqBody != "" {
var expectedBody []byte
expectedBody, err = base64.StdEncoding.DecodeString(expectedReqBody)
Expand Down
4 changes: 2 additions & 2 deletions tests/internal/testupstreamlib/testupstream/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func Test_main(t *testing.T) {

t.Run("sse", func(t *testing.T) {
t.Parallel()
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", nil)
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", strings.NewReader("some-body"))
require.NoError(t, err)
request.Header.Set(testupstreamlib.ResponseTypeKey, "sse")
request.Header.Set(testupstreamlib.ResponseBodyHeaderKey,
Expand Down Expand Up @@ -271,7 +271,7 @@ func Test_main(t *testing.T) {

t.Run("aws-event-stream", func(t *testing.T) {
t.Parallel()
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", nil)
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", strings.NewReader("some-body"))
require.NoError(t, err)
request.Header.Set(testupstreamlib.ResponseTypeKey, "aws-event-stream")
request.Header.Set(testupstreamlib.ResponseBodyHeaderKey,
Expand Down