diff --git a/internal/extensionserver/extensionserver.go b/internal/extensionserver/extensionserver.go index 49b91015d1..9aebbee6a5 100644 --- a/internal/extensionserver/extensionserver.go +++ b/internal/extensionserver/extensionserver.go @@ -242,9 +242,7 @@ func (s *Server) maybeModifyCluster(cluster *clusterv3.Cluster) { extProcConfig.RequestAttributes = []string{"xds.upstream_host_metadata"} extProcConfig.ProcessingMode = &extprocv3http.ProcessingMode{ RequestHeaderMode: extprocv3http.ProcessingMode_SEND, - // At the upstream filter, it can access the original body in its memory, so it can perform the translation - // as well as the authentication at the request headers. Hence, there's no need to send the request body to the extproc. - RequestBodyMode: extprocv3http.ProcessingMode_NONE, + RequestBodyMode: extprocv3http.ProcessingMode_BUFFERED, // Response will be handled at the router filter level so that we could avoid the shenanigans around the retry+the upstream filter. ResponseHeaderMode: extprocv3http.ProcessingMode_SKIP, ResponseBodyMode: extprocv3http.ProcessingMode_NONE, diff --git a/internal/extproc/chatcompletion_processor.go b/internal/extproc/chatcompletion_processor.go index f20282af47..c684c9c907 100644 --- a/internal/extproc/chatcompletion_processor.go +++ b/internal/extproc/chatcompletion_processor.go @@ -187,22 +187,31 @@ func (c *chatCompletionProcessorUpstreamFilter) selectTranslator(out filterapi.V } // ProcessRequestHeaders implements [Processor.ProcessRequestHeaders]. -// -// At the upstream filter, we already have the original request body at request headers phase. -// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy -// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again -// to the extproc. -func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { +func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { + // Start tracking metrics for this request. + c.metrics.StartRequest(c.requestHeaders) + + // The request headers have already been at the time the processor was created. + return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extprocv3.HeadersResponse{}, + }}, nil +} + +// ProcessRequestBody implements [Processor.ProcessRequestBody]. +func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.Context, _ *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { defer func() { if err != nil { c.metrics.RecordRequestCompletion(ctx, false) } }() - // Start tracking metrics for this request. - c.metrics.StartRequest(c.requestHeaders) - c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey]) + // TODO: We do not use the body from the extproc request since we might have already translated it + // to the upstream format on the previous retry (if any). If it's possible, we should be able to + // configure the extproc filter to "not send the body but execute the ProcessRequestBody" method. + // Currently, there's no way to do this, hence Envoy has to "unnecessarily" send the entire request body + // to the extproc twice. + c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey]) headerMutation, bodyMutation, err := c.translator.RequestBody(c.originalRequestBodyRaw, c.originalRequestBody, c.onRetry) if err != nil { return nil, fmt.Errorf("failed to transform request: %w", err) @@ -220,21 +229,15 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx contex } } - return &extprocv3.ProcessingResponse{ - Response: &extprocv3.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extprocv3.HeadersResponse{ - Response: &extprocv3.CommonResponse{ - HeaderMutation: headerMutation, BodyMutation: bodyMutation, - Status: extprocv3.CommonResponse_CONTINUE_AND_REPLACE, - }, + resp := &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_RequestBody{ + RequestBody: &extprocv3.BodyResponse{ + Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation, BodyMutation: bodyMutation}, }, }, - }, nil -} - -// ProcessRequestBody implements [Processor.ProcessRequestBody]. -func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { - panic("BUG: ProcessRequestBody should not be called in the upstream filter") + } + c.stream = c.originalRequestBody.Stream + return resp, nil } // ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. diff --git a/internal/extproc/chatcompletion_processor_test.go b/internal/extproc/chatcompletion_processor_test.go index ea3acd0914..3cac813e44 100644 --- a/internal/extproc/chatcompletion_processor_test.go +++ b/internal/extproc/chatcompletion_processor_test.go @@ -133,6 +133,19 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { }) } +func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { + mm := &mockChatCompletionMetrics{} + p := &chatCompletionProcessorUpstreamFilter{metrics: mm} + res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{ + Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}}, + }) + require.NoError(t, err) + _, ok := res.Response.(*extprocv3.ProcessingResponse_RequestHeaders) + require.True(t, ok) + require.NotZero(t, mm.requestStart) + mm.RequireRequestNotCompleted(t) +} + func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) { t.Run("error translation", func(t *testing.T) { mm := &mockChatCompletionMetrics{} @@ -294,7 +307,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) { require.False(t, p.stream) // On error, stream should be false regardless of the input. } -func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { +func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) { const modelKey = "x-ai-gateway-model-key" for _, stream := range []bool{false, true} { t.Run(fmt.Sprintf("stream%v", stream), func(t *testing.T) { @@ -315,13 +328,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing translator: tr, originalRequestBodyRaw: someBody, originalRequestBody: &body, - stream: stream, } - _, err := p.ProcessRequestHeaders(t.Context(), nil) + _, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody}) require.ErrorContains(t, err, "failed to transform request: test error") mm.RequireRequestFailure(t) mm.RequireTokensRecorded(t, 0) mm.RequireSelectedModel(t, "some-model") + require.False(t, p.stream) // On error, stream should be false regardless of the input. }) t.Run("ok", func(t *testing.T) { someBody := bodyFromModel(t, "some-model", stream) @@ -344,13 +357,12 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing translator: mt, originalRequestBodyRaw: someBody, originalRequestBody: &expBody, - stream: stream, } - resp, err := p.ProcessRequestHeaders(t.Context(), nil) + resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{}) require.NoError(t, err) require.Equal(t, mt, p.translator) require.NotNil(t, resp) - commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response + commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestBody).RequestBody.Response require.Equal(t, headerMut, commonRes.HeaderMutation) require.Equal(t, bodyMut, commonRes.BodyMutation) diff --git a/tests/extproc/envoy.yaml b/tests/extproc/envoy.yaml index 2997a18e4f..35b58fcbed 100644 --- a/tests/extproc/envoy.yaml +++ b/tests/extproc/envoy.yaml @@ -225,7 +225,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -270,7 +270,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -315,7 +315,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -360,7 +360,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -416,7 +416,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -478,7 +478,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -540,7 +540,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -610,7 +610,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -661,7 +661,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -712,7 +712,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -763,7 +763,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -814,7 +814,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "NONE" + request_body_mode: "BUFFERED" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: diff --git a/tests/internal/testupstreamlib/testupstream/main.go b/tests/internal/testupstreamlib/testupstream/main.go index 9af0c81c13..c05aaf5afc 100644 --- a/tests/internal/testupstreamlib/testupstream/main.go +++ b/tests/internal/testupstreamlib/testupstream/main.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "os" + "slices" "strconv" "time" @@ -161,6 +162,19 @@ func handler(w http.ResponseWriter, r *http.Request) { return } + // At least for the endpoints we want to support, all requests should have a Content-Length header + // and should not use chunked transfer encoding. + if r.Header.Get("Content-Length") == "" { + logger.Println("no Content-Length header, using request body length:", len(requestBody)) + http.Error(w, "no Content-Length header, using request body length: "+strconv.Itoa(len(requestBody)), http.StatusBadRequest) + return + } + if slices.Contains(r.TransferEncoding, "chunked") { + logger.Println("chunked transfer encoding detected") + http.Error(w, "chunked transfer encoding is not supported", http.StatusBadRequest) + return + } + if expectedReqBody := r.Header.Get(testupstreamlib.ExpectedRequestBodyHeaderKey); expectedReqBody != "" { var expectedBody []byte expectedBody, err = base64.StdEncoding.DecodeString(expectedReqBody) diff --git a/tests/internal/testupstreamlib/testupstream/main_test.go b/tests/internal/testupstreamlib/testupstream/main_test.go index 8d848a6bbf..b05c06c034 100644 --- a/tests/internal/testupstreamlib/testupstream/main_test.go +++ b/tests/internal/testupstreamlib/testupstream/main_test.go @@ -43,7 +43,7 @@ func Test_main(t *testing.T) { t.Run("sse", func(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", nil) + request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", strings.NewReader("some-body")) require.NoError(t, err) request.Header.Set(testupstreamlib.ResponseTypeKey, "sse") request.Header.Set(testupstreamlib.ResponseBodyHeaderKey, @@ -271,7 +271,7 @@ func Test_main(t *testing.T) { t.Run("aws-event-stream", func(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", nil) + request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", strings.NewReader("some-body")) require.NoError(t, err) request.Header.Set(testupstreamlib.ResponseTypeKey, "aws-event-stream") request.Header.Set(testupstreamlib.ResponseBodyHeaderKey,