diff --git a/lambda/handler.go b/lambda/handler.go index c455656d..3a644dc2 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -12,6 +12,7 @@ import ( "io/ioutil" // nolint:staticcheck "reflect" "strings" + "sync" "github.com/aws/aws-lambda-go/lambda/handlertrace" ) @@ -31,6 +32,7 @@ type handlerOptions struct { jsonResponseIndentValue string enableSIGTERM bool sigtermCallbacks []func() + jsonOutBufferPool *sync.Pool // contains *jsonOutBuffer } type Option func(*handlerOptions) @@ -227,12 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions { if h, ok := handlerFunc.(*handlerOptions); ok { return h } + pool := &sync.Pool{} + pool.New = func() interface{} { + return &jsonOutBuffer{pool, bytes.NewBuffer(nil)} + } h := &handlerOptions{ baseContext: context.Background(), contextValues: map[interface{}]interface{}{}, jsonResponseEscapeHTML: false, jsonResponseIndentPrefix: "", jsonResponseIndentValue: "", + jsonOutBufferPool: pool, } for _, option := range options { option(h) @@ -280,6 +287,7 @@ func errorHandler(err error) handlerFunc { } type jsonOutBuffer struct { + pool *sync.Pool *bytes.Buffer } @@ -287,6 +295,12 @@ func (j *jsonOutBuffer) ContentType() string { return contentTypeJSON } +func (j *jsonOutBuffer) Close() error { + j.Reset() + j.pool.Put(j) + return nil +} + func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { if f == nil { return errorHandler(errors.New("handler is nil")) @@ -318,9 +332,7 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { return errorHandler(err) } - out := &jsonOutBuffer{bytes.NewBuffer(nil)} - return func(ctx context.Context, payload []byte) (io.Reader, error) { - out.Reset() + return func(ctx context.Context, payload []byte) (outFinal io.Reader, _ error) { in := bytes.NewBuffer(payload) decoder := json.NewDecoder(in) if h.jsonRequestUseNumber { @@ -329,6 +341,15 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { if h.jsonRequestDisallowUnknownFields { decoder.DisallowUnknownFields() } + + out := h.jsonOutBufferPool.Get().(*jsonOutBuffer) + defer func() { + // If the final return value is not our buffer, reset and return it to the pool. + // The caller of the handlerFunc does this otherwise. + if outFinal != out { + out.Close() + } + }() encoder := json.NewEncoder(out) encoder.SetEscapeHTML(h.jsonResponseEscapeHTML) encoder.SetIndent(h.jsonResponseIndentPrefix, h.jsonResponseIndentValue) diff --git a/lambda/invoke_loop.go b/lambda/invoke_loop.go index 338237ea..89cb1f4c 100644 --- a/lambda/invoke_loop.go +++ b/lambda/invoke_loop.go @@ -27,16 +27,13 @@ func unixMS(ms int64) time.Time { return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS) } -// startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error -func startRuntimeAPILoop(api string, handler Handler) error { - client := newRuntimeAPIClient(api) - h := newHandler(handler) +func doRuntimeAPILoop(ctx context.Context, client *runtimeAPIClient, handler *handlerOptions) error { for { - invoke, err := client.next() + invoke, err := client.next(ctx) if err != nil { return err } - if err = handleInvoke(invoke, h); err != nil { + if err := handleInvoke(invoke, handler); err != nil { return err } } @@ -72,7 +69,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error { ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID) // call the handler, marshal any returned error - response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc) + response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload.Bytes(), handler.handlerFunc) if invokeErr != nil { if err := reportFailure(invoke, invokeErr); err != nil { return err diff --git a/lambda/invoke_loop_gte_go122.go b/lambda/invoke_loop_gte_go122.go new file mode 100644 index 00000000..b3254143 --- /dev/null +++ b/lambda/invoke_loop_gte_go122.go @@ -0,0 +1,41 @@ +//go:build go1.22 +// +build go1.22 + +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved + +package lambda + +import ( + "context" + "errors" + "sync" + + "github.com/aws/aws-lambda-go/lambdacontext" +) + +func startRuntimeAPILoop(api string, handler Handler) error { + return startRuntimeAPILoopWithConcurrency(api, handler, lambdacontext.MaxConcurrency()) +} + +func startRuntimeAPILoopWithConcurrency(api string, handler Handler, concurrency int) error { + h := newHandler(handler) + client := newRuntimeAPIClient(api) + if concurrency <= 1 { + return doRuntimeAPILoop(context.Background(), client, h) + } + + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("no handlers run")) + + wg := &sync.WaitGroup{} + wg.Add(concurrency) + for range concurrency { + go func() { + cancel(doRuntimeAPILoop(ctx, client, h)) + wg.Done() + }() + } + wg.Wait() + + return context.Cause(ctx) +} diff --git a/lambda/invoke_loop_gte_go122_test.go b/lambda/invoke_loop_gte_go122_test.go new file mode 100644 index 00000000..cf3a906a --- /dev/null +++ b/lambda/invoke_loop_gte_go122_test.go @@ -0,0 +1,241 @@ +//go:build go1.22 +// +build go1.22 + +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved + +package lambda + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "math" + "math/rand" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-lambda-go/lambdacontext" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRuntimeAPILoopWithConcurrency(t *testing.T) { + nInvokes := 100 + concurrency := 5 + + metadata := make([]eventMetadata, nInvokes) + for i := range nInvokes { + m := defaultInvokeMetadata() + m.requestID = fmt.Sprintf("request-%d", i) + metadata[i] = m + } + + ts, record := runtimeAPIServer(``, nInvokes, metadata...) + defer ts.Close() + + active := atomic.Int32{} + maxActive := atomic.Int32{} + handler := NewHandler(func(ctx context.Context) (string, error) { + activeNow := active.Add(1) + defer active.Add(-1) + for pr := maxActive.Load(); activeNow > pr; pr = maxActive.Load() { + if maxActive.CompareAndSwap(pr, activeNow) { + break + } + } + lc, _ := lambdacontext.FromContext(ctx) + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) + switch lc.AwsRequestID[len(lc.AwsRequestID)-1:] { + case "6", "7": + return "", fmt.Errorf("error-%s", lc.AwsRequestID) + default: + return lc.AwsRequestID, nil + } + }) + endpoint := strings.Split(ts.URL, "://")[1] + expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) + assert.EqualError(t, startRuntimeAPILoopWithConcurrency(endpoint, handler, concurrency), expectedError) + assert.GreaterOrEqual(t, record.nGets, nInvokes+1) + assert.Equal(t, nInvokes, record.nPosts) + assert.Equal(t, int32(concurrency), maxActive.Load()) + responses := make(map[string]int) + for _, response := range record.responses { + responses[string(response)]++ + } + assert.Len(t, responses, nInvokes) + for response, count := range responses { + assert.Equal(t, 1, count, "response %s seen %d times", response, count) + } + for i := range nInvokes { + switch i % 10 { + case 6, 7: + assert.Contains(t, responses, fmt.Sprintf(`{"errorMessage":"error-request-%d","errorType":"errorString"}`, i)) + default: + assert.Contains(t, responses, fmt.Sprintf(`"request-%d"`, i)) + } + } +} + +func TestRuntimeAPILoopSingleConcurrency(t *testing.T) { + nInvokes := 10 + + ts, record := runtimeAPIServer(``, nInvokes) + defer ts.Close() + + var counter atomic.Int32 + handler := NewHandler(func(ctx context.Context) (string, error) { + counter.Add(1) + return "Hello!", nil + }) + endpoint := strings.Split(ts.URL, "://")[1] + expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) + assert.EqualError(t, startRuntimeAPILoopWithConcurrency(endpoint, handler, 1), expectedError) + assert.Equal(t, nInvokes+1, record.nGets) + assert.Equal(t, nInvokes, record.nPosts) + assert.Equal(t, int32(nInvokes), counter.Load()) +} + +func TestRuntimeAPILoopWithConcurrencyPanic(t *testing.T) { + concurrency := 3 + + ts, record := runtimeAPIServer(``, 100) + defer ts.Close() + + var logBuf bytes.Buffer + log.SetOutput(&logBuf) + defer log.SetOutput(os.Stderr) + + var counter atomic.Int32 + handler := NewHandler(func() error { + n := counter.Add(1) + time.Sleep(time.Duration(n) * 10 * time.Millisecond) + panic(fmt.Errorf("panic %d", n)) + }) + endpoint := strings.Split(ts.URL, "://")[1] + err := startRuntimeAPILoopWithConcurrency(endpoint, handler, concurrency) + require.Error(t, err) + assert.Contains(t, err.Error(), "calling the handler function resulted in a panic, the process should exit") + assert.Equal(t, concurrency, record.nGets) + assert.Equal(t, concurrency, record.nPosts) + assert.Equal(t, int32(concurrency), counter.Load()) + assert.Contains(t, string(record.responses[0]), "panic 1") + logs := logBuf.String() + idx1 := strings.Index(logs, "panic 1") + idx2 := strings.Index(logs, "panic 2") + idx3 := strings.Index(logs, "panic 3") + assert.Greater(t, idx1, -1) + assert.Greater(t, idx2, idx1) + assert.Greater(t, idx3, idx2) +} + +func TestConcurrencyWithRIE(t *testing.T) { + containerCmd := "" + if _, err := exec.LookPath("finch"); err == nil { + containerCmd = "finch" + } else if _, err := exec.LookPath("docker"); err == nil { + containerCmd = "docker" + } else { + t.Skip("finch or docker required") + } + + testDir := t.TempDir() + handlerBuild := exec.Command("go", "build", "-o", filepath.Join(testDir, "bootstrap"), "./testdata/sleep.go") + handlerBuild.Env = append(os.Environ(), "GOOS=linux") + require.NoError(t, handlerBuild.Run()) + + nInvokes := 10 + concurrency := 3 + sleepMs := 1000 + batches := int(math.Ceil(float64(nInvokes) / float64(concurrency))) + expectedMaxDuration := time.Duration(float64(batches*sleepMs)*1.1) * time.Millisecond // 10% margin for retries, network overhead, scheduling + + // Find an available port + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cmd := exec.Command(containerCmd, "run", "--rm", + "-v", testDir+":/var/runtime:ro,delegated", + "-p", fmt.Sprintf("%d:8080", port), + "-e", fmt.Sprintf("AWS_LAMBDA_MAX_CONCURRENCY=%d", concurrency), + "public.ecr.aws/lambda/provided:al2023", + "bootstrap") + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + stderr, err := cmd.StderrPipe() + require.NoError(t, err) + + var logBuf strings.Builder + logDone := make(chan struct{}) + go func() { + _, _ = io.Copy(io.MultiWriter(os.Stderr, &logBuf), io.MultiReader(stdout, stderr)) + close(logDone) + + }() + + require.NoError(t, cmd.Start()) + t.Cleanup(func() { _ = cmd.Process.Kill() }) + + time.Sleep(5 * time.Second) // Wait for container to start and pull image if needed + + client := &http.Client{Timeout: 15 * time.Second} + invokeURL := fmt.Sprintf("http://127.0.0.1:%d/2015-03-31/functions/function/invocations", port) + + start := time.Now() + var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + for range nInvokes { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + } + time.Sleep(50 * time.Millisecond) + body := strings.NewReader(fmt.Sprintf(`{"sleep_ms":%d}`, sleepMs)) + resp, err := client.Post(invokeURL, "application/json", body) + if err != nil { + continue + } + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + if resp.StatusCode == 400 { + continue + } + return + } + }() + } + wg.Wait() + duration := time.Since(start) + + t.Logf("Completed %d invocations in %v", nInvokes, duration) + + _ = cmd.Process.Kill() + _ = cmd.Wait() + <-logDone + + logs := logBuf.String() + processingCount := strings.Count(logs, "processing") + completedCount := strings.Count(logs, "completed") + + assert.Equal(t, nInvokes, processingCount, "expected %d processing logs", nInvokes) + assert.Equal(t, nInvokes, completedCount, "expected %d completed logs", nInvokes) + assert.Less(t, duration, expectedMaxDuration, "concurrent execution should complete faster than sequential") + +} diff --git a/lambda/invoke_loop_lte_go121.go b/lambda/invoke_loop_lte_go121.go new file mode 100644 index 00000000..f91a4e77 --- /dev/null +++ b/lambda/invoke_loop_lte_go121.go @@ -0,0 +1,14 @@ +//go:build !go1.22 +// +build !go1.22 + +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved + +package lambda + +import ( + "context" +) + +func startRuntimeAPILoop(api string, handler Handler) error { + return doRuntimeAPILoop(context.Background(), newRuntimeAPIClient(api), newHandler(handler)) +} diff --git a/lambda/invoke_loop_test.go b/lambda/invoke_loop_test.go index 3374dc2a..1f4dbd18 100644 --- a/lambda/invoke_loop_test.go +++ b/lambda/invoke_loop_test.go @@ -13,6 +13,7 @@ import ( "net/http/httptest" "os" "strings" + "sync" "testing" "unicode/utf8" @@ -371,6 +372,7 @@ func TestSafeMarshal_SerializationError(t *testing.T) { } type requestRecord struct { + lock sync.Mutex nGets int nPosts int responses [][]byte @@ -410,20 +412,27 @@ func defaultInvokeMetadata() eventMetadata { func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMetadata) (*httptest.Server, *requestRecord) { numInvokesRequested := 0 + numInvokesRequestedLock := sync.Mutex{} record := &requestRecord{} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: + numInvokesRequestedLock.Lock() metadata := defaultInvokeMetadata() if numInvokesRequested < len(overrides) { metadata = overrides[numInvokesRequested] } - record.nGets++ numInvokesRequested++ - if numInvokesRequested > failAfter { + shouldFail := numInvokesRequested > failAfter + numInvokesRequestedLock.Unlock() + record.lock.Lock() + record.nGets++ + record.lock.Unlock() + if shouldFail { w.WriteHeader(http.StatusGone) _, _ = w.Write([]byte("END THE TEST!")) + return } w.Header().Add(string(headerAWSRequestID), metadata.requestID) w.Header().Add(string(headerDeadlineMS), metadata.deadline) @@ -434,14 +443,16 @@ func runtimeAPIServer(eventPayload string, failAfter int, overrides ...eventMeta w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(eventPayload)) case http.MethodPost: - record.nPosts++ response := bytes.NewBuffer(nil) _, _ = io.Copy(response, r.Body) _ = r.Body.Close() w.WriteHeader(http.StatusAccepted) + record.lock.Lock() + record.nPosts++ record.responses = append(record.responses, response.Bytes()) record.contentTypes = append(record.contentTypes, r.Header.Get("Content-Type")) record.xrayCauses = append(record.xrayCauses, r.Header.Get(headerXRayErrorCause)) + record.lock.Unlock() default: w.WriteHeader(http.StatusBadRequest) } diff --git a/lambda/runtime_api_client.go b/lambda/runtime_api_client.go index 158bd6b4..1d268cc6 100644 --- a/lambda/runtime_api_client.go +++ b/lambda/runtime_api_client.go @@ -6,6 +6,7 @@ package lambda import ( "bytes" + "context" "encoding/base64" "fmt" "io" @@ -13,6 +14,7 @@ import ( "log" "net/http" "runtime" + "sync" ) const ( @@ -35,7 +37,7 @@ type runtimeAPIClient struct { baseURL string userAgent string httpClient *http.Client - buffer *bytes.Buffer + pool *sync.Pool } func newRuntimeAPIClient(address string) *runtimeAPIClient { @@ -44,12 +46,17 @@ func newRuntimeAPIClient(address string) *runtimeAPIClient { } endpoint := "http://" + address + "/" + apiVersion + "/runtime/invocation/" userAgent := "aws-lambda-go/" + runtime.Version() - return &runtimeAPIClient{endpoint, userAgent, client, bytes.NewBuffer(nil)} + pool := &sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, + } + return &runtimeAPIClient{endpoint, userAgent, client, pool} } type invoke struct { id string - payload []byte + payload *bytes.Buffer headers http.Header client *runtimeAPIClient } @@ -58,6 +65,9 @@ type invoke struct { // Notes: // - An invoke is not complete until next() is called again! func (i *invoke) success(body io.Reader, contentType string) error { + defer i.client.pool.Put(i.payload) + defer i.payload.Reset() + url := i.client.baseURL + i.id + "/response" return i.client.post(url, body, contentType, nil) } @@ -68,15 +78,18 @@ func (i *invoke) success(body io.Reader, contentType string) error { // - A Lambda Function continues to be re-used for future invokes even after a failure. // If the error is fatal (panic, unrecoverable state), exit the process immediately after calling failure() func (i *invoke) failure(body io.Reader, contentType string, causeForXRay []byte) error { + defer i.client.pool.Put(i.payload) + defer i.payload.Reset() + url := i.client.baseURL + i.id + "/error" return i.client.post(url, body, contentType, causeForXRay) } // next connects to the Runtime API and waits for a new invoke Request to be available. // Note: After a call to Done() or Error() has been made, a call to next() will complete the in-flight invoke. -func (c *runtimeAPIClient) next() (*invoke, error) { +func (c *runtimeAPIClient) next(ctx context.Context) (*invoke, error) { url := c.baseURL + "next" - req, err := http.NewRequest(http.MethodGet, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to construct GET request to %s: %v", url, err) } @@ -96,15 +109,17 @@ func (c *runtimeAPIClient) next() (*invoke, error) { return nil, fmt.Errorf("failed to GET %s: got unexpected status code: %d", url, resp.StatusCode) } - c.buffer.Reset() - _, err = c.buffer.ReadFrom(resp.Body) + payload := c.pool.Get().(*bytes.Buffer) + _, err = payload.ReadFrom(resp.Body) if err != nil { + payload.Reset() + c.pool.Put(payload) return nil, fmt.Errorf("failed to read the invoke payload: %v", err) } return &invoke{ id: resp.Header.Get(headerAWSRequestID), - payload: c.buffer.Bytes(), + payload: payload, headers: resp.Header, client: c, }, nil diff --git a/lambda/runtime_api_client_test.go b/lambda/runtime_api_client_test.go index 3f41403f..96b4b7e3 100644 --- a/lambda/runtime_api_client_test.go +++ b/lambda/runtime_api_client_test.go @@ -4,6 +4,7 @@ package lambda import ( "bytes" + "context" "fmt" "io/ioutil" //nolint: staticcheck "net/http" @@ -11,6 +12,8 @@ import ( "strings" "testing" + "io" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -38,17 +41,24 @@ func TestClientNext(t *testing.T) { defer returnsNoBody.Close() t.Run("handles regular response", func(t *testing.T) { - invoke, err := newRuntimeAPIClient(serverAddress(returnsBody)).next() + invoke, err := newRuntimeAPIClient(serverAddress(returnsBody)).next(context.Background()) require.NoError(t, err) assert.Equal(t, dummyRequestID, invoke.id) - assert.Equal(t, dummyPayload, string(invoke.payload)) + assert.Equal(t, dummyPayload, invoke.payload.String()) }) t.Run("handles no body", func(t *testing.T) { - invoke, err := newRuntimeAPIClient(serverAddress(returnsNoBody)).next() + invoke, err := newRuntimeAPIClient(serverAddress(returnsNoBody)).next(context.Background()) require.NoError(t, err) assert.Equal(t, dummyRequestID, invoke.id) - assert.Equal(t, 0, len(invoke.payload)) + assert.Equal(t, 0, len(invoke.payload.Bytes())) + }) + + t.Run("error on context canceled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := newRuntimeAPIClient(serverAddress(returnsNoBody)).next(ctx) + require.Error(t, err) }) } @@ -84,8 +94,9 @@ func TestClientDoneAndError(t *testing.T) { expectedPayloadsRecived := [][]byte{{}, {}, []byte("hello")} // nil payload expected to be read as empty bytes by the server for i, payload := range inputPayloads { invoke := &invoke{ - id: invokeID, - client: client, + id: invokeID, + client: client, + payload: bytes.NewBuffer(nil), } t.Run(fmt.Sprintf("happy Done with payload[%d]", i), func(t *testing.T) { err := invoke.success(bytes.NewReader(payload), contentTypeJSON) @@ -101,11 +112,11 @@ func TestClientDoneAndError(t *testing.T) { } func TestInvalidRequestsForMalformedEndpoint(t *testing.T) { - _, err := newRuntimeAPIClient("🚨").next() + _, err := newRuntimeAPIClient("🚨").next(context.Background()) require.Error(t, err) - err = (&invoke{client: newRuntimeAPIClient("🚨")}).success(nil, "") + err = (&invoke{client: newRuntimeAPIClient("🚨"), payload: bytes.NewBuffer(nil)}).success(nil, "") require.Error(t, err) - err = (&invoke{client: newRuntimeAPIClient("🚨")}).failure(nil, "", nil) + err = (&invoke{client: newRuntimeAPIClient("🚨"), payload: bytes.NewBuffer(nil)}).failure(nil, "", nil) require.Error(t, err) } @@ -115,22 +126,22 @@ func TestStatusCodes(t *testing.T) { url := fmt.Sprintf("status-%d", i) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = ioutil.ReadAll(r.Body) + _, _ = ioutil.ReadAll(io.Reader(r.Body)) w.WriteHeader(i) })) defer ts.Close() client := newRuntimeAPIClient(serverAddress(ts)) - invoke := &invoke{id: url, client: client} + invoke := &invoke{id: url, client: client, payload: bytes.NewBuffer(nil)} if i == http.StatusOK { t.Run("next should not error", func(t *testing.T) { - _, err := client.next() + _, err := client.next(context.Background()) require.NoError(t, err) }) } else { t.Run("next should error", func(t *testing.T) { - _, err := client.next() + _, err := client.next(context.Background()) require.Error(t, err) if i != 301 && i != 302 && i != 303 { assert.Contains(t, err.Error(), "unexpected status code") diff --git a/lambda/testdata/sleep.go b/lambda/testdata/sleep.go new file mode 100644 index 00000000..6a74ab50 --- /dev/null +++ b/lambda/testdata/sleep.go @@ -0,0 +1,29 @@ +package main + +import ( + "context" + "log/slog" + "time" + + "github.com/aws/aws-lambda-go/lambda" + "github.com/aws/aws-lambda-go/lambdacontext" +) + +type Event struct { + SleepMilliseconds int `json:"sleep_ms"` +} + +func handler(ctx context.Context, event Event) (string, error) { + lc, _ := lambdacontext.FromContext(ctx) + logger := slog.Default().With("handler", "sleep-test") + + logger.Info("processing", "request_id", lc.AwsRequestID, "sleep_ms", event.SleepMilliseconds) + time.Sleep(time.Duration(event.SleepMilliseconds) * time.Millisecond) + logger.Info("completed", "request_id", lc.AwsRequestID) + + return "ok", nil +} + +func main() { + lambda.Start(handler) +} diff --git a/lambdacontext/context.go b/lambdacontext/context.go index bd2e1664..658d870c 100644 --- a/lambdacontext/context.go +++ b/lambdacontext/context.go @@ -27,6 +27,8 @@ var MemoryLimitInMB int // FunctionVersion is the published version of the current instance of the Lambda Function var FunctionVersion string +var maxConcurrency int + func init() { LogGroupName = os.Getenv("AWS_LAMBDA_LOG_GROUP_NAME") LogStreamName = os.Getenv("AWS_LAMBDA_LOG_STREAM_NAME") @@ -37,6 +39,15 @@ func init() { MemoryLimitInMB = limit } FunctionVersion = os.Getenv("AWS_LAMBDA_FUNCTION_VERSION") + if v, err := strconv.Atoi(os.Getenv("AWS_LAMBDA_MAX_CONCURRENCY")); err != nil || v < 1 { + maxConcurrency = 1 + } else { + maxConcurrency = v + } +} + +func MaxConcurrency() int { + return maxConcurrency } // ClientApplication is metadata about the calling application. diff --git a/lambdaurl/http_handler_test.go b/lambdaurl/http_handler_test.go index 419cbd7c..d4d990b7 100644 --- a/lambdaurl/http_handler_test.go +++ b/lambdaurl/http_handler_test.go @@ -10,7 +10,7 @@ import ( _ "embed" "encoding/json" "io" - "io/ioutil" + "io/ioutil" //nolint: staticcheck "log" "net/http" "os"