Skip to content

Commit 36c1cc3

Browse files
authored
Add support for AWS_LAMBDA_MAX_CONCURRENCY (#600)
1 parent 867cb12 commit 36c1cc3

File tree

11 files changed

+426
-35
lines changed

11 files changed

+426
-35
lines changed

lambda/handler.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"io/ioutil" // nolint:staticcheck
1313
"reflect"
1414
"strings"
15+
"sync"
1516

1617
"github.com/aws/aws-lambda-go/lambda/handlertrace"
1718
)
@@ -31,6 +32,7 @@ type handlerOptions struct {
3132
jsonResponseIndentValue string
3233
enableSIGTERM bool
3334
sigtermCallbacks []func()
35+
jsonOutBufferPool *sync.Pool // contains *jsonOutBuffer
3436
}
3537

3638
type Option func(*handlerOptions)
@@ -227,12 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
227229
if h, ok := handlerFunc.(*handlerOptions); ok {
228230
return h
229231
}
232+
pool := &sync.Pool{}
233+
pool.New = func() interface{} {
234+
return &jsonOutBuffer{pool, bytes.NewBuffer(nil)}
235+
}
230236
h := &handlerOptions{
231237
baseContext: context.Background(),
232238
contextValues: map[interface{}]interface{}{},
233239
jsonResponseEscapeHTML: false,
234240
jsonResponseIndentPrefix: "",
235241
jsonResponseIndentValue: "",
242+
jsonOutBufferPool: pool,
236243
}
237244
for _, option := range options {
238245
option(h)
@@ -280,13 +287,20 @@ func errorHandler(err error) handlerFunc {
280287
}
281288

282289
type jsonOutBuffer struct {
290+
pool *sync.Pool
283291
*bytes.Buffer
284292
}
285293

286294
func (j *jsonOutBuffer) ContentType() string {
287295
return contentTypeJSON
288296
}
289297

298+
func (j *jsonOutBuffer) Close() error {
299+
j.Reset()
300+
j.pool.Put(j)
301+
return nil
302+
}
303+
290304
func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
291305
if f == nil {
292306
return errorHandler(errors.New("handler is nil"))
@@ -318,9 +332,7 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
318332
return errorHandler(err)
319333
}
320334

321-
out := &jsonOutBuffer{bytes.NewBuffer(nil)}
322-
return func(ctx context.Context, payload []byte) (io.Reader, error) {
323-
out.Reset()
335+
return func(ctx context.Context, payload []byte) (outFinal io.Reader, _ error) {
324336
in := bytes.NewBuffer(payload)
325337
decoder := json.NewDecoder(in)
326338
if h.jsonRequestUseNumber {
@@ -329,6 +341,15 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc {
329341
if h.jsonRequestDisallowUnknownFields {
330342
decoder.DisallowUnknownFields()
331343
}
344+
345+
out := h.jsonOutBufferPool.Get().(*jsonOutBuffer)
346+
defer func() {
347+
// If the final return value is not our buffer, reset and return it to the pool.
348+
// The caller of the handlerFunc does this otherwise.
349+
if outFinal != out {
350+
out.Close()
351+
}
352+
}()
332353
encoder := json.NewEncoder(out)
333354
encoder.SetEscapeHTML(h.jsonResponseEscapeHTML)
334355
encoder.SetIndent(h.jsonResponseIndentPrefix, h.jsonResponseIndentValue)

lambda/invoke_loop.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,13 @@ func unixMS(ms int64) time.Time {
2727
return time.Unix(ms/msPerS, (ms%msPerS)*nsPerMS)
2828
}
2929

30-
// startRuntimeAPILoop will return an error if handling a particular invoke resulted in a non-recoverable error
31-
func startRuntimeAPILoop(api string, handler Handler) error {
32-
client := newRuntimeAPIClient(api)
33-
h := newHandler(handler)
30+
func doRuntimeAPILoop(ctx context.Context, client *runtimeAPIClient, handler *handlerOptions) error {
3431
for {
35-
invoke, err := client.next()
32+
invoke, err := client.next(ctx)
3633
if err != nil {
3734
return err
3835
}
39-
if err = handleInvoke(invoke, h); err != nil {
36+
if err := handleInvoke(invoke, handler); err != nil {
4037
return err
4138
}
4239
}
@@ -72,7 +69,7 @@ func handleInvoke(invoke *invoke, handler *handlerOptions) error {
7269
ctx = context.WithValue(ctx, "x-amzn-trace-id", traceID)
7370

7471
// call the handler, marshal any returned error
75-
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload, handler.handlerFunc)
72+
response, invokeErr := callBytesHandlerFunc(ctx, invoke.payload.Bytes(), handler.handlerFunc)
7673
if invokeErr != nil {
7774
if err := reportFailure(invoke, invokeErr); err != nil {
7875
return err

lambda/invoke_loop_gte_go122.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//go:build go1.22
2+
// +build go1.22
3+
4+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved
5+
6+
package lambda
7+
8+
import (
9+
"context"
10+
"errors"
11+
"sync"
12+
13+
"github.com/aws/aws-lambda-go/lambdacontext"
14+
)
15+
16+
func startRuntimeAPILoop(api string, handler Handler) error {
17+
return startRuntimeAPILoopWithConcurrency(api, handler, lambdacontext.MaxConcurrency())
18+
}
19+
20+
func startRuntimeAPILoopWithConcurrency(api string, handler Handler, concurrency int) error {
21+
h := newHandler(handler)
22+
client := newRuntimeAPIClient(api)
23+
if concurrency <= 1 {
24+
return doRuntimeAPILoop(context.Background(), client, h)
25+
}
26+
27+
ctx, cancel := context.WithCancelCause(context.Background())
28+
defer cancel(errors.New("no handlers run"))
29+
30+
wg := &sync.WaitGroup{}
31+
wg.Add(concurrency)
32+
for range concurrency {
33+
go func() {
34+
cancel(doRuntimeAPILoop(ctx, client, h))
35+
wg.Done()
36+
}()
37+
}
38+
wg.Wait()
39+
40+
return context.Cause(ctx)
41+
}
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
//go:build go1.22
2+
// +build go1.22
3+
4+
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved
5+
6+
package lambda
7+
8+
import (
9+
"bytes"
10+
"context"
11+
"fmt"
12+
"io"
13+
"log"
14+
"math"
15+
"math/rand"
16+
"net"
17+
"net/http"
18+
"os"
19+
"os/exec"
20+
"path/filepath"
21+
"strings"
22+
"sync"
23+
"sync/atomic"
24+
"testing"
25+
"time"
26+
27+
"github.com/aws/aws-lambda-go/lambdacontext"
28+
"github.com/stretchr/testify/assert"
29+
"github.com/stretchr/testify/require"
30+
)
31+
32+
func TestRuntimeAPILoopWithConcurrency(t *testing.T) {
33+
nInvokes := 100
34+
concurrency := 5
35+
36+
metadata := make([]eventMetadata, nInvokes)
37+
for i := range nInvokes {
38+
m := defaultInvokeMetadata()
39+
m.requestID = fmt.Sprintf("request-%d", i)
40+
metadata[i] = m
41+
}
42+
43+
ts, record := runtimeAPIServer(``, nInvokes, metadata...)
44+
defer ts.Close()
45+
46+
active := atomic.Int32{}
47+
maxActive := atomic.Int32{}
48+
handler := NewHandler(func(ctx context.Context) (string, error) {
49+
activeNow := active.Add(1)
50+
defer active.Add(-1)
51+
for pr := maxActive.Load(); activeNow > pr; pr = maxActive.Load() {
52+
if maxActive.CompareAndSwap(pr, activeNow) {
53+
break
54+
}
55+
}
56+
lc, _ := lambdacontext.FromContext(ctx)
57+
time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond)
58+
switch lc.AwsRequestID[len(lc.AwsRequestID)-1:] {
59+
case "6", "7":
60+
return "", fmt.Errorf("error-%s", lc.AwsRequestID)
61+
default:
62+
return lc.AwsRequestID, nil
63+
}
64+
})
65+
endpoint := strings.Split(ts.URL, "://")[1]
66+
expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint)
67+
assert.EqualError(t, startRuntimeAPILoopWithConcurrency(endpoint, handler, concurrency), expectedError)
68+
assert.GreaterOrEqual(t, record.nGets, nInvokes+1)
69+
assert.Equal(t, nInvokes, record.nPosts)
70+
assert.Equal(t, int32(concurrency), maxActive.Load())
71+
responses := make(map[string]int)
72+
for _, response := range record.responses {
73+
responses[string(response)]++
74+
}
75+
assert.Len(t, responses, nInvokes)
76+
for response, count := range responses {
77+
assert.Equal(t, 1, count, "response %s seen %d times", response, count)
78+
}
79+
for i := range nInvokes {
80+
switch i % 10 {
81+
case 6, 7:
82+
assert.Contains(t, responses, fmt.Sprintf(`{"errorMessage":"error-request-%d","errorType":"errorString"}`, i))
83+
default:
84+
assert.Contains(t, responses, fmt.Sprintf(`"request-%d"`, i))
85+
}
86+
}
87+
}
88+
89+
func TestRuntimeAPILoopSingleConcurrency(t *testing.T) {
90+
nInvokes := 10
91+
92+
ts, record := runtimeAPIServer(``, nInvokes)
93+
defer ts.Close()
94+
95+
var counter atomic.Int32
96+
handler := NewHandler(func(ctx context.Context) (string, error) {
97+
counter.Add(1)
98+
return "Hello!", nil
99+
})
100+
endpoint := strings.Split(ts.URL, "://")[1]
101+
expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint)
102+
assert.EqualError(t, startRuntimeAPILoopWithConcurrency(endpoint, handler, 1), expectedError)
103+
assert.Equal(t, nInvokes+1, record.nGets)
104+
assert.Equal(t, nInvokes, record.nPosts)
105+
assert.Equal(t, int32(nInvokes), counter.Load())
106+
}
107+
108+
func TestRuntimeAPILoopWithConcurrencyPanic(t *testing.T) {
109+
concurrency := 3
110+
111+
ts, record := runtimeAPIServer(``, 100)
112+
defer ts.Close()
113+
114+
var logBuf bytes.Buffer
115+
log.SetOutput(&logBuf)
116+
defer log.SetOutput(os.Stderr)
117+
118+
var counter atomic.Int32
119+
handler := NewHandler(func() error {
120+
n := counter.Add(1)
121+
time.Sleep(time.Duration(n) * 10 * time.Millisecond)
122+
panic(fmt.Errorf("panic %d", n))
123+
})
124+
endpoint := strings.Split(ts.URL, "://")[1]
125+
err := startRuntimeAPILoopWithConcurrency(endpoint, handler, concurrency)
126+
require.Error(t, err)
127+
assert.Contains(t, err.Error(), "calling the handler function resulted in a panic, the process should exit")
128+
assert.Equal(t, concurrency, record.nGets)
129+
assert.Equal(t, concurrency, record.nPosts)
130+
assert.Equal(t, int32(concurrency), counter.Load())
131+
assert.Contains(t, string(record.responses[0]), "panic 1")
132+
logs := logBuf.String()
133+
idx1 := strings.Index(logs, "panic 1")
134+
idx2 := strings.Index(logs, "panic 2")
135+
idx3 := strings.Index(logs, "panic 3")
136+
assert.Greater(t, idx1, -1)
137+
assert.Greater(t, idx2, idx1)
138+
assert.Greater(t, idx3, idx2)
139+
}
140+
141+
func TestConcurrencyWithRIE(t *testing.T) {
142+
containerCmd := ""
143+
if _, err := exec.LookPath("finch"); err == nil {
144+
containerCmd = "finch"
145+
} else if _, err := exec.LookPath("docker"); err == nil {
146+
containerCmd = "docker"
147+
} else {
148+
t.Skip("finch or docker required")
149+
}
150+
151+
testDir := t.TempDir()
152+
handlerBuild := exec.Command("go", "build", "-o", filepath.Join(testDir, "bootstrap"), "./testdata/sleep.go")
153+
handlerBuild.Env = append(os.Environ(), "GOOS=linux")
154+
require.NoError(t, handlerBuild.Run())
155+
156+
nInvokes := 10
157+
concurrency := 3
158+
sleepMs := 1000
159+
batches := int(math.Ceil(float64(nInvokes) / float64(concurrency)))
160+
expectedMaxDuration := time.Duration(float64(batches*sleepMs)*1.1) * time.Millisecond // 10% margin for retries, network overhead, scheduling
161+
162+
// Find an available port
163+
listener, err := net.Listen("tcp", "127.0.0.1:0")
164+
require.NoError(t, err)
165+
port := listener.Addr().(*net.TCPAddr).Port
166+
listener.Close()
167+
168+
cmd := exec.Command(containerCmd, "run", "--rm",
169+
"-v", testDir+":/var/runtime:ro,delegated",
170+
"-p", fmt.Sprintf("%d:8080", port),
171+
"-e", fmt.Sprintf("AWS_LAMBDA_MAX_CONCURRENCY=%d", concurrency),
172+
"public.ecr.aws/lambda/provided:al2023",
173+
"bootstrap")
174+
stdout, err := cmd.StdoutPipe()
175+
require.NoError(t, err)
176+
stderr, err := cmd.StderrPipe()
177+
require.NoError(t, err)
178+
179+
var logBuf strings.Builder
180+
logDone := make(chan struct{})
181+
go func() {
182+
_, _ = io.Copy(io.MultiWriter(os.Stderr, &logBuf), io.MultiReader(stdout, stderr))
183+
close(logDone)
184+
185+
}()
186+
187+
require.NoError(t, cmd.Start())
188+
t.Cleanup(func() { _ = cmd.Process.Kill() })
189+
190+
time.Sleep(5 * time.Second) // Wait for container to start and pull image if needed
191+
192+
client := &http.Client{Timeout: 15 * time.Second}
193+
invokeURL := fmt.Sprintf("http://127.0.0.1:%d/2015-03-31/functions/function/invocations", port)
194+
195+
start := time.Now()
196+
var wg sync.WaitGroup
197+
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
198+
defer cancel()
199+
for range nInvokes {
200+
wg.Add(1)
201+
go func() {
202+
defer wg.Done()
203+
for {
204+
select {
205+
case <-ctx.Done():
206+
return
207+
default:
208+
}
209+
time.Sleep(50 * time.Millisecond)
210+
body := strings.NewReader(fmt.Sprintf(`{"sleep_ms":%d}`, sleepMs))
211+
resp, err := client.Post(invokeURL, "application/json", body)
212+
if err != nil {
213+
continue
214+
}
215+
_, _ = io.Copy(io.Discard, resp.Body)
216+
_ = resp.Body.Close()
217+
if resp.StatusCode == 400 {
218+
continue
219+
}
220+
return
221+
}
222+
}()
223+
}
224+
wg.Wait()
225+
duration := time.Since(start)
226+
227+
t.Logf("Completed %d invocations in %v", nInvokes, duration)
228+
229+
_ = cmd.Process.Kill()
230+
_ = cmd.Wait()
231+
<-logDone
232+
233+
logs := logBuf.String()
234+
processingCount := strings.Count(logs, "processing")
235+
completedCount := strings.Count(logs, "completed")
236+
237+
assert.Equal(t, nInvokes, processingCount, "expected %d processing logs", nInvokes)
238+
assert.Equal(t, nInvokes, completedCount, "expected %d completed logs", nInvokes)
239+
assert.Less(t, duration, expectedMaxDuration, "concurrent execution should complete faster than sequential")
240+
241+
}

0 commit comments

Comments
 (0)