|
| 1 | +//go:build go1.22 |
| 2 | +// +build go1.22 |
| 3 | + |
| 4 | +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved |
| 5 | + |
| 6 | +package lambda |
| 7 | + |
| 8 | +import ( |
| 9 | + "bytes" |
| 10 | + "context" |
| 11 | + "fmt" |
| 12 | + "io" |
| 13 | + "log" |
| 14 | + "math" |
| 15 | + "math/rand" |
| 16 | + "net" |
| 17 | + "net/http" |
| 18 | + "os" |
| 19 | + "os/exec" |
| 20 | + "path/filepath" |
| 21 | + "strings" |
| 22 | + "sync" |
| 23 | + "sync/atomic" |
| 24 | + "testing" |
| 25 | + "time" |
| 26 | + |
| 27 | + "github.com/aws/aws-lambda-go/lambdacontext" |
| 28 | + "github.com/stretchr/testify/assert" |
| 29 | + "github.com/stretchr/testify/require" |
| 30 | +) |
| 31 | + |
| 32 | +func TestRuntimeAPILoopWithConcurrency(t *testing.T) { |
| 33 | + nInvokes := 100 |
| 34 | + concurrency := 5 |
| 35 | + |
| 36 | + metadata := make([]eventMetadata, nInvokes) |
| 37 | + for i := range nInvokes { |
| 38 | + m := defaultInvokeMetadata() |
| 39 | + m.requestID = fmt.Sprintf("request-%d", i) |
| 40 | + metadata[i] = m |
| 41 | + } |
| 42 | + |
| 43 | + ts, record := runtimeAPIServer(``, nInvokes, metadata...) |
| 44 | + defer ts.Close() |
| 45 | + |
| 46 | + active := atomic.Int32{} |
| 47 | + maxActive := atomic.Int32{} |
| 48 | + handler := NewHandler(func(ctx context.Context) (string, error) { |
| 49 | + activeNow := active.Add(1) |
| 50 | + defer active.Add(-1) |
| 51 | + for pr := maxActive.Load(); activeNow > pr; pr = maxActive.Load() { |
| 52 | + if maxActive.CompareAndSwap(pr, activeNow) { |
| 53 | + break |
| 54 | + } |
| 55 | + } |
| 56 | + lc, _ := lambdacontext.FromContext(ctx) |
| 57 | + time.Sleep(time.Duration(rand.Intn(20)) * time.Millisecond) |
| 58 | + switch lc.AwsRequestID[len(lc.AwsRequestID)-1:] { |
| 59 | + case "6", "7": |
| 60 | + return "", fmt.Errorf("error-%s", lc.AwsRequestID) |
| 61 | + default: |
| 62 | + return lc.AwsRequestID, nil |
| 63 | + } |
| 64 | + }) |
| 65 | + endpoint := strings.Split(ts.URL, "://")[1] |
| 66 | + expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) |
| 67 | + assert.EqualError(t, startRuntimeAPILoopWithConcurrency(endpoint, handler, concurrency), expectedError) |
| 68 | + assert.GreaterOrEqual(t, record.nGets, nInvokes+1) |
| 69 | + assert.Equal(t, nInvokes, record.nPosts) |
| 70 | + assert.Equal(t, int32(concurrency), maxActive.Load()) |
| 71 | + responses := make(map[string]int) |
| 72 | + for _, response := range record.responses { |
| 73 | + responses[string(response)]++ |
| 74 | + } |
| 75 | + assert.Len(t, responses, nInvokes) |
| 76 | + for response, count := range responses { |
| 77 | + assert.Equal(t, 1, count, "response %s seen %d times", response, count) |
| 78 | + } |
| 79 | + for i := range nInvokes { |
| 80 | + switch i % 10 { |
| 81 | + case 6, 7: |
| 82 | + assert.Contains(t, responses, fmt.Sprintf(`{"errorMessage":"error-request-%d","errorType":"errorString"}`, i)) |
| 83 | + default: |
| 84 | + assert.Contains(t, responses, fmt.Sprintf(`"request-%d"`, i)) |
| 85 | + } |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +func TestRuntimeAPILoopSingleConcurrency(t *testing.T) { |
| 90 | + nInvokes := 10 |
| 91 | + |
| 92 | + ts, record := runtimeAPIServer(``, nInvokes) |
| 93 | + defer ts.Close() |
| 94 | + |
| 95 | + var counter atomic.Int32 |
| 96 | + handler := NewHandler(func(ctx context.Context) (string, error) { |
| 97 | + counter.Add(1) |
| 98 | + return "Hello!", nil |
| 99 | + }) |
| 100 | + endpoint := strings.Split(ts.URL, "://")[1] |
| 101 | + expectedError := fmt.Sprintf("failed to GET http://%s/2018-06-01/runtime/invocation/next: got unexpected status code: 410", endpoint) |
| 102 | + assert.EqualError(t, startRuntimeAPILoopWithConcurrency(endpoint, handler, 1), expectedError) |
| 103 | + assert.Equal(t, nInvokes+1, record.nGets) |
| 104 | + assert.Equal(t, nInvokes, record.nPosts) |
| 105 | + assert.Equal(t, int32(nInvokes), counter.Load()) |
| 106 | +} |
| 107 | + |
| 108 | +func TestRuntimeAPILoopWithConcurrencyPanic(t *testing.T) { |
| 109 | + concurrency := 3 |
| 110 | + |
| 111 | + ts, record := runtimeAPIServer(``, 100) |
| 112 | + defer ts.Close() |
| 113 | + |
| 114 | + var logBuf bytes.Buffer |
| 115 | + log.SetOutput(&logBuf) |
| 116 | + defer log.SetOutput(os.Stderr) |
| 117 | + |
| 118 | + var counter atomic.Int32 |
| 119 | + handler := NewHandler(func() error { |
| 120 | + n := counter.Add(1) |
| 121 | + time.Sleep(time.Duration(n) * 10 * time.Millisecond) |
| 122 | + panic(fmt.Errorf("panic %d", n)) |
| 123 | + }) |
| 124 | + endpoint := strings.Split(ts.URL, "://")[1] |
| 125 | + err := startRuntimeAPILoopWithConcurrency(endpoint, handler, concurrency) |
| 126 | + require.Error(t, err) |
| 127 | + assert.Contains(t, err.Error(), "calling the handler function resulted in a panic, the process should exit") |
| 128 | + assert.Equal(t, concurrency, record.nGets) |
| 129 | + assert.Equal(t, concurrency, record.nPosts) |
| 130 | + assert.Equal(t, int32(concurrency), counter.Load()) |
| 131 | + assert.Contains(t, string(record.responses[0]), "panic 1") |
| 132 | + logs := logBuf.String() |
| 133 | + idx1 := strings.Index(logs, "panic 1") |
| 134 | + idx2 := strings.Index(logs, "panic 2") |
| 135 | + idx3 := strings.Index(logs, "panic 3") |
| 136 | + assert.Greater(t, idx1, -1) |
| 137 | + assert.Greater(t, idx2, idx1) |
| 138 | + assert.Greater(t, idx3, idx2) |
| 139 | +} |
| 140 | + |
| 141 | +func TestConcurrencyWithRIE(t *testing.T) { |
| 142 | + containerCmd := "" |
| 143 | + if _, err := exec.LookPath("finch"); err == nil { |
| 144 | + containerCmd = "finch" |
| 145 | + } else if _, err := exec.LookPath("docker"); err == nil { |
| 146 | + containerCmd = "docker" |
| 147 | + } else { |
| 148 | + t.Skip("finch or docker required") |
| 149 | + } |
| 150 | + |
| 151 | + testDir := t.TempDir() |
| 152 | + handlerBuild := exec.Command("go", "build", "-o", filepath.Join(testDir, "bootstrap"), "./testdata/sleep.go") |
| 153 | + handlerBuild.Env = append(os.Environ(), "GOOS=linux") |
| 154 | + require.NoError(t, handlerBuild.Run()) |
| 155 | + |
| 156 | + nInvokes := 10 |
| 157 | + concurrency := 3 |
| 158 | + sleepMs := 1000 |
| 159 | + batches := int(math.Ceil(float64(nInvokes) / float64(concurrency))) |
| 160 | + expectedMaxDuration := time.Duration(float64(batches*sleepMs)*1.1) * time.Millisecond // 10% margin for retries, network overhead, scheduling |
| 161 | + |
| 162 | + // Find an available port |
| 163 | + listener, err := net.Listen("tcp", "127.0.0.1:0") |
| 164 | + require.NoError(t, err) |
| 165 | + port := listener.Addr().(*net.TCPAddr).Port |
| 166 | + listener.Close() |
| 167 | + |
| 168 | + cmd := exec.Command(containerCmd, "run", "--rm", |
| 169 | + "-v", testDir+":/var/runtime:ro,delegated", |
| 170 | + "-p", fmt.Sprintf("%d:8080", port), |
| 171 | + "-e", fmt.Sprintf("AWS_LAMBDA_MAX_CONCURRENCY=%d", concurrency), |
| 172 | + "public.ecr.aws/lambda/provided:al2023", |
| 173 | + "bootstrap") |
| 174 | + stdout, err := cmd.StdoutPipe() |
| 175 | + require.NoError(t, err) |
| 176 | + stderr, err := cmd.StderrPipe() |
| 177 | + require.NoError(t, err) |
| 178 | + |
| 179 | + var logBuf strings.Builder |
| 180 | + logDone := make(chan struct{}) |
| 181 | + go func() { |
| 182 | + _, _ = io.Copy(io.MultiWriter(os.Stderr, &logBuf), io.MultiReader(stdout, stderr)) |
| 183 | + close(logDone) |
| 184 | + |
| 185 | + }() |
| 186 | + |
| 187 | + require.NoError(t, cmd.Start()) |
| 188 | + t.Cleanup(func() { _ = cmd.Process.Kill() }) |
| 189 | + |
| 190 | + time.Sleep(5 * time.Second) // Wait for container to start and pull image if needed |
| 191 | + |
| 192 | + client := &http.Client{Timeout: 15 * time.Second} |
| 193 | + invokeURL := fmt.Sprintf("http://127.0.0.1:%d/2015-03-31/functions/function/invocations", port) |
| 194 | + |
| 195 | + start := time.Now() |
| 196 | + var wg sync.WaitGroup |
| 197 | + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) |
| 198 | + defer cancel() |
| 199 | + for range nInvokes { |
| 200 | + wg.Add(1) |
| 201 | + go func() { |
| 202 | + defer wg.Done() |
| 203 | + for { |
| 204 | + select { |
| 205 | + case <-ctx.Done(): |
| 206 | + return |
| 207 | + default: |
| 208 | + } |
| 209 | + time.Sleep(50 * time.Millisecond) |
| 210 | + body := strings.NewReader(fmt.Sprintf(`{"sleep_ms":%d}`, sleepMs)) |
| 211 | + resp, err := client.Post(invokeURL, "application/json", body) |
| 212 | + if err != nil { |
| 213 | + continue |
| 214 | + } |
| 215 | + _, _ = io.Copy(io.Discard, resp.Body) |
| 216 | + _ = resp.Body.Close() |
| 217 | + if resp.StatusCode == 400 { |
| 218 | + continue |
| 219 | + } |
| 220 | + return |
| 221 | + } |
| 222 | + }() |
| 223 | + } |
| 224 | + wg.Wait() |
| 225 | + duration := time.Since(start) |
| 226 | + |
| 227 | + t.Logf("Completed %d invocations in %v", nInvokes, duration) |
| 228 | + |
| 229 | + _ = cmd.Process.Kill() |
| 230 | + _ = cmd.Wait() |
| 231 | + <-logDone |
| 232 | + |
| 233 | + logs := logBuf.String() |
| 234 | + processingCount := strings.Count(logs, "processing") |
| 235 | + completedCount := strings.Count(logs, "completed") |
| 236 | + |
| 237 | + assert.Equal(t, nInvokes, processingCount, "expected %d processing logs", nInvokes) |
| 238 | + assert.Equal(t, nInvokes, completedCount, "expected %d completed logs", nInvokes) |
| 239 | + assert.Less(t, duration, expectedMaxDuration, "concurrent execution should complete faster than sequential") |
| 240 | + |
| 241 | +} |
0 commit comments