diff --git a/internal/extproc/backendauth/aws.go b/internal/extproc/backendauth/aws.go index c4f37ecd81..c03f40162c 100644 --- a/internal/extproc/backendauth/aws.go +++ b/internal/extproc/backendauth/aws.go @@ -106,6 +106,14 @@ func (a *awsHandler) Do(ctx context.Context, requestHeaders map[string]string, h if err != nil { return fmt.Errorf("cannot create request: %w", err) } + // By setting the content length to -1, we can avoid the inclusion of the `Content-Length` header in the signature. + // https://github.com/aws/aws-sdk-go-v2/blob/755839b2eebb246c7eec79b65404aee105196d5b/aws/signer/v4/v4.go#L427-L431 + // + // The reason why we want to avoid this is that the ExtProc filter will remove the content-length header + // from the request currently. Envoy will instead do "transfer-encoding: chunked" for the request body, + // which should be acceptable for AWS Bedrock or any modern HTTP service. + // https://github.com/envoyproxy/envoy/blob/60b2b5187cf99db79ecfc54675354997af4765ea/source/extensions/filters/http/ext_proc/processor_state.cc#L180-L183 + req.ContentLength = -1 err = a.signer.SignHTTP(ctx, a.credentials, req, hex.EncodeToString(payloadHash[:]), "bedrock", a.region, time.Now()) diff --git a/tests/extproc/real_providers_test.go b/tests/extproc/real_providers_test.go index f061c6bdb8..f3b3572734 100644 --- a/tests/extproc/real_providers_test.go +++ b/tests/extproc/real_providers_test.go @@ -85,7 +85,6 @@ func TestWithRealProviders(t *testing.T) { requireExtProc(t, os.Stdout, extProcExecutablePath(), configPath) t.Run("health-checking", func(t *testing.T) { - client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/")) for _, tc := range []realProvidersTestCase{ {name: "openai", modelName: "gpt-4o-mini", required: internaltesting.RequiredCredentialOpenAI}, {name: "aws-bedrock", modelName: "us.meta.llama3-2-1b-instruct-v1:0", required: internaltesting.RequiredCredentialAWS}, @@ -93,26 +92,7 @@ func TestWithRealProviders(t *testing.T) { } { t.Run(tc.modelName, func(t *testing.T) { cc.MaybeSkip(t, tc.required) - require.Eventually(t, func() bool { - chatCompletion, err := client.Chat.Completions.New(t.Context(), openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Say this is a test"), - }, - Model: tc.modelName, - }) - if err != nil { - t.Logf("error: %v", err) - return false - } - nonEmptyCompletion := false - for _, choice := range chatCompletion.Choices { - t.Logf("choice: %s", choice.Message.Content) - if choice.Message.Content != "" { - nonEmptyCompletion = true - } - } - return nonEmptyCompletion - }, eventuallyTimeout, eventuallyInterval) + requireEventuallyNonStreamingRequestOK(t, tc.modelName, "Say this is a test") }) } }) @@ -322,6 +302,11 @@ func TestWithRealProviders(t *testing.T) { "o1", }, models) }) + t.Run("aws-bedrock-large-body", func(t *testing.T) { + cc.MaybeSkip(t, internaltesting.RequiredCredentialAWS) + requireEventuallyNonStreamingRequestOK(t, + "us.meta.llama3-2-1b-instruct-v1:0", strings.Repeat("Say this is a test", 10000)) + }) } // realProvidersTestCase is a base test case for the real providers, which is mainly for the centralization of the @@ -331,3 +316,27 @@ type realProvidersTestCase struct { modelName string required internaltesting.RequiredCredential } + +func requireEventuallyNonStreamingRequestOK(t *testing.T, modelName, msg string) { + client := openai.NewClient(option.WithBaseURL(listenerAddress + "/v1/")) + require.Eventually(t, func() bool { + chatCompletion, err := client.Chat.Completions.New(t.Context(), openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(msg), + }, + Model: modelName, + }) + if err != nil { + t.Logf("error: %v", err) + return false + } + nonEmptyCompletion := false + for _, choice := range chatCompletion.Choices { + t.Logf("choice: %s", choice.Message.Content) + if choice.Message.Content != "" { + nonEmptyCompletion = true + } + } + return nonEmptyCompletion + }, eventuallyTimeout, eventuallyInterval) +}