diff --git a/internal/extensionserver/extensionserver.go b/internal/extensionserver/extensionserver.go index 6437bd9430..aed97d69a9 100644 --- a/internal/extensionserver/extensionserver.go +++ b/internal/extensionserver/extensionserver.go @@ -210,8 +210,10 @@ func (s *Server) maybeModifyCluster(cluster *clusterv3.Cluster) { extProcConfig.AllowModeOverride = true extProcConfig.RequestAttributes = []string{"xds.upstream_host_metadata"} extProcConfig.ProcessingMode = &extprocv3http.ProcessingMode{ - RequestHeaderMode: extprocv3http.ProcessingMode_SEND, - RequestBodyMode: extprocv3http.ProcessingMode_BUFFERED, + RequestHeaderMode: extprocv3http.ProcessingMode_SEND, + // At the upstream filter, it can access the original body in its memory, so it can perform the translation + // as well as the authentication at the request headers. Hence, there's no need to send the request body to the extproc. + RequestBodyMode: extprocv3http.ProcessingMode_NONE, ResponseHeaderMode: extprocv3http.ProcessingMode_SEND, ResponseBodyMode: extprocv3http.ProcessingMode_BUFFERED, } diff --git a/internal/extproc/chatcompletion_processor.go b/internal/extproc/chatcompletion_processor.go index b5a4422a53..f9b0afd3a5 100644 --- a/internal/extproc/chatcompletion_processor.go +++ b/internal/extproc/chatcompletion_processor.go @@ -159,31 +159,22 @@ func (c *chatCompletionProcessorUpstreamFilter) selectTranslator(out filterapi.V } // ProcessRequestHeaders implements [Processor.ProcessRequestHeaders]. -func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { - // Start tracking metrics for this request. - c.metrics.StartRequest(c.requestHeaders) - - // The request headers have already been at the time the processor was created. - return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extprocv3.HeadersResponse{}, - }}, nil -} - -// ProcessRequestBody implements [Processor.ProcessRequestBody]. -func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.Context, _ *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { +// +// At the upstream filter, we already have the original request body at request headers phase. +// So, we simply do the translation and upstream auth at this stage, and send them back to Envoy +// with the status CONTINUE_AND_REPLACE. This will allows Envoy to not send the request body again +// to the extproc. +func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestHeaders(ctx context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) { defer func() { if err != nil { c.metrics.RecordRequestCompletion(ctx, false) } }() - // TODO: We do not use the body from the extproc request since we might have already translated it - // to the upstream format on the previous retry (if any). If it's possible, we should be able to - // configure the extproc filter to "not send the body but execute the ProcessRequestBody" method. - // Currently, there's no way to do this, hence Envoy has to "unnecessarily" send the entire request body - // to the extproc twice. - + // Start tracking metrics for this request. + c.metrics.StartRequest(c.requestHeaders) c.metrics.SetModel(c.requestHeaders[c.config.modelNameHeaderKey]) + headerMutation, bodyMutation, err := c.translator.RequestBody(c.originalRequestBodyRaw, c.originalRequestBody, c.onRetry) if err != nil { return nil, fmt.Errorf("failed to transform request: %w", err) @@ -201,15 +192,21 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(ctx context.C } } - resp := &extprocv3.ProcessingResponse{ - Response: &extprocv3.ProcessingResponse_RequestBody{ - RequestBody: &extprocv3.BodyResponse{ - Response: &extprocv3.CommonResponse{HeaderMutation: headerMutation, BodyMutation: bodyMutation}, + return &extprocv3.ProcessingResponse{ + Response: &extprocv3.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extprocv3.HeadersResponse{ + Response: &extprocv3.CommonResponse{ + HeaderMutation: headerMutation, BodyMutation: bodyMutation, + Status: extprocv3.CommonResponse_CONTINUE_AND_REPLACE, + }, }, }, - } - c.stream = c.originalRequestBody.Stream - return resp, nil + }, nil +} + +// ProcessRequestBody implements [Processor.ProcessRequestBody]. +func (c *chatCompletionProcessorUpstreamFilter) ProcessRequestBody(context.Context, *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) { + panic("BUG: ProcessRequestBody should not be called in the upstream filter") } // ProcessResponseHeaders implements [Processor.ProcessResponseHeaders]. @@ -313,6 +310,7 @@ func (c *chatCompletionProcessorUpstreamFilter) SetBackend(ctx context.Context, c.originalRequestBody = rp.originalRequestBody c.originalRequestBodyRaw = rp.originalRequestBodyRaw c.onRetry = rp.upstreamFilterCount > 1 + c.stream = c.originalRequestBody.Stream return } diff --git a/internal/extproc/chatcompletion_processor_test.go b/internal/extproc/chatcompletion_processor_test.go index 3cac813e44..ea3acd0914 100644 --- a/internal/extproc/chatcompletion_processor_test.go +++ b/internal/extproc/chatcompletion_processor_test.go @@ -133,19 +133,6 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { }) } -func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { - mm := &mockChatCompletionMetrics{} - p := &chatCompletionProcessorUpstreamFilter{metrics: mm} - res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{ - Headers: []*corev3.HeaderValue{{Key: "foo", Value: "bar"}}, - }) - require.NoError(t, err) - _, ok := res.Response.(*extprocv3.ProcessingResponse_RequestHeaders) - require.True(t, ok) - require.NotZero(t, mm.requestStart) - mm.RequireRequestNotCompleted(t) -} - func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseHeaders(t *testing.T) { t.Run("error translation", func(t *testing.T) { mm := &mockChatCompletionMetrics{} @@ -307,7 +294,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) { require.False(t, p.stream) // On error, stream should be false regardless of the input. } -func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) { +func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { const modelKey = "x-ai-gateway-model-key" for _, stream := range []bool{false, true} { t.Run(fmt.Sprintf("stream%v", stream), func(t *testing.T) { @@ -328,13 +315,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) translator: tr, originalRequestBodyRaw: someBody, originalRequestBody: &body, + stream: stream, } - _, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{Body: someBody}) + _, err := p.ProcessRequestHeaders(t.Context(), nil) require.ErrorContains(t, err, "failed to transform request: test error") mm.RequireRequestFailure(t) mm.RequireTokensRecorded(t, 0) mm.RequireSelectedModel(t, "some-model") - require.False(t, p.stream) // On error, stream should be false regardless of the input. }) t.Run("ok", func(t *testing.T) { someBody := bodyFromModel(t, "some-model", stream) @@ -357,12 +344,13 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestBody(t *testing.T) translator: mt, originalRequestBodyRaw: someBody, originalRequestBody: &expBody, + stream: stream, } - resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{}) + resp, err := p.ProcessRequestHeaders(t.Context(), nil) require.NoError(t, err) require.Equal(t, mt, p.translator) require.NotNil(t, resp) - commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestBody).RequestBody.Response + commonRes := resp.Response.(*extprocv3.ProcessingResponse_RequestHeaders).RequestHeaders.Response require.Equal(t, headerMut, commonRes.HeaderMutation) require.Equal(t, bodyMut, commonRes.BodyMutation) diff --git a/tests/extproc/envoy.yaml b/tests/extproc/envoy.yaml index 3147a8780b..2073e6a18a 100644 --- a/tests/extproc/envoy.yaml +++ b/tests/extproc/envoy.yaml @@ -168,7 +168,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SEND" response_body_mode: "BUFFERED" grpc_service: @@ -217,7 +217,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SEND" response_body_mode: "BUFFERED" grpc_service: @@ -266,7 +266,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SEND" response_body_mode: "BUFFERED" grpc_service: @@ -326,7 +326,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SEND" response_body_mode: "BUFFERED" grpc_service: @@ -392,7 +392,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SEND" response_body_mode: "BUFFERED" grpc_service: @@ -458,7 +458,7 @@ static_resources: - xds.upstream_host_metadata processing_mode: request_header_mode: "SEND" - request_body_mode: "BUFFERED" + request_body_mode: "NONE" response_header_mode: "SEND" response_body_mode: "BUFFERED" grpc_service: