From bd92b19e1e4334b36dfb8143070f4a0661fc7e9e Mon Sep 17 00:00:00 2001 From: Dan Sun Date: Tue, 1 Jul 2025 20:06:06 -0400 Subject: [PATCH] Revert "extproc: reverts the use of REPLACE_AND_CONTINUE (#730)" This reverts commit 1bed302ec6cd7fb54dc6e8a0422fc6ede37fc92d. Signed-off-by: Dan Sun --- internal/extensionserver/extensionserver.go | 4 +- internal/extproc/chatcompletion_processor.go | 47 +++++++++---------- .../extproc/chatcompletion_processor_test.go | 24 +++------- tests/extproc/envoy.yaml | 24 +++++----- .../testupstreamlib/testupstream/main.go | 14 ------ .../testupstreamlib/testupstream/main_test.go | 4 +- 6 files changed, 45 insertions(+), 72 deletions(-) diff --git a/internal/extensionserver/extensionserver.go b/internal/extensionserver/extensionserver.go index 9aebbee6a5..49b91015d1 100644 --- a/internal/extensionserver/extensionserver.go +++ b/internal/extensionserver/extensionserver.go @@ -242,7 +242,9 @@ func (s *Server) maybeModifyCluster(cluster *clusterv3.Cluster) { extProcConfig.RequestAttributes = []string{"xds.upstream_host_metadata"} extProcConfig.ProcessingMode = &extprocv3http.ProcessingMode{ RequestHeaderMode: extprocv3http.ProcessingMode_SEND, - RequestBodyMode: extprocv3http.ProcessingMode_BUFFERED, + // At the upstream filter, it can access the original body in its memory, so it can perform the translation + // as well as the authentication at the request headers. Hence, there's no need to send the request body to the extproc. + RequestBodyMode: extprocv3http.ProcessingMode_NONE, // Response will be handled at the router filter level so that we could avoid the shenanigans around the retry+the upstream filter. ResponseHeaderMode: extprocv3http.ProcessingMode_SKIP, ResponseBodyMode: extprocv3http.ProcessingMode_NONE, diff --git a/internal/extproc/chatcompletion_processor.go b/internal/extproc/chatcompletion_processor.go index d9f5d3dfa6..d5bb52ef35 100644 --- a/internal/extproc/chatcompletion_processor.go +++ b/internal/extproc/chatcompletion_processor.go @@ -188,31 +188,22 @@ func (c *chatCompletionProcessorUpstreamFilter) selectTranslator(out filterapi.V } // ProcessRequestHeaders implements [Processor.ProcessRequestHeaders]. -func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { - // Start tracking metrics for this request. - c.metrics.StartRequest(c.requestHeaders) - - // The request headers have already been at the time the processor was created. - return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extprocv3.HeadersResponse{}, - }}, nil -} - -// ProcessRequestBody implements [Processor.ProcessRequestBody]. -func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.Context, _ *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { +// +// At the upstream filter, we already have the original request body at request headers phase. +// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy +// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again +// to the extproc. +func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { defer func() { if err != nil { c.metrics.RecordRequestCompletion(ctx, false) } }() - // TODO: We do not use the body from the extproc request since we might have already translated it - // to the upstream format on the previous retry (if any). If it's possible, we should be able to - // configure the extproc filter to "not send the body but execute the ProcessRequestBody" method. - // Currently, there's no way to do this, hence Envoy has to "unnecessarily" send the entire request body - // to the extproc twice. - + // Start tracking metrics for this request. + c.metrics.StartRequest(c.requestHeaders) c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey]) + headerMutation, bodyMutation, err := c.translator.RequestBody(c.originalRequestBodyRaw, c.originalRequestBody, c.onRetry) if err != nil { return nil, fmt.Errorf("failed to transform request: %w", err) @@ -230,15 +221,21 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.C } } - resp := &extprocv3.ProcessingResponse{ - Response: &extprocv3.ProcessingResponse_RequestBody{ - RequestBody: &extprocv3.BodyResponse{ - Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation, BodyMutation: bodyMutation}, + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extprocv3.HeadersResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, BodyMutation: bodyMutation, + Status: extprocv3.CommonResponse_CONTINUE_AND_REPLACE, + }, }, }, - } - c.stream = c.originalRequestBody.Stream - return resp, nil + }, nil +} + +// ProcessRequestBody implements [Processor.ProcessRequestBody]. +func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { + panic("BUG: ProcessRequestBody should not be called in the upstream filter") } // ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. diff --git a/internal/extproc/chatcompletion_processor_test.go b/internal/extproc/chatcompletion_processor_test.go index 596be05117..91dc9d1b4c 100644 --- a/internal/extproc/chatcompletion_processor_test.go +++ b/internal/extproc/chatcompletion_processor_test.go @@ -134,19 +134,6 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { }) } -func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { - mm := &mockChatCompletionMetrics{} - p := &chatCompletionProcessorUpstreamFilter{metrics: mm} - res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{ - Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}}, - }) - require.NoError(t, err) - _, ok := res.Response.(*extprocv3.ProcessingResponse_RequestHeaders) - require.True(t, ok) - require.NotZero(t, mm.requestStart) - mm.RequireRequestNotCompleted(t) -} - func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) { t.Run("error translation", func(t *testing.T) { mm := &mockChatCompletionMetrics{} @@ -317,7 +304,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) { require.False(t, p.stream) // On error, stream should be false regardless of the input. } -func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) { +func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { const modelKey = "x-ai-gateway-model-key" for _, stream := range []bool{false, true} { t.Run(fmt.Sprintf("stream%v", stream), func(t *testing.T) { @@ -338,13 +325,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) translator: tr, originalRequestBodyRaw: someBody, originalRequestBody: &body, + stream: stream, } - _, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody}) + _, err := p.ProcessRequestHeaders(t.Context(), nil) require.ErrorContains(t, err, "failed to transform request: test error") mm.RequireRequestFailure(t) mm.RequireTokensRecorded(t, 0) mm.RequireSelectedModel(t, "some-model") - require.False(t, p.stream) // On error, stream should be false regardless of the input. }) t.Run("ok", func(t *testing.T) { someBody := bodyFromModel(t, "some-model", stream) @@ -367,12 +354,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) translator: mt, originalRequestBodyRaw: someBody, originalRequestBody: &expBody, + stream: stream, } - resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{}) + resp, err := p.ProcessRequestHeaders(t.Context(), nil) require.NoError(t, err) require.Equal(t, mt, p.translator) require.NotNil(t, resp) - commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestBody).RequestBody.Response + commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response require.Equal(t, headerMut, commonRes.HeaderMutation) require.Equal(t, bodyMut, commonRes.BodyMutation) diff --git a/tests/extproc/envoy.yaml b/tests/extproc/envoy.yaml index 35b58fcbed..2997a18e4f 100644 --- a/tests/extproc/envoy.yaml +++ b/tests/extproc/envoy.yaml @@ -225,7 +225,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -270,7 +270,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -315,7 +315,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -360,7 +360,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -416,7 +416,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -478,7 +478,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -540,7 +540,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -610,7 +610,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -661,7 +661,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -712,7 +712,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -763,7 +763,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: @@ -814,7 +814,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SKIP" response_body_mode: "NONE" grpc_service: diff --git a/tests/internal/testupstreamlib/testupstream/main.go b/tests/internal/testupstreamlib/testupstream/main.go index c05aaf5afc..9af0c81c13 100644 --- a/tests/internal/testupstreamlib/testupstream/main.go +++ b/tests/internal/testupstreamlib/testupstream/main.go @@ -15,7 +15,6 @@ import ( "net" "net/http" "os" - "slices" "strconv" "time" @@ -162,19 +161,6 @@ func handler(w http.ResponseWriter, r *http.Request) { return } - // At least for the endpoints we want to support, all requests should have a Content-Length header - // and should not use chunked transfer encoding. - if r.Header.Get("Content-Length") == "" { - logger.Println("no Content-Length header, using request body length:", len(requestBody)) - http.Error(w, "no Content-Length header, using request body length: "+strconv.Itoa(len(requestBody)), http.StatusBadRequest) - return - } - if slices.Contains(r.TransferEncoding, "chunked") { - logger.Println("chunked transfer encoding detected") - http.Error(w, "chunked transfer encoding is not supported", http.StatusBadRequest) - return - } - if expectedReqBody := r.Header.Get(testupstreamlib.ExpectedRequestBodyHeaderKey); expectedReqBody != "" { var expectedBody []byte expectedBody, err = base64.StdEncoding.DecodeString(expectedReqBody) diff --git a/tests/internal/testupstreamlib/testupstream/main_test.go b/tests/internal/testupstreamlib/testupstream/main_test.go index b05c06c034..8d848a6bbf 100644 --- a/tests/internal/testupstreamlib/testupstream/main_test.go +++ b/tests/internal/testupstreamlib/testupstream/main_test.go @@ -43,7 +43,7 @@ func Test_main(t *testing.T) { t.Run("sse", func(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", strings.NewReader("some-body")) + request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", nil) require.NoError(t, err) request.Header.Set(testupstreamlib.ResponseTypeKey, "sse") request.Header.Set(testupstreamlib.ResponseBodyHeaderKey, @@ -271,7 +271,7 @@ func Test_main(t *testing.T) { t.Run("aws-event-stream", func(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", strings.NewReader("some-body")) + request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", nil) require.NoError(t, err) request.Header.Set(testupstreamlib.ResponseTypeKey, "aws-event-stream") request.Header.Set(testupstreamlib.ResponseBodyHeaderKey,