Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/extensionserver/extensionserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ func (s *Server) maybeModifyCluster(cluster *clusterv3.Cluster) {
extProcConfig.RequestAttributes = []string{"xds.upstream_host_metadata"}
extProcConfig.ProcessingMode = &extprocv3http.ProcessingMode{
RequestHeaderMode: extprocv3http.ProcessingMode_SEND,
RequestBodyMode: extprocv3http.ProcessingMode_BUFFERED,
// At the upstream filter, it can access the original body in its memory, so it can perform the translation
// as well as the authentication at the request headers. Hence, there's no need to send the request body to the extproc.
RequestBodyMode: extprocv3http.ProcessingMode_NONE,
// Response will be handled at the router filter level so that we could avoid the shenanigans around the retry+the upstream filter.
ResponseHeaderMode: extprocv3http.ProcessingMode_SKIP,
ResponseBodyMode: extprocv3http.ProcessingMode_NONE,
Expand Down
47 changes: 22 additions & 25 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,31 +188,22 @@ func (c *chatCompletionProcessorUpstreamFilter) selectTranslator(out filterapi.V
}

// ProcessRequestHeaders implements [Processor.ProcessRequestHeaders].
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
// Start tracking metrics for this request.
c.metrics.StartRequest(c.requestHeaders)

// The request headers have already been at the time the processor was created.
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{},
}}, nil
}

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.Context, _ *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
//
// At the upstream filter, we already have the original request body at request headers phase.
// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy
// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again
// to the extproc.
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
defer func() {
if err != nil {
c.metrics.RecordRequestCompletion(ctx, false)
}
}()

// TODO: We do not use the body from the extproc request since we might have already translated it
// to the upstream format on the previous retry (if any). If it's possible, we should be able to
// configure the extproc filter to "not send the body but execute the ProcessRequestBody" method.
// Currently, there's no way to do this, hence Envoy has to "unnecessarily" send the entire request body
// to the extproc twice.

// Start tracking metrics for this request.
c.metrics.StartRequest(c.requestHeaders)
c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey])

headerMutation, bodyMutation, err := c.translator.RequestBody(c.originalRequestBodyRaw, c.originalRequestBody, c.onRetry)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
Expand All @@ -230,15 +221,21 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.C
}
}

resp := &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_RequestBody{
RequestBody: &extprocv3.BodyResponse{
Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation, BodyMutation: bodyMutation},
return &extprocv3.ProcessingResponse{
Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{
Response: &extprocv3.CommonResponse{
HeaderMutation: headerMutation, BodyMutation: bodyMutation,
Status: extprocv3.CommonResponse_CONTINUE_AND_REPLACE,
},
},
},
}
c.stream = c.originalRequestBody.Stream
return resp, nil
}, nil
}

// ProcessRequestBody implements [Processor.ProcessRequestBody].
func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
panic("BUG: ProcessRequestBody should not be called in the upstream filter")
}

// ProcessResponseHeaders implements [Processor.ProcessResponseHeaders].
Expand Down
24 changes: 6 additions & 18 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,6 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
})
}

func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) {
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{metrics: mm}
res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{
Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}},
})
require.NoError(t, err)
_, ok := res.Response.(*extprocv3.ProcessingResponse_RequestHeaders)
require.True(t, ok)
require.NotZero(t, mm.requestStart)
mm.RequireRequestNotCompleted(t)
}

func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) {
t.Run("error translation", func(t *testing.T) {
mm := &mockChatCompletionMetrics{}
Expand Down Expand Up @@ -317,7 +304,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) {
require.False(t, p.stream) // On error, stream should be false regardless of the input.
}

func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) {
func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) {
const modelKey = "x-ai-gateway-model-key"
for _, stream := range []bool{false, true} {
t.Run(fmt.Sprintf("stream%v", stream), func(t *testing.T) {
Expand All @@ -338,13 +325,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T)
translator: tr,
originalRequestBodyRaw: someBody,
originalRequestBody: &body,
stream: stream,
}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody})
_, err := p.ProcessRequestHeaders(t.Context(), nil)
require.ErrorContains(t, err, "failed to transform request: test error")
mm.RequireRequestFailure(t)
mm.RequireTokensRecorded(t, 0)
mm.RequireSelectedModel(t, "some-model")
require.False(t, p.stream) // On error, stream should be false regardless of the input.
})
t.Run("ok", func(t *testing.T) {
someBody := bodyFromModel(t, "some-model", stream)
Expand All @@ -367,12 +354,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T)
translator: mt,
originalRequestBodyRaw: someBody,
originalRequestBody: &expBody,
stream: stream,
}
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
resp, err := p.ProcessRequestHeaders(t.Context(), nil)
require.NoError(t, err)
require.Equal(t, mt, p.translator)
require.NotNil(t, resp)
commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestBody).RequestBody.Response
commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response
require.Equal(t, headerMut, commonRes.HeaderMutation)
require.Equal(t, bodyMut, commonRes.BodyMutation)

Expand Down
24 changes: 12 additions & 12 deletions tests/extproc/envoy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -270,7 +270,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -315,7 +315,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -360,7 +360,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -416,7 +416,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -478,7 +478,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -540,7 +540,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -610,7 +610,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -661,7 +661,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -712,7 +712,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -763,7 +763,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down Expand Up @@ -814,7 +814,7 @@ static_resources:
- xds.upstream_host_metadata
processing_mode:
request_header_mode: "SEND"
request_body_mode: "BUFFERED"
request_body_mode: "NONE"
response_header_mode: "SKIP"
response_body_mode: "NONE"
grpc_service:
Expand Down
14 changes: 0 additions & 14 deletions tests/internal/testupstreamlib/testupstream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"net"
"net/http"
"os"
"slices"
"strconv"
"time"

Expand Down Expand Up @@ -162,19 +161,6 @@ func handler(w http.ResponseWriter, r *http.Request) {
return
}

// At least for the endpoints we want to support, all requests should have a Content-Length header
// and should not use chunked transfer encoding.
if r.Header.Get("Content-Length") == "" {
logger.Println("no Content-Length header, using request body length:", len(requestBody))
http.Error(w, "no Content-Length header, using request body length: "+strconv.Itoa(len(requestBody)), http.StatusBadRequest)
return
}
if slices.Contains(r.TransferEncoding, "chunked") {
logger.Println("chunked transfer encoding detected")
http.Error(w, "chunked transfer encoding is not supported", http.StatusBadRequest)
return
}

if expectedReqBody := r.Header.Get(testupstreamlib.ExpectedRequestBodyHeaderKey); expectedReqBody != "" {
var expectedBody []byte
expectedBody, err = base64.StdEncoding.DecodeString(expectedReqBody)
Expand Down
4 changes: 2 additions & 2 deletions tests/internal/testupstreamlib/testupstream/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func Test_main(t *testing.T) {

t.Run("sse", func(t *testing.T) {
t.Parallel()
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", strings.NewReader("some-body"))
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", nil)
require.NoError(t, err)
request.Header.Set(testupstreamlib.ResponseTypeKey, "sse")
request.Header.Set(testupstreamlib.ResponseBodyHeaderKey,
Expand Down Expand Up @@ -271,7 +271,7 @@ func Test_main(t *testing.T) {

t.Run("aws-event-stream", func(t *testing.T) {
t.Parallel()
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", strings.NewReader("some-body"))
request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", nil)
require.NoError(t, err)
request.Header.Set(testupstreamlib.ResponseTypeKey, "aws-event-stream")
request.Header.Set(testupstreamlib.ResponseBodyHeaderKey,
Expand Down
Loading