Skip to content

Commit

Permalink
Merge pull request #2105 from aws/recursion-detection
Browse files Browse the repository at this point in the history
Add Recursion Detection middleware to all SDK requests
  • Loading branch information
wty-Bryant authored Apr 24, 2023
2 parents 7399331 + 75061a4 commit 44acf0c
Show file tree
Hide file tree
Showing 13,519 changed files with 40,773 additions and 0 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
8 changes: 8 additions & 0 deletions .changelog/d74f8a813ddb431fb6006abefbdaba1b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "d74f8a81-3ddb-431f-b600-6abefbdaba1b",
"type": "feature",
"description": "add recursion detection middleware to all SDK requests to avoid recursion invocation in Lambda",
"modules": [
"."
]
}
94 changes: 94 additions & 0 deletions aws/middleware/recursion_detection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package middleware

import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"os"
)

const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME"
const envAmznTraceID = "_X_AMZN_TRACE_ID"
const amznTraceIDHeader = "X-Amzn-Trace-Id"

// AddRecursionDetection adds recursionDetection to the middleware stack
func AddRecursionDetection(stack *middleware.Stack) error {
return stack.Build.Add(&RecursionDetection{}, middleware.After)
}

// RecursionDetection detects Lambda environment and sets its X-Ray trace ID to request header if absent
// to avoid recursion invocation in Lambda
type RecursionDetection struct{}

// ID returns the middleware identifier
func (m *RecursionDetection) ID() string {
return "RecursionDetection"
}

// HandleBuild detects Lambda environment and adds its trace ID to request header if absent
func (m *RecursionDetection) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}

_, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName)
xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID)
value := req.Header.Get(amznTraceIDHeader)
// only set the X-Amzn-Trace-Id header when it is not set initially, the
// current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists
if value != "" || !hasLambdaEnv || !hasTraceID {
return next.HandleBuild(ctx, in)
}

req.Header.Set(amznTraceIDHeader, percentEncode(xAmznTraceID))
return next.HandleBuild(ctx, in)
}

func percentEncode(s string) string {
upperhex := "0123456789ABCDEF"
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEncode(c) {
hexCount++
}
}

if hexCount == 0 {
return s
}

required := len(s) + 2*hexCount
t := make([]byte, required)
j := 0
for i := 0; i < len(s); i++ {
if c := s[i]; shouldEncode(c) {
t[j] = '%'
t[j+1] = upperhex[c>>4]
t[j+2] = upperhex[c&15]
j += 3
} else {
t[j] = c
j++
}
}
return string(t)
}

func shouldEncode(c byte) bool {
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
return false
}
switch c {
case '-', '=', ';', ':', '+', '&', '[', ']', '{', '}', '"', '\'', ',':
return false
default:
return true
}
}
87 changes: 87 additions & 0 deletions aws/middleware/recursion_detection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package middleware

import (
"context"
smithymiddleware "github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"os"
"testing"
)

func TestRecursionDetection(t *testing.T) {
cases := map[string]struct {
LambdaFuncName string
TraceID string
HeaderBefore string
HeaderAfter string
}{
"non lambda env and no trace ID header before": {},
"with lambda env but no trace ID env variable, no trace ID header before": {
LambdaFuncName: "some-function1",
},
"with lambda env and trace ID env variable, no trace ID header before": {
LambdaFuncName: "some-function2",
TraceID: "traceID1",
HeaderAfter: "traceID1",
},
"with lambda env and trace ID env variable, has trace ID header before": {
LambdaFuncName: "some-function3",
TraceID: "traceID2",
HeaderBefore: "traceID1",
HeaderAfter: "traceID1",
},
"with lambda env and trace ID (needs encoding) env variable, no trace ID header before": {
LambdaFuncName: "some-function4",
TraceID: "traceID3\n",
HeaderAfter: "traceID3%0A",
},
"with lambda env and trace ID (contains chars must not be encoded) env variable, no trace ID header before": {
LambdaFuncName: "some-function5",
TraceID: "traceID4-=;:+&[]{}\"'",
HeaderAfter: "traceID4-=;:+&[]{}\"'",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
// clear current case's environment variables and restore them at the end of the test func goroutine
restoreEnv := clearEnv()
defer restoreEnv()

setEnvVar(t, envAwsLambdaFunctionName, c.LambdaFuncName)
setEnvVar(t, envAmznTraceID, c.TraceID)

req := smithyhttp.NewStackRequest().(*smithyhttp.Request)
if c.HeaderBefore != "" {
req.Header.Set(amznTraceIDHeader, c.HeaderBefore)
}
var updatedRequest *smithyhttp.Request
m := RecursionDetection{}
_, _, err := m.HandleBuild(context.Background(),
smithymiddleware.BuildInput{Request: req},
smithymiddleware.BuildHandlerFunc(func(ctx context.Context, input smithymiddleware.BuildInput) (
out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) {
updatedRequest = input.Request.(*smithyhttp.Request)
return out, metadata, nil
}),
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.HeaderAfter, updatedRequest.Header.Get(amznTraceIDHeader); e != a {
t.Errorf("expect header value %v found, got %v", e, a)
}
})
}
}

// check if test case has environment variable and set to os if it has
func setEnvVar(t *testing.T, key, value string) {
if value != "" {
err := os.Setenv(key, value)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package software.amazon.smithy.aws.go.codegen.customization;

import software.amazon.smithy.aws.go.codegen.AwsGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.utils.ListUtils;

import java.util.List;

/**
* Add middleware during operation builder step, which detects Lambda environment and sets its X-Ray trace ID to
* request header if absent to avoid recursion invocation in Lambda
*/
public class LambdaRecursionDetection implements GoIntegration {
/**
* Gets the sort order of the customization from -128 to 127, with lowest
* executed first.
*
* @return Returns the sort order, defaults to -40.
*/
@Override
public byte getOrder() {
return 126;
}

@Override
public List<RuntimeClientPlugin> getClientPlugins() {
return ListUtils.of(
RuntimeClientPlugin.builder()
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(
"AddRecursionDetection", AwsGoDependency.AWS_MIDDLEWARE)
.build())
.build()
)
.build()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ software.amazon.smithy.aws.go.codegen.customization.SQSValidateMessageChecksum
software.amazon.smithy.aws.go.codegen.EndpointDiscoveryGenerator
software.amazon.smithy.aws.go.codegen.customization.S3100Continue
software.amazon.smithy.aws.go.codegen.customization.ApiGatewayExportsNullabilityExceptionIntegration
software.amazon.smithy.aws.go.codegen.customization.LambdaRecursionDetection

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions internal/protocoltest/awsrestjson/api_op_DatetimeOffsets.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions internal/protocoltest/awsrestjson/api_op_DocumentType.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions internal/protocoltest/awsrestjson/api_op_EndpointOperation.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions internal/protocoltest/awsrestjson/api_op_HttpEnumPayload.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions internal/protocoltest/awsrestjson/api_op_HttpPayloadTraits.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 44acf0c

Please sign in to comment.