From 102b9fb5dbbe5576088c04e50564c6cc7ce3f439 Mon Sep 17 00:00:00 2001 From: Dane H Lim Date: Thu, 7 Sep 2023 11:21:36 -0700 Subject: [PATCH] Move ACS session to ecs-agent module and refactor --- agent/acs/handler/acs_handler.go | 402 ----- agent/acs/handler/acs_handler_test.go | 1541 ----------------- agent/app/agent.go | 82 +- agent/go.mod | 2 +- .../ecs-agent/acs/session/session.go | 394 +++++ ecs-agent/acs/session/session.go | 394 +++++ ecs-agent/acs/session/session_test.go | 1329 ++++++++++++++ 7 files changed, 2181 insertions(+), 1963 deletions(-) delete mode 100644 agent/acs/handler/acs_handler.go delete mode 100644 agent/acs/handler/acs_handler_test.go create mode 100644 ecs-agent/acs/session/session_test.go diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go deleted file mode 100644 index c8ecd7c3e20..00000000000 --- a/agent/acs/handler/acs_handler.go +++ /dev/null @@ -1,402 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -// Package handler deals with appropriately reacting to all ACS messages as well -// as maintaining the connection to ACS. -package handler - -import ( - "context" - "io" - "net/url" - "strconv" - "strings" - "time" - - "github.com/aws/amazon-ecs-agent/agent/api" - "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/data" - "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" - "github.com/aws/amazon-ecs-agent/agent/engine" - "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - "github.com/aws/amazon-ecs-agent/agent/eventhandler" - "github.com/aws/amazon-ecs-agent/agent/version" - acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" - rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" - "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" - "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" - "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" - "github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime" - "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" - - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/cihub/seelog" -) - -const ( - // heartbeatTimeout is the maximum time to wait between heartbeats - // without disconnecting - heartbeatTimeout = 1 * time.Minute - heartbeatJitter = 1 * time.Minute - // wsRWTimeout is the duration of read and write deadline for the - // websocket connection - wsRWTimeout = 2*heartbeatTimeout + heartbeatJitter - - inactiveInstanceReconnectDelay = 1 * time.Hour - - connectionBackoffMin = 250 * time.Millisecond - connectionBackoffMax = 2 * time.Minute - connectionBackoffJitter = 0.2 - connectionBackoffMultiplier = 1.5 - // sendCredentialsURLParameterName is the name of the URL parameter - // in the ACS URL that is used to indicate if ACS should send - // credentials for all tasks on establishing the connection - sendCredentialsURLParameterName = "sendCredentials" - inactiveInstanceExceptionPrefix = "InactiveInstanceException:" - // ACS protocol version spec: - // 1: default protocol version - // 2: ACS will proactively close the connection when heartbeat acks are missing - acsProtocolVersion = 2 -) - -// Session defines an interface for handler's long-lived connection with ACS. -type Session interface { - Start() error -} - -// session encapsulates all arguments needed by the handler to connect to ACS -// and to handle messages received by ACS. The Session.Start() method can be used -// to start processing messages from ACS. -type session struct { - containerInstanceARN string - credentialsProvider *credentials.Credentials - agentConfig *config.Config - deregisterInstanceEventStream *eventstream.EventStream - taskEngine engine.TaskEngine - dockerClient dockerapi.DockerClient - ecsClient api.ECSClient - state dockerstate.TaskEngineState - dataClient data.Client - credentialsManager rolecredentials.Manager - taskHandler *eventhandler.TaskHandler - ctx context.Context - cancel context.CancelFunc - backoff retry.Backoff - metricsFactory metrics.EntryFactory - clientFactory wsclient.ClientFactory - sendCredentials bool - latestSeqNumTaskManifest *int64 - doctor *doctor.Doctor - addUpdateRequestHandlers func(wsclient.ClientServer) - _heartbeatTimeout time.Duration - _heartbeatJitter time.Duration - connectionTime time.Duration - connectionJitter time.Duration - _inactiveInstanceReconnectDelay time.Duration -} - -// NewSession creates a new Session object -func NewSession( - ctx context.Context, - config *config.Config, - deregisterInstanceEventStream *eventstream.EventStream, - containerInstanceARN string, - credentialsProvider *credentials.Credentials, - dockerClient dockerapi.DockerClient, - ecsClient api.ECSClient, - taskEngineState dockerstate.TaskEngineState, - dataClient data.Client, - taskEngine engine.TaskEngine, - credentialsManager rolecredentials.Manager, - taskHandler *eventhandler.TaskHandler, - latestSeqNumTaskManifest *int64, - doctor *doctor.Doctor, - clientFactory wsclient.ClientFactory, - addUpdateRequestHandlers func(wsclient.ClientServer), - metricsFactory metrics.EntryFactory, -) Session { - backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, - connectionBackoffJitter, connectionBackoffMultiplier) - derivedContext, cancel := context.WithCancel(ctx) - - return &session{ - agentConfig: config, - deregisterInstanceEventStream: deregisterInstanceEventStream, - containerInstanceARN: containerInstanceARN, - credentialsProvider: credentialsProvider, - ecsClient: ecsClient, - dockerClient: dockerClient, - state: taskEngineState, - dataClient: dataClient, - taskEngine: taskEngine, - credentialsManager: credentialsManager, - taskHandler: taskHandler, - ctx: derivedContext, - cancel: cancel, - backoff: backoff, - latestSeqNumTaskManifest: latestSeqNumTaskManifest, - doctor: doctor, - metricsFactory: metricsFactory, - clientFactory: clientFactory, - addUpdateRequestHandlers: addUpdateRequestHandlers, - sendCredentials: true, - _heartbeatTimeout: heartbeatTimeout, - _heartbeatJitter: heartbeatJitter, - connectionTime: wsclient.DisconnectTimeout, - connectionJitter: wsclient.DisconnectJitterMax, - _inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, - } -} - -// Start starts the session. It'll forever keep trying to connect to ACS unless -// the context is cancelled. -// -// Returns nil always TODO: consider removing error return value completely -func (acsSession *session) Start() error { - // Loop continuously until context is closed/cancelled - for { - seelog.Debugf("Attempting connect to ACS") - // Start a session with ACS - acsError := acsSession.startSessionOnce() - - // If the session is over check for shutdown first - if err := acsSession.ctx.Err(); err != nil { - return nil - } - - // If ACS closed the connection, reconnect immediately - if shouldReconnectWithoutBackoff(acsError) { - seelog.Infof("ACS Websocket connection closed for a valid reason: %v", acsError) - acsSession.backoff.Reset() - continue - } - - // Session with ACS was stopped with some error, start processing the error - isInactiveInstance := isInactiveInstanceError(acsError) - if isInactiveInstance { - // If the instance was deregistered, send an event to the event stream - // for the same - seelog.Debug("Container instance is deregistered, notifying listeners") - err := acsSession.deregisterInstanceEventStream.WriteToEventStream(struct{}{}) - if err != nil { - seelog.Debugf("Failed to write to deregister container instance event stream, err: %v", err) - } - } - - // Disconnected unexpectedly from ACS, compute backoff duration to - // reconnect - reconnectDelay := acsSession.computeReconnectDelay(isInactiveInstance) - seelog.Infof("Reconnecting to ACS in: %s", reconnectDelay.String()) - waitComplete := acsSession.waitForDuration(reconnectDelay) - if !waitComplete { - // Wait was interrupted. We expect the session to close as canceling - // the session context is the only way to end up here. Print a message - // to indicate the same - seelog.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close") - return nil - } - - // If the context was not cancelled and we've waited for the - // wait duration without any errors, reconnect to ACS - seelog.Info("Done waiting; reconnecting to ACS") - } -} - -// startSessionOnce creates a session with ACS and handles requests using the passed -// in arguments -func (acsSession *session) startSessionOnce() error { - minAgentCfg := &wsclient.WSClientMinAgentConfig{ - AcceptInsecureCert: acsSession.agentConfig.AcceptInsecureCert, - AWSRegion: acsSession.agentConfig.AWSRegion, - } - - acsEndpoint, err := acsSession.ecsClient.DiscoverPollEndpoint(acsSession.containerInstanceARN) - if err != nil { - seelog.Errorf("acs: unable to discover poll endpoint, err: %v", err) - return err - } - - url := acsSession.acsURL(acsEndpoint) - client := acsSession.clientFactory.New( - url, - acsSession.credentialsProvider, - wsRWTimeout, - minAgentCfg, - acsSession.metricsFactory) - defer client.Close() - - return acsSession.startACSSession(client) -} - -// startACSSession starts a session with ACS. It adds request handlers for various -// kinds of messages expected from ACS. It returns on server disconnection or when -// the context is cancelled -func (acsSession *session) startACSSession(client wsclient.ClientServer) error { - payloadMsgHandler := NewPayloadMessageHandler(acsSession.taskEngine, acsSession.ecsClient, acsSession.dataClient, - acsSession.taskHandler, acsSession.credentialsManager, acsSession.latestSeqNumTaskManifest) - - credsMetadataSetter := NewCredentialsMetadataSetter(acsSession.taskEngine) - - eniHandler := NewENIHandler(acsSession.state, acsSession.dataClient) - - manifestMessageIDAccessor := NewManifestMessageIDAccessor() - - sequenceNumberAccessor := NewSequenceNumberAccessor(acsSession.latestSeqNumTaskManifest, acsSession.dataClient) - taskComparer := NewTaskComparer(acsSession.taskEngine) - - taskStopper := NewTaskStopper(acsSession.taskEngine) - - responseSender := func(response interface{}) error { - return client.MakeRequest(response) - } - responders := []wsclient.RequestResponder{ - acssession.NewPayloadResponder(payloadMsgHandler, responseSender), - acssession.NewRefreshCredentialsResponder(acsSession.credentialsManager, credsMetadataSetter, acsSession.metricsFactory, - responseSender), - acssession.NewAttachTaskENIResponder(eniHandler, responseSender), - acssession.NewAttachInstanceENIResponder(eniHandler, responseSender), - acssession.NewHeartbeatResponder(acsSession.doctor, responseSender), - acssession.NewTaskManifestResponder(taskComparer, sequenceNumberAccessor, manifestMessageIDAccessor, - acsSession.metricsFactory, responseSender), - acssession.NewTaskStopVerificationACKResponder(taskStopper, manifestMessageIDAccessor, acsSession.metricsFactory), - } - for _, r := range responders { - client.AddRequestHandler(r.HandlerFunc()) - } - - if acsSession.addUpdateRequestHandlers != nil { - acsSession.addUpdateRequestHandlers(client) - } - - disconnectTimer, err := client.Connect(metrics.ACSDisconnectTimeoutMetricName, - acsSession.connectionTime, - acsSession.connectionJitter) - if err != nil { - seelog.Errorf("Error connecting to ACS: %v", err) - return err - } - - defer disconnectTimer.Stop() - - seelog.Info("Connected to ACS endpoint") - - // Start a heartbeat timer for closing the connection - heartbeatTimer := newHeartbeatTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) - // Any message from the server resets the heartbeat timer - client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client)) - defer heartbeatTimer.Stop() - - // Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence - // and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS. - acsSession.sendCredentials = false - - backoffResetTimer := time.AfterFunc( - retry.AddJitter(acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()), func() { - // If we do not have an error connecting and remain connected for at - // least 1 or so minutes, reset the backoff. This prevents disconnect - // errors that only happen infrequently from damaging the reconnect - // delay as significantly. - acsSession.backoff.Reset() - }) - defer backoffResetTimer.Stop() - - return client.Serve(acsSession.ctx) -} - -// newHeartbeatTimer creates a new time object, with a callback to -// disconnect from ACS on inactivity -func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer { - timer := time.AfterFunc(retry.AddJitter(timeout, jitter), func() { - seelog.Warn("ACS Connection hasn't had any activity for too long; closing connection") - if err := client.Close(); err != nil { - seelog.Warnf("Error disconnecting: %v", err) - } - seelog.Info("Disconnected from ACS") - }) - - return timer -} - -func (acsSession *session) computeReconnectDelay(isInactiveInstance bool) time.Duration { - if isInactiveInstance { - return acsSession._inactiveInstanceReconnectDelay - } - - return acsSession.backoff.Duration() -} - -// waitForDuration waits for the specified duration of time. If the wait is interrupted, -// it returns a false value. Else, it returns true, indicating completion of wait time. -func (acsSession *session) waitForDuration(delay time.Duration) bool { - reconnectTimer := time.NewTimer(delay) - select { - case <-reconnectTimer.C: - return true - case <-acsSession.ctx.Done(): - reconnectTimer.Stop() - return false - } -} - -func (acsSession *session) heartbeatTimeout() time.Duration { - return acsSession._heartbeatTimeout -} - -func (acsSession *session) heartbeatJitter() time.Duration { - return acsSession._heartbeatJitter -} - -// acsURL returns the websocket url for ACS given the endpoint -func (acsSession *session) acsURL(endpoint string) string { - acsURL := endpoint - if endpoint[len(endpoint)-1] != '/' { - acsURL += "/" - } - acsURL += "ws" - query := url.Values{} - query.Set("clusterArn", acsSession.agentConfig.Cluster) - query.Set("containerInstanceArn", acsSession.containerInstanceARN) - query.Set("agentHash", version.GitHashString()) - query.Set("agentVersion", version.Version) - query.Set("seqNum", "1") - query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion)) - if dockerVersion, err := acsSession.taskEngine.Version(); err == nil { - query.Set("dockerVersion", "DockerVersion: "+dockerVersion) - } - query.Set(sendCredentialsURLParameterName, strconv.FormatBool(acsSession.sendCredentials)) - return acsURL + "?" + query.Encode() -} - -// anyMessageHandler handles any server message. Any server message means the -// connection is active and thus the heartbeat disconnect should not occur -func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) { - return func(interface{}) { - seelog.Debug("ACS activity occurred") - // Reset read deadline as there's activity on the channel - if err := client.SetReadDeadline(time.Now().Add(wsRWTimeout)); err != nil { - seelog.Warnf("Unable to extend read deadline for ACS connection: %v", err) - } - - // Reset heartbeat timer - timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) - } -} - -func shouldReconnectWithoutBackoff(acsError error) bool { - return acsError == nil || acsError == io.EOF -} - -func isInactiveInstanceError(acsError error) bool { - return acsError != nil && strings.HasPrefix(acsError.Error(), inactiveInstanceExceptionPrefix) -} diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go deleted file mode 100644 index f93f9e75e7b..00000000000 --- a/agent/acs/handler/acs_handler_test.go +++ /dev/null @@ -1,1541 +0,0 @@ -//go:build unit -// +build unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -package handler - -import ( - "context" - "fmt" - "io" - "net/http" - "net/http/httptest" - "net/url" - "os" - "reflect" - "runtime" - "runtime/pprof" - "strconv" - "sync" - "testing" - "time" - - "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" - - apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" - apitask "github.com/aws/amazon-ecs-agent/agent/api/task" - "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/data" - mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks" - "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" - "github.com/aws/amazon-ecs-agent/agent/eventhandler" - "github.com/aws/amazon-ecs-agent/agent/version" - acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" - rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" - mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" - "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" - "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" - mock_retry "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock" - "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" - mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/golang/mock/gomock" - "github.com/gorilla/websocket" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -const ( - samplePayloadMessage = ` -{ - "type": "PayloadMessage", - "message": { - "messageId": "123", - "tasks": [ - { - "taskDefinitionAccountId": "123", - "containers": [ - { - "environment": {}, - "name": "name", - "cpu": 1, - "essential": true, - "memory": 1, - "portMappings": [], - "overrides": "{}", - "image": "i", - "mountPoints": [], - "volumesFrom": [] - } - ], - "elasticNetworkInterfaces":[{ - "attachmentArn": "eni_attach_arn", - "ec2Id": "eni_id", - "ipv4Addresses":[{ - "primary": true, - "privateAddress": "ipv4" - }], - "ipv6Addresses": [{ - "address": "ipv6" - }], - "subnetGatewayIpv4Address": "ipv4/20", - "macAddress": "mac" - }], - "roleCredentials": { - "credentialsId": "credsId", - "accessKeyId": "accessKeyId", - "expiration": "2016-03-25T06:17:19.318+0000", - "roleArn": "r1", - "secretAccessKey": "secretAccessKey", - "sessionToken": "token" - }, - "version": "3", - "volumes": [], - "family": "f", - "arn": "arn", - "desiredStatus": "RUNNING" - } - ], - "generatedAt": 1, - "clusterArn": "1", - "containerInstanceArn": "1", - "seqNum": 1 - } -} -` - sampleRefreshCredentialsMessage = ` -{ - "type": "IAMRoleCredentialsMessage", - "message": { - "messageId": "123", - "clusterArn": "default", - "taskArn": "t1", - "roleType": "TaskApplication", - "roleCredentials": { - "credentialsId": "credsId", - "accessKeyId": "newakid", - "expiration": "later", - "roleArn": "r1", - "secretAccessKey": "newskid", - "sessionToken": "newstkn" - } - } -} -` - acsURL = "http://endpoint.tld" -) - -var testConfig = &config.Config{ - Cluster: "someCluster", - AcceptInsecureCert: true, -} - -var testCreds = credentials.NewStaticCredentials("test-id", "test-secret", "test-token") - -// TestACSURL tests if the URL is constructed correctly when connecting to ACS -func TestACSURL(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - - taskEngine.EXPECT().Version().Return("Docker version result", nil) - - acsSession := session{ - taskEngine: taskEngine, - sendCredentials: true, - agentConfig: testConfig, - containerInstanceARN: "myContainerInstance", - } - wsurl := acsSession.acsURL(acsURL) - - parsed, err := url.Parse(wsurl) - assert.NoError(t, err, "should be able to parse URL") - assert.Equal(t, "/ws", parsed.Path, "wrong path") - assert.Equal(t, "someCluster", parsed.Query().Get("clusterArn"), "wrong cluster") - assert.Equal(t, "myContainerInstance", parsed.Query().Get("containerInstanceArn"), "wrong container instance") - assert.Equal(t, version.Version, parsed.Query().Get("agentVersion"), "wrong agent version") - assert.Equal(t, version.GitHashString(), parsed.Query().Get("agentHash"), "wrong agent hash") - assert.Equal(t, "DockerVersion: Docker version result", parsed.Query().Get("dockerVersion"), "wrong docker version") - assert.Equalf(t, "true", parsed.Query().Get(sendCredentialsURLParameterName), "Wrong value set for: %s", sendCredentialsURLParameterName) - assert.Equal(t, "1", parsed.Query().Get("seqNum"), "wrong seqNum") - protocolVersion, _ := strconv.Atoi(parsed.Query().Get("protocolVersion")) - assert.True(t, protocolVersion > 1, "ACS protocol version should be greater than 1") -} - -// TestHandlerReconnectsOnConnectErrors tests if handler reconnects retries -// to establish the session with ACS when ClientServer.Connect() returns errors -func TestHandlerReconnectsOnConnectErrors(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - gomock.InOrder( - // Connect fails 10 times - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, io.EOF).Times(10), - // Cancel trying to connect to ACS on the 11th attempt - // Failure to retry on Connect() errors should cause the - // test to time out as the context is never cancelled - - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - cancel() - }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).MinTimes(1), - ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestIsInactiveInstanceErrorReturnsTrueForInactiveInstance tests if the 'InactiveInstance' -// exception is identified correctly by the handler -func TestIsInactiveInstanceErrorReturnsTrueForInactiveInstance(t *testing.T) { - assert.True(t, isInactiveInstanceError(fmt.Errorf("InactiveInstanceException: ")), - "inactive instance exception message parsed incorrectly") -} - -// TestIsInactiveInstanceErrorReturnsFalseForActiveInstance tests if non 'InactiveInstance' -// exceptions are identified correctly by the handler -func TestIsInactiveInstanceErrorReturnsFalseForActiveInstance(t *testing.T) { - assert.False(t, isInactiveInstanceError(io.EOF), - "inactive instance exception message parsed incorrectly") -} - -// TestComputeReconnectDelayForInactiveInstance tests if the reconnect delay is computed -// correctly for an inactive instance -func TestComputeReconnectDelayForInactiveInstance(t *testing.T) { - acsSession := session{_inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay} - assert.Equal(t, inactiveInstanceReconnectDelay, acsSession.computeReconnectDelay(true), - "Reconnect delay doesn't match expected value for inactive instance") -} - -// TestComputeReconnectDelayForActiveInstance tests if the reconnect delay is computed -// correctly for an active instance -func TestComputeReconnectDelayForActiveInstance(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockBackoff := mock_retry.NewMockBackoff(ctrl) - mockBackoff.EXPECT().Duration().Return(connectionBackoffMax) - - acsSession := session{backoff: mockBackoff} - assert.Equal(t, connectionBackoffMax, acsSession.computeReconnectDelay(false), - "Reconnect delay doesn't match expected value for active instance") -} - -// TestWaitForDurationReturnsTrueWhenContextNotCancelled tests if the -// waitForDurationOrCancelledSession method behaves correctly when the session context -// is not cancelled -func TestWaitForDurationReturnsTrueWhenContextNotCancelled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - acsSession := session{ - ctx: ctx, - cancel: cancel, - } - - assert.True(t, acsSession.waitForDuration(time.Millisecond), - "WaitForDuration should return true when uninterrupted") -} - -// TestWaitForDurationReturnsFalseWhenContextCancelled tests if the -// waitForDurationOrCancelledSession method behaves correctly when the session contexnt -// is cancelled -func TestWaitForDurationReturnsFalseWhenContextCancelled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - acsSession := session{ - ctx: ctx, - cancel: cancel, - } - cancel() - - assert.False(t, acsSession.waitForDuration(time.Millisecond), - "WaitForDuration should return false when interrupted") -} - -func TestShouldReconnectWithoutBackoffReturnsTrueForEOF(t *testing.T) { - assert.True(t, shouldReconnectWithoutBackoff(io.EOF), - "Reconnect without backoff should return true when connection is closed") -} - -func TestShouldReconnectWithoutBackoffReturnsFalseForNonEOF(t *testing.T) { - assert.False(t, shouldReconnectWithoutBackoff(fmt.Errorf("not EOF")), - "Reconnect without backoff should return false for non io.EOF error") -} - -// TestHandlerReconnectsWithoutBackoffOnEOFError tests if the session handler reconnects -// to ACS without any delay when the connection is closed with the io.EOF error -func TestHandlerReconnectsWithoutBackoffOnEOFError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - deregisterInstanceEventStream.StartListening() - - mockBackoff := mock_retry.NewMockBackoff(ctrl) - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - gomock.InOrder( - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, io.EOF), - // The backoff.Reset() method is expected to be invoked when the connection - // is closed with io.EOF - mockBackoff.EXPECT().Reset(), - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - // cancel the context on the 2nd connect attempt, which should stop - // the test - cancel() - }).Return(nil, io.EOF), - mockBackoff.EXPECT().Reset().AnyTimes(), - ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - deregisterInstanceEventStream: deregisterInstanceEventStream, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: mockBackoff, - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - metricsFactory: metrics.NewNopEntryFactory(), - latestSeqNumTaskManifest: aws.Int64(10), - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - _inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, - } - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerReconnectsWithoutBackoffOnEOFError tests if the session handler reconnects -// to ACS after a backoff duration when the connection is closed with non io.EOF error -func TestHandlerReconnectsWithBackoffOnNonEOFError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - deregisterInstanceEventStream.StartListening() - - mockBackoff := mock_retry.NewMockBackoff(ctrl) - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - gomock.InOrder( - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("not EOF")), - // The backoff.Duration() method is expected to be invoked when - // the connection is closed with a non-EOF error code to compute - // the backoff. Also, no calls to backoff.Reset() are expected - // in this code path. - mockBackoff.EXPECT().Duration().Return(time.Millisecond), - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - cancel() - }).Return(nil, io.EOF), - mockBackoff.EXPECT().Reset().AnyTimes(), - ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - deregisterInstanceEventStream: deregisterInstanceEventStream, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: mockBackoff, - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerGeneratesDeregisteredInstanceEvent tests if the session handler generates -// an event into the deregister instance event stream when the acs connection is closed -// with inactive instance error -func TestHandlerGeneratesDeregisteredInstanceEvent(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - - // receiverFunc cancels the context when invoked. Any event on the deregister - // instance even stream would trigger this. - receiverFunc := func(...interface{}) error { - cancel() - return nil - } - err := deregisterInstanceEventStream.Subscribe("DeregisterContainerInstance", receiverFunc) - assert.NoError(t, err, "Error adding deregister instance event stream subscriber") - deregisterInstanceEventStream.StartListening() - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("InactiveInstanceException:")) - inactiveInstanceReconnectDelay := 200 * time.Millisecond - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - deregisterInstanceEventStream: deregisterInstanceEventStream, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - _inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, - metricsFactory: metrics.NewNopEntryFactory(), - } - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerReconnectDelayForInactiveInstanceError tests if the session handler applies -// the proper reconnect delay with ACS when ClientServer.Connect() returns the -// InstanceInactive error -func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - // Don't start to ensure an error doesn't affect reconnect - // deregisterInstanceEventStream.StartListening() - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - var firstConnectionAttemptTime time.Time - inactiveInstanceReconnectDelay := 200 * time.Millisecond - gomock.InOrder( - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - firstConnectionAttemptTime = time.Now() - }).Return(nil, fmt.Errorf("InactiveInstanceException:")), - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - reconnectDelay := time.Now().Sub(firstConnectionAttemptTime) - reconnectDelayTime := time.Now() - t.Logf("Delay between successive connections: %v", reconnectDelay) - timeSubFuncSlopAllowed := 2 * time.Millisecond - if reconnectDelay < inactiveInstanceReconnectDelay { - // On windows platform, we found issue with time.Now().Sub(...) reporting 199.9989 even - // after the code has already waited for time.NewTimer(200)ms. - assert.WithinDuration(t, reconnectDelayTime, firstConnectionAttemptTime.Add(inactiveInstanceReconnectDelay), timeSubFuncSlopAllowed) - } - cancel() - }).Return(nil, io.EOF), - ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - deregisterInstanceEventStream: deregisterInstanceEventStream, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - _inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, - metricsFactory: metrics.NewNopEntryFactory(), - } - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerReconnectsOnServeErrors tests if the handler retries to -// establish the session with ACS when ClientServer.Serve() returns errors -func TestHandlerReconnectsOnServeErrors(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - gomock.InOrder( - // Serve fails 10 times - mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).Times(10), - // Cancel trying to Serve ACS requests on the 11th attempt - // Failure to retry on Serve() errors should cause the - // test to time out as the context is never cancelled - mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { - cancel() - }), - ) - - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerStopsWhenContextIsCancelled tests if the session's Start() method returns -// when session context is cancelled -func TestHandlerStopsWhenContextIsCancelled(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - gomock.InOrder( - mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF), - mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { - cancel() - }).Return(errors.New("InactiveInstanceException")), - ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - - // The session error channel would have an event when the Start() method returns - // Cancelling the context should trigger this - sessionError := make(chan error) - go func() { - sessionError <- acsSession.Start() - }() - response := <-sessionError - assert.Nil(t, response) -} - -// TestHandlerStopsWhenContextIsError tests if the session's Start() method returns -// when session context is in error -func TestHandlerStopsWhenContextIsError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { - time.Sleep(5 * time.Millisecond) - }).Return(io.EOF).AnyTimes() - - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - - // The session error channel would have an event when the Start() method returns - // Cancelling the context should trigger this - sessionError := make(chan error) - go func() { - sessionError <- acsSession.Start() - }() - response := <-sessionError - assert.Nil(t, response) -} - -// TestHandlerStopsWhenContextIsErrorReconnectDelay tests if the session's Start() method returns -// when session context is in error -func TestHandlerStopsWhenContextIsErrorReconnectDelay(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).Return(errors.New("InactiveInstanceException")).AnyTimes() - - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - _inactiveInstanceReconnectDelay: 1 * time.Hour, - metricsFactory: metrics.NewNopEntryFactory(), - } - - // The session error channel would have an event when the Start() method returns - // Cancelling the context should trigger this - sessionError := make(chan error) - go func() { - sessionError <- acsSession.Start() - }() - response := <-sessionError - assert.Nil(t, response) -} - -// TestHandlerReconnectsOnDiscoverPollEndpointError tests if handler retries -// to establish the session with ACS on DiscoverPollEndpoint errors -func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - // Serve() cancels the context - cancel() - }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).MinTimes(1) - - gomock.InOrder( - // DiscoverPollEndpoint returns an error on its first invocation - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return("", fmt.Errorf("oops")).Times(1), - // Second invocation returns a success - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).Times(1), - ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - ctx: ctx, - cancel: cancel, - clientFactory: mockClientFactory, - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - go func() { - acsSession.Start() - }() - start := time.Now() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } - - // Measure the duration between retries - timeSinceStart := time.Since(start) - if timeSinceStart < connectionBackoffMin { - t.Errorf("Duration since start is less than minimum threshold for backoff: %s", timeSinceStart.String()) - } - - // The upper limit here should really be connectionBackoffMin + (connectionBackoffMin * jitter) - // But, it can be off by a few milliseconds to account for execution of other instructions - // In any case, it should never be higher than 4*connectionBackoffMin - if timeSinceStart > 4*connectionBackoffMin { - t.Errorf("Duration since start is greater than maximum anticipated wait time: %v", timeSinceStart.String()) - } -} - -// TestConnectionIsClosedOnIdle tests if the connection to ACS is closed -// when the channel is idle -func TestConnectionIsClosedOnIdle(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - defer cancel() - - wait := sync.WaitGroup{} - wait.Add(1) - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil) - mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { - wait.Done() - // Pretend as if the maximum heartbeatTimeout duration has - // been breached while Serving requests - time.Sleep(30 * time.Millisecond) - }).Return(io.EOF) - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - connectionClosed := make(chan bool) - mockWsClient.EXPECT().Close().Do(func() { - wait.Wait() - // Record connection closed - connectionClosed <- true - }).Return(nil) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - ctx: context.Background(), - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - go acsSession.startACSSession(mockWsClient) - - // Wait for connection to be closed. If the connection is not closed - // due to inactivity, the test will time out - <-connectionClosed -} - -// TestConnectionIsClosedAfterTimeIsUp tests if the connection to ACS is closed -// when the session's connection time is expired. -func TestConnectionIsClosedAfterTimeIsUp(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - defer cancel() - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil) - mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { - // pretend as if the connectionTime has elapsed - time.Sleep(30 * time.Millisecond) - cancel() - }).Return(io.EOF) - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - - // set connectionTime to a value lower than the heartbeatTimeout to avoid - // closing the connection due to the heartbeatTimer's callback func - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - ctx: context.Background(), - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - _heartbeatTimeout: 50 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 20 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - metricsFactory: metrics.NewNopEntryFactory(), - } - - go func() { - messageError := make(chan error) - messageError <- acsSession.startACSSession(mockWsClient) - assert.EqualError(t, <-messageError, io.EOF.Error()) - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -func TestHandlerDoesntLeakGoroutines(t *testing.T) { - // Skip this test on "windows" platform as we have observed this to - // fail often after upgrading the windows builds to golang v1.17. - if runtime.GOOS == "windows" { - t.Skip() - } - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - closeWS := make(chan bool) - server, serverIn, requests, errs, err := startMockAcsServer(t, closeWS) - if err != nil { - t.Fatal(err) - } - go func() { - for { - select { - case <-requests: - case <-errs: - case <-ctx.Done(): - return - } - } - }() - - timesConnected := 0 - ecsClient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).AnyTimes().Do(func(_ interface{}) { - timesConnected++ - }) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - taskEngine.EXPECT().AddTask(gomock.Any()).AnyTimes() - dockerClient.EXPECT().SystemPing(gomock.Any(), gomock.Any()).AnyTimes() - - emptyHealthchecksList := []doctor.Healthcheck{} - emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "test-cluster", "this:is:an:instance:arn") - - ended := make(chan bool, 1) - go func() { - - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - dockerClient: dockerClient, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - ctx: ctx, - clientFactory: acsclient.NewACSClientFactory(), - _heartbeatTimeout: 1 * time.Second, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - credentialsManager: rolecredentials.NewManager(), - latestSeqNumTaskManifest: aws.Int64(12), - doctor: emptyDoctor, - metricsFactory: metrics.NewNopEntryFactory(), - } - acsSession.Start() - ended <- true - }() - // Warm it up - serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}` - serverIn <- samplePayloadMessage - - beforeGoroutines := runtime.NumGoroutine() - for i := 0; i < 40; i++ { - serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}` - serverIn <- samplePayloadMessage - closeWS <- true - } - - cancel() - <-ended - - afterGoroutines := runtime.NumGoroutine() - - t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines) - - if timesConnected < 20 { - t.Fatal("Expected times connected to be a large number, was ", timesConnected) - } - if afterGoroutines > beforeGoroutines+2 { - t.Error("Goroutine leak, oh no!") - pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) - } - -} - -// TestStartSessionHandlesRefreshCredentialsMessages tests the agent restart -// scenario where the payload to refresh credentials is processed immediately on -// connection establishment with ACS -func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - closeWS := make(chan bool) - server, serverIn, requestsChan, errChan, err := startMockAcsServer(t, closeWS) - if err != nil { - t.Fatal(err) - } - defer close(serverIn) - - go func() { - for { - select { - case <-requestsChan: - // Cancel the context when we get the ack request - cancel() - } - } - }() - - // DiscoverPollEndpoint returns the URL for the server that we started - ecsClient.EXPECT().DiscoverPollEndpoint("myArn").Return(server.URL, nil).Times(1) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - credentialsManager := mock_credentials.NewMockManager(ctrl) - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - - emptyHealthchecksList := []doctor.Healthcheck{} - emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "test-cluster", "this:is:a:container:arn") - - latestSeqNumberTaskManifest := int64(10) - ended := make(chan bool, 1) - go func() { - acsSession := NewSession(ctx, - testConfig, - nil, - "myArn", - testCreds, - dockerClient, - ecsClient, - dockerstate.NewTaskEngineState(), - data.NewNoopClient(), - taskEngine, - credentialsManager, - taskHandler, - &latestSeqNumberTaskManifest, - emptyDoctor, - acsclient.NewACSClientFactory(), - nil, - metrics.NewNopEntryFactory(), - ) - acsSession.Start() - // StartSession should never return unless the context is canceled - ended <- true - }() - - updatedCredentials := rolecredentials.TaskIAMRoleCredentials{} - taskFromEngine := &apitask.Task{} - credentialsIdInRefreshMessage := "credsId" - // Ensure that credentials manager interface methods are invoked in the - // correct order, with expected arguments - gomock.InOrder( - // The last invocation of SetCredentials is to update - // credentials when a refresh message is received by the handler - credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()).Do(func(creds *rolecredentials.TaskIAMRoleCredentials) { - updatedCredentials = *creds - // Validate parsed credentials after the update - expectedCreds := rolecredentials.TaskIAMRoleCredentials{ - ARN: "t1", - IAMRoleCredentials: rolecredentials.IAMRoleCredentials{ - RoleArn: "r1", - AccessKeyID: "newakid", - SecretAccessKey: "newskid", - SessionToken: "newstkn", - Expiration: "later", - CredentialsID: credentialsIdInRefreshMessage, - RoleType: "TaskApplication", - }, - } - if !reflect.DeepEqual(updatedCredentials, expectedCreds) { - t.Errorf("Mismatch between expected and credentials expected: %v, added: %v", expectedCreds, updatedCredentials) - } - }).Return(nil), - // Return a task from the engine for GetTaskByArn - taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true), - ) - serverIn <- sampleRefreshCredentialsMessage - - select { - case err := <-errChan: - t.Fatal("Error should not have been returned from server", err) - case <-ctx.Done(): - // Context is canceled when requestsChan receives an ack - } - - // Validate that the correct credentialsId is set for the task - credentialsIdFromTask := taskFromEngine.GetCredentialsID() - if credentialsIdFromTask != credentialsIdInRefreshMessage { - t.Errorf("Mismatch between expected and added credentials id for task, expected: %s, added: %s", credentialsIdInRefreshMessage, credentialsIdFromTask) - } - - server.Close() - // Cancel context should close the session - <-ended -} - -// TestHandlerCorrectlySetsSendCredentials tests if 'sendCredentials' -// is set correctly for successive invocations of startACSSession -func TestHandlerCorrectlySetsSendCredentials(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - deregisterInstanceEventStream.StartListening() - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - emptyHealthchecksList := []doctor.Healthcheck{} - emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "test-cluster", "this:is:an:instance:arn") - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() - - acsSession := NewSession( - ctx, - testConfig, - deregisterInstanceEventStream, - "myArn", - testCreds, - dockerClient, - ecsClient, - dockerstate.NewTaskEngineState(), - data.NewNoopClient(), - taskEngine, - rolecredentials.NewManager(), - taskHandler, - aws.Int64(10), - emptyDoctor, - mockClientFactory, - nil, - metrics.NewNopEntryFactory()) - acsSession.(*session)._heartbeatTimeout = 20 * time.Millisecond - acsSession.(*session)._heartbeatJitter = 10 * time.Millisecond - acsSession.(*session).connectionTime = 30 * time.Millisecond - acsSession.(*session).connectionJitter = 10 * time.Millisecond - gomock.InOrder( - // When the websocket client connects to ACS for the first - // time, 'sendCredentials' should be set to true - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - assert.Equal(t, true, acsSession.(*session).sendCredentials) - }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil), - // For all subsequent connections to ACS, 'sendCredentials' - // should be set to false - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - assert.Equal(t, false, acsSession.(*session).sendCredentials) - }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes(), - ) - - go func() { - for i := 0; i < 10; i++ { - acsSession.(*session).startACSSession(mockWsClient) - } - cancel() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerReconnectCorrectlySetsAcsUrl tests if the ACS URL -// is set correctly for the initial connection and subsequent connections -func TestHandlerReconnectCorrectlySetsAcsUrl(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - dockerVerStr := "1.5.0" - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return(fmt.Sprintf("Docker: %s", dockerVerStr), nil).AnyTimes() - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - deregisterInstanceEventStream.StartListening() - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - emptyHealthchecksList := []doctor.Healthcheck{} - emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "test-cluster", "this:is:an:instance:arn") - - mockBackoff := mock_retry.NewMockBackoff(ctrl) - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() - - // On the initial connection, sendCredentials must be true because Agent forces ACS to send credentials. - initialAcsURL := fmt.Sprintf( - "http://endpoint.tld/ws?agentHash=%s&agentVersion=%s&clusterArn=%s&containerInstanceArn=%s&"+ - "dockerVersion=DockerVersion%%3A+Docker%%3A+%s&protocolVersion=%v&sendCredentials=true&seqNum=1", - version.GitShortHash, version.Version, testConfig.Cluster, "myArn", dockerVerStr, acsProtocolVersion) - - // But after that, ACS sends credentials at ACS's own cadence, so sendCredentials must be false. - subsequentAcsURL := fmt.Sprintf( - "http://endpoint.tld/ws?agentHash=%s&agentVersion=%s&clusterArn=%s&containerInstanceArn=%s&"+ - "dockerVersion=DockerVersion%%3A+Docker%%3A+%s&protocolVersion=%v&sendCredentials=false&seqNum=1", - version.GitShortHash, version.Version, testConfig.Cluster, "myArn", dockerVerStr, acsProtocolVersion) - - gomock.InOrder( - mockClientFactory.EXPECT(). - New(initialAcsURL, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient), - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil), - mockBackoff.EXPECT().Reset(), - mockClientFactory.EXPECT(). - New(subsequentAcsURL, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient), - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, interface{}, interface{}) { - cancel() - }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil), - ) - acsSession := NewSession( - ctx, - testConfig, - deregisterInstanceEventStream, - "myArn", - testCreds, - dockerClient, - ecsClient, - dockerstate.NewTaskEngineState(), - data.NewNoopClient(), - taskEngine, - rolecredentials.NewManager(), - taskHandler, - aws.Int64(10), - emptyDoctor, - mockClientFactory, - nil, - metrics.NewNopEntryFactory()) - acsSession.(*session).backoff = mockBackoff - acsSession.(*session)._heartbeatTimeout = 20 * time.Millisecond - acsSession.(*session)._heartbeatJitter = 10 * time.Millisecond - acsSession.(*session).connectionTime = 30 * time.Millisecond - acsSession.(*session).connectionJitter = 10 * time.Millisecond - - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } -} - -// TestHandlerCallsAddUpdateRequestHandlers tests that the session handler calls the function -// contained in session struct field addUpdateRequestHandlers is called if it is not nil -func TestHandlerCallsAddUpdateRequestHandlers(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - addUpdateRequestHandlersCalled := false - addUpdateRequestHandlers := func(cs wsclient.ClientServer) { - addUpdateRequestHandlersCalled = true - } - - taskEngine := mock_engine.NewMockTaskEngine(ctrl) - taskEngine.EXPECT().Version().Return("Docker: 1.5.0", nil).AnyTimes() - - ecsClient := mock_api.NewMockECSClient(ctrl) - ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() - - ctx, cancel := context.WithCancel(context.Background()) - taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) - - deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) - deregisterInstanceEventStream.StartListening() - - mockWsClient := mock_wsclient.NewMockClientServer(ctrl) - mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) - mockClientFactory.EXPECT(). - New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(mockWsClient).AnyTimes() - mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() - mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { - if addUpdateRequestHandlersCalled { - cancel() - } - }) - mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() - mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - - acsSession := NewSession( - ctx, - testConfig, - deregisterInstanceEventStream, - "myArn", - testCreds, - nil, - ecsClient, - nil, - data.NewNoopClient(), - taskEngine, - nil, - taskHandler, - nil, - nil, - mockClientFactory, - addUpdateRequestHandlers, - metrics.NewNopEntryFactory(), - ) - - go func() { - acsSession.Start() - }() - - // Wait for context to be cancelled - select { - case <-ctx.Done(): - } - - assert.True(t, addUpdateRequestHandlersCalled) -} - -// TODO: replace with gomock -func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { - serverChan := make(chan string, 1) - requestsChan := make(chan string, 1) - errChan := make(chan error, 1) - - upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ws, err := upgrader.Upgrade(w, r, nil) - - if err != nil { - errChan <- err - } - - go func() { - _, msg, err := ws.ReadMessage() - if err != nil { - errChan <- err - } else { - requestsChan <- string(msg) - } - }() - for { - select { - case str := <-serverChan: - err := ws.WriteMessage(websocket.TextMessage, []byte(str)) - if err != nil { - errChan <- err - } - - case <-closeWS: - ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - ws.Close() - errChan <- io.EOF - // Quit listening to serverChan if we've been closed - return - } - - } - }) - - server := httptest.NewTLSServer(handler) - return server, serverChan, requestsChan, errChan, nil -} - -// validateAddedTask validates fields in addedTask for expected values -// It returns an error if there's a mismatch -func validateAddedTask(expectedTask apitask.Task, addedTask apitask.Task) error { - // The ecsacs.Task -> apitask.Task conversion initializes all fields in apitask.Task - // with empty objects. So, we create a new object to compare with only those - // fields that we are intrested in for comparison - taskToCompareFromAdded := apitask.Task{ - Arn: addedTask.Arn, - Family: addedTask.Family, - Version: addedTask.Version, - DesiredStatusUnsafe: addedTask.GetDesiredStatus(), - } - - if !reflect.DeepEqual(expectedTask, taskToCompareFromAdded) { - return fmt.Errorf("Mismatch between added and expected task: expected: %v, added: %v", expectedTask, taskToCompareFromAdded) - } - - return nil -} - -// validateAddedContainer validates fields in addedContainer for expected values -// It returns an error if there's a mismatch -func validateAddedContainer(expectedContainer *apicontainer.Container, addedContainer *apicontainer.Container) error { - // The ecsacs.Task -> apitask.Task conversion initializes all fields in apicontainer.Container - // with empty objects. So, we create a new object to compare with only those - // fields that we are intrested in for comparison - containerToCompareFromAdded := &apicontainer.Container{ - Name: addedContainer.Name, - CPU: addedContainer.CPU, - Essential: addedContainer.Essential, - Memory: addedContainer.Memory, - Image: addedContainer.Image, - } - if !reflect.DeepEqual(expectedContainer, containerToCompareFromAdded) { - return fmt.Errorf("Mismatch between added and expected container: expected: %v, added: %v", expectedContainer, containerToCompareFromAdded) - } - return nil -} diff --git a/agent/app/agent.go b/agent/app/agent.go index 192ca59b3f3..2208bd48111 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -56,6 +56,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper" "github.com/aws/amazon-ecs-agent/agent/version" acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" @@ -64,9 +65,10 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" md "github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon" - ecs_agent_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + metricsfactory "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" aws_credentials "github.com/aws/aws-sdk-go/aws/credentials" @@ -997,29 +999,71 @@ func (agent *ecsAgent) startACSSession( taskHandler *eventhandler.TaskHandler, doctor *doctor.Doctor) int { - acsSession := acshandler.NewSession( - agent.ctx, - agent.cfg, - deregisterInstanceEventStream, - agent.containerInstanceARN, - agent.credentialProvider, - agent.dockerClient, + inactiveInstanceCB := func() { + // If the instance is inactive (i.e., was deregistered), send an event to the event stream + // for the same. + err := deregisterInstanceEventStream.WriteToEventStream(struct{}{}) + if err != nil { + logger.Debug("Failed to write to deregister container instance event stream", logger.Fields{ + field.Error: err, + }) + } + } + + dockerVersion, err := taskEngine.Version() + if err != nil { + if err != nil { + logger.Warn("Failed to get docker version from task engine", logger.Fields{ + field.Error: err, + }) + } + } + + minAgentCfg := &wsclient.WSClientMinAgentConfig{ + AcceptInsecureCert: agent.cfg.AcceptInsecureCert, + AWSRegion: agent.cfg.AWSRegion, + DockerEndpoint: agent.cfg.DockerEndpoint, + IsDocker: true, + } + + payloadMessageHandler := acshandler.NewPayloadMessageHandler(taskEngine, client, agent.dataClient, taskHandler, + credentialsManager, agent.latestSeqNumberTaskManifest) + credsMetadataSetter := acshandler.NewCredentialsMetadataSetter(taskEngine) + eniHandler := acshandler.NewENIHandler(state, agent.dataClient) + manifestMessageIDAccessor := acshandler.NewManifestMessageIDAccessor() + sequenceNumberAccessor := acshandler.NewSequenceNumberAccessor(agent.latestSeqNumberTaskManifest, agent.dataClient) + taskComparer := acshandler.NewTaskComparer(taskEngine) + taskStopper := acshandler.NewTaskStopper(taskEngine) + + acsSession := session.NewSession(agent.containerInstanceARN, + agent.cfg.Cluster, client, - state, - agent.dataClient, - taskEngine, + agent.credentialProvider, + inactiveInstanceCB, + acsclient.NewACSClientFactory(), + metricsfactory.NewNopEntryFactory(), + version.Version, + version.GitHashString(), + dockerVersion, + minAgentCfg, + payloadMessageHandler, credentialsManager, - taskHandler, - agent.latestSeqNumberTaskManifest, + credsMetadataSetter, doctor, - acsclient.NewACSClientFactory(), + eniHandler, + manifestMessageIDAccessor, + taskComparer, + sequenceNumberAccessor, + taskStopper, + nil, updater.NewUpdater(agent.cfg, state, agent.dataClient, taskEngine).AddAgentUpdateHandlers, - ecs_agent_metrics.NewNopEntryFactory(), ) - seelog.Info("Beginning Polling for updates") - err := acsSession.Start() - if err != nil { - seelog.Criticalf("Unretriable error starting communicating with ACS: %v", err) + logger.Info("Beginning Polling for updates") + sessionEndReason := acsSession.Start(agent.ctx) + if sessionEndReason == nil { + // Agent somehow exited without a reason. + // We don't expect this condition to ever be reached, but log a critical error just in case it is. + logger.Critical("ACS session ended for unknown reason") return exitcodes.ExitTerminal } return exitcodes.ExitSuccess diff --git a/agent/go.mod b/agent/go.mod index 9e3ee012966..5a1d30c6944 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -17,7 +17,6 @@ require ( github.com/fsnotify/fsnotify v1.6.0 github.com/golang/mock v1.4.1 github.com/gorilla/mux v1.8.0 - github.com/gorilla/websocket v1.5.0 github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 github.com/opencontainers/runtime-spec v1.0.3-0.20210326190908-1c3f411f0417 @@ -49,6 +48,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/uuid v1.3.0 // indirect + github.com/gorilla/websocket v1.5.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/moby/sys/mount v0.3.3 // indirect diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go index 0eae5c3f1c5..37bbf68d0b9 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go @@ -11,13 +11,341 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. +// Package session deals with appropriately reacting to all ACS messages as well +// as maintaining the connection to ACS. package session import ( + "context" + "io" + "net/url" + "strconv" + "strings" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/api" + rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime" "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +const ( + // heartbeatTimeout is the maximum time to wait between heartbeats + // without disconnecting. + heartbeatTimeout = 1 * time.Minute + heartbeatJitter = 1 * time.Minute + + // wsRWTimeout is the duration of read and write deadline for the + // websocket connection. + wsRWTimeout = 2*heartbeatTimeout + heartbeatJitter + + inactiveInstanceReconnectDelay = 1 * time.Hour + + connectionBackoffMin = 250 * time.Millisecond + connectionBackoffMax = 2 * time.Minute + connectionBackoffJitter = 0.2 + connectionBackoffMultiplier = 1.5 + + inactiveInstanceExceptionPrefix = "InactiveInstanceException" + + // ACS protocol version spec: + // 1: default protocol version + // 2: ACS will proactively close the connection when heartbeat ACKs are missing + acsProtocolVersion = 2 ) +// Session defines an interface for Agent's long-lived connection with ACS. +// The Session.Start() method can be used to start processing messages from ACS. +type Session interface { + Start(context.Context) error +} + +// session encapsulates all arguments needed to connect to ACS and to handle messages received by ACS. +type session struct { + containerInstanceARN string + cluster string + credentialsProvider *credentials.Credentials + discoverEndpointClient api.ECSDiscoverEndpointSDK + inactiveInstanceCB func() + agentVersion string + agentHash string + dockerVersion string + payloadMessageHandler PayloadMessageHandler + credentialsManager rolecredentials.Manager + credentialsMetadataSetter CredentialsMetadataSetter + doctor *doctor.Doctor + eniHandler ENIHandler + manifestMessageIDAccessor ManifestMessageIDAccessor + taskComparer TaskComparer + sequenceNumberAccessor SequenceNumberAccessor + taskStopper TaskStopper + resourceHandler ResourceHandler + backoff retry.Backoff + sendCredentials bool + clientFactory wsclient.ClientFactory + metricsFactory metrics.EntryFactory + minAgentConfig *wsclient.WSClientMinAgentConfig + addUpdateRequestHandlers func(wsclient.ClientServer) + heartbeatTimeout time.Duration + heartbeatJitter time.Duration + disconnectTimeout time.Duration + disconnectJitter time.Duration + inactiveInstanceReconnectDelay time.Duration +} + +// NewSession creates a new Session. +func NewSession(containerInstanceARN string, + cluster string, + discoverEndpointClient api.ECSDiscoverEndpointSDK, + credentialsProvider *credentials.Credentials, + inactiveInstanceCB func(), + clientFactory wsclient.ClientFactory, + metricsFactory metrics.EntryFactory, + agentVersion string, + agentHash string, + dockerVersion string, + minAgentConfig *wsclient.WSClientMinAgentConfig, + payloadMessageHandler PayloadMessageHandler, + credentialsManager rolecredentials.Manager, + credentialsMetadataSetter CredentialsMetadataSetter, + doctor *doctor.Doctor, + eniHandler ENIHandler, + manifestMessageIDAccessor ManifestMessageIDAccessor, + taskComparer TaskComparer, + sequenceNumberAccessor SequenceNumberAccessor, + taskStopper TaskStopper, + resourceHandler ResourceHandler, + addUpdateRequestHandlers func(wsclient.ClientServer), +) Session { + backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier) + return &session{ + containerInstanceARN: containerInstanceARN, + cluster: cluster, + discoverEndpointClient: discoverEndpointClient, + credentialsProvider: credentialsProvider, + inactiveInstanceCB: inactiveInstanceCB, + clientFactory: clientFactory, + metricsFactory: metricsFactory, + agentVersion: agentVersion, + agentHash: agentHash, + dockerVersion: dockerVersion, + minAgentConfig: minAgentConfig, + payloadMessageHandler: payloadMessageHandler, + credentialsManager: credentialsManager, + credentialsMetadataSetter: credentialsMetadataSetter, + doctor: doctor, + eniHandler: eniHandler, + manifestMessageIDAccessor: manifestMessageIDAccessor, + taskComparer: taskComparer, + sequenceNumberAccessor: sequenceNumberAccessor, + taskStopper: taskStopper, + resourceHandler: resourceHandler, + addUpdateRequestHandlers: addUpdateRequestHandlers, + backoff: backoff, + sendCredentials: true, + heartbeatTimeout: heartbeatTimeout, + heartbeatJitter: heartbeatJitter, + disconnectTimeout: wsclient.DisconnectTimeout, + disconnectJitter: wsclient.DisconnectJitterMax, + inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + } +} + +// Start starts the session. It'll forever keep trying to connect to ACS unless +// the context is closed. +// +// If the context is closed, Start() would return with the error code returned +// by the context. +func (s *session) Start(ctx context.Context) error { + // connectToACS channel is used to indicate the intent to connect to ACS + // It's processed by the select loop to connect to ACS. + connectToACS := make(chan struct{}) + + // The below is required to trigger the first connection to ACS. + sendEmptyMessageOnChannel(connectToACS) + + // Loop continuously until context is closed/canceled. + for { + select { + case <-connectToACS: + logger.Debug("Received connect to ACS message. Attempting connect to ACS") + + // Start a session with ACS. + acsError := s.startSessionOnce(ctx) + + // Session with ACS was stopped with some error, start processing the error. + reconnectDelay, ok := s.reconnectDelay(acsError) + + if ok { + logger.Info("Waiting before reconnecting to ACS", logger.Fields{ + "reconnectDelay": reconnectDelay.String(), + }) + waitComplete := waitForDuration(ctx, reconnectDelay) + if waitComplete { + // If the context was not canceled and we've waited for the + // wait duration without any errors, send the message to the channel + // to reconnect to ACS. + logger.Info("Done waiting; reconnecting to ACS") + sendEmptyMessageOnChannel(connectToACS) + } else { + // Wait was interrupted. We expect the session to close as canceling + // the session context is the only way to end up here. Print a message + // to indicate the same. + logger.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close") + } + } else { + // No need to delay reconnect - reconnect immediately. + logger.Info("Reconnecting to ACS immediately without waiting") + sendEmptyMessageOnChannel(connectToACS) + } + case <-ctx.Done(): + logger.Info("ACS session ended (context closed)", logger.Fields{ + field.Reason: ctx.Err(), + }) + return ctx.Err() + } + } +} + +// startSessionOnce creates a session with ACS and handles requests using the passed +// in arguments. +func (s *session) startSessionOnce(ctx context.Context) error { + acsEndpoint, err := s.discoverEndpointClient.DiscoverPollEndpoint(s.containerInstanceARN) + if err != nil { + logger.Error("ACS: Unable to discover poll endpoint", logger.Fields{ + field.Error: err, + }) + return err + } + + client := s.clientFactory.New( + s.acsURL(acsEndpoint), + s.credentialsProvider, + wsRWTimeout, + s.minAgentConfig, + s.metricsFactory) + defer client.Close() + + // Invoke Connect method as soon as we create client. This will ensure all the + // request handlers to be associated with this client have a valid connection. + disconnectTimer, err := client.Connect(metrics.ACSDisconnectTimeoutMetricName, s.disconnectTimeout, + s.disconnectJitter) + if err != nil { + logger.Error("Failed to connect to ACS", logger.Fields{ + field.Error: err, + }) + return err + } + defer disconnectTimer.Stop() + + // Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence + // and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS. + logger.Info("Connected to ACS endpoint") + s.sendCredentials = false + + return s.startACSSession(ctx, client) +} + +// startACSSession starts a session with ACS. It adds request handlers for various +// kinds of messages expected from ACS. It returns on server disconnection or when +// the context is canceled. +func (s *session) startACSSession(ctx context.Context, client wsclient.ClientServer) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + responseSender := func(response interface{}) error { + return client.MakeRequest(response) + } + responders := []wsclient.RequestResponder{ + NewPayloadResponder(s.payloadMessageHandler, responseSender), + NewRefreshCredentialsResponder(s.credentialsManager, s.credentialsMetadataSetter, s.metricsFactory, + responseSender), + NewAttachTaskENIResponder(s.eniHandler, responseSender), + NewAttachInstanceENIResponder(s.eniHandler, responseSender), + NewHeartbeatResponder(s.doctor, responseSender), + NewTaskManifestResponder(s.taskComparer, s.sequenceNumberAccessor, s.manifestMessageIDAccessor, + s.metricsFactory, responseSender), + NewTaskStopVerificationACKResponder(s.taskStopper, s.manifestMessageIDAccessor, s.metricsFactory), + } + for _, r := range responders { + client.AddRequestHandler(r.HandlerFunc()) + } + + if s.dockerVersion == "containerd" && s.resourceHandler != nil { + client.AddRequestHandler(NewAttachResourceResponder(s.resourceHandler, s.metricsFactory, + responseSender).HandlerFunc()) + } + + if s.dockerVersion != "containerd" && s.addUpdateRequestHandlers != nil { + s.addUpdateRequestHandlers(client) + } + + // Start a heartbeat timer for closing the connection. + heartbeatTimer := newHeartbeatTimer(client, s.heartbeatTimeout, s.heartbeatJitter) + // Any message from the server resets the heartbeat timer. + client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client)) + defer heartbeatTimer.Stop() + + backoffResetTimer := time.AfterFunc( + retry.AddJitter(s.heartbeatTimeout, s.heartbeatJitter), func() { + // If we do not have an error connecting and remain connected for at + // least 1 or so minutes, reset the backoff. This prevents disconnect + // errors that only happen infrequently from damaging the reconnect + // delay as significantly. + s.backoff.Reset() + }) + defer backoffResetTimer.Stop() + + return client.Serve(ctx) +} + +func (s *session) reconnectDelay(acsError error) (time.Duration, bool) { + if isInactiveInstanceError(acsError) { + logger.Info("Container instance is deregistered") + s.inactiveInstanceCB() + return s.inactiveInstanceReconnectDelay, true + } + if shouldReconnectWithoutBackoff(acsError) { + // ACS has closed the connection for valid reasons. Example: periodic disconnect. + // No need to wait/backoff to reconnect. + logger.Info("ACS WebSocket connection closed for a valid reason") + s.backoff.Reset() + return 0, false + + } + // Disconnected unexpectedly from ACS, compute backoff duration to reconnect. + return s.backoff.Duration(), true +} + +// acsURL returns the websocket url for ACS given the endpoint. +func (s *session) acsURL(endpoint string) string { + wsURL := endpoint + if endpoint[len(endpoint)-1] != '/' { + wsURL += "/" + } + wsURL += "ws" + query := url.Values{} + query.Set("clusterArn", s.cluster) + query.Set("containerInstanceArn", s.containerInstanceARN) + query.Set("agentHash", s.agentHash) + query.Set("agentVersion", s.agentVersion) + query.Set("seqNum", "1") + query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion)) + if s.dockerVersion != "" { + query.Set("dockerVersion", formatDockerVersion(s.dockerVersion)) + } + // Below indicates if ACS should send credentials for all tasks upon establishing the connection. + query.Set("sendCredentials", strconv.FormatBool(s.sendCredentials)) + return wsURL + "?" + query.Encode() +} + // ResponseToACSSender returns a wsclient.RespondFunc that a responder can invoke in response to receiving and // processing specific websocket request messages from ACS. The returned wsclient.RespondFunc: // 1. logs the response to be sent, as well as the name of the invoking responder @@ -31,3 +359,69 @@ func ResponseToACSSender(responderName string, responseSender wsclient.RespondFu return responseSender(response) } } + +// newHeartbeatTimer creates a new time object, with a callback to +// disconnect from ACS on inactivity (i.e., after timeout + jitter). +func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer { + timer := time.AfterFunc(retry.AddJitter(timeout, jitter), func() { + logger.Warn("ACS Connection hasn't had any activity for too long; closing connection") + if err := client.Close(); err != nil { + logger.Warn("Error disconnecting from ACS", logger.Fields{ + field.Error: err, + }) + } + logger.Info("Disconnected from ACS") + }) + + return timer +} + +// anyMessageHandler handles any server message. Any server message means the +// connection is active and thus the heartbeat disconnect should not occur. +func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) { + return func(interface{}) { + logger.Debug("ACS activity occurred") + // Reset read deadline as there's activity on the channel. + if err := client.SetReadDeadline(time.Now().Add(wsRWTimeout)); err != nil { + logger.Warn("Unable to extend read deadline for ACS connection", logger.Fields{ + field.Error: err, + }) + } + + // Reset heartbeat timer. + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// waitForDuration waits for the specified duration of time. It returns true if the wait time has completed. +// Else, it returns false. +func waitForDuration(ctx context.Context, duration time.Duration) bool { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(duration)) + defer cancel() + <-ctx.Done() + err := ctx.Err() + return err == context.DeadlineExceeded +} + +// sendEmptyMessageOnChannel sends an empty message using a goroutine on the +// specified channel. +func sendEmptyMessageOnChannel(channel chan<- struct{}) { + go func() { + channel <- struct{}{} + }() +} + +func shouldReconnectWithoutBackoff(acsError error) bool { + return acsError == nil || acsError == io.EOF +} + +func isInactiveInstanceError(acsError error) bool { + return acsError != nil && strings.Contains(acsError.Error(), inactiveInstanceExceptionPrefix) +} + +func formatDockerVersion(dockerVersionValue string) string { + if dockerVersionValue != "containerd" { + return "DockerVersion: " + dockerVersionValue + } + return dockerVersionValue +} diff --git a/ecs-agent/acs/session/session.go b/ecs-agent/acs/session/session.go index 0eae5c3f1c5..37bbf68d0b9 100644 --- a/ecs-agent/acs/session/session.go +++ b/ecs-agent/acs/session/session.go @@ -11,13 +11,341 @@ // express or implied. See the License for the specific language governing // permissions and limitations under the License. +// Package session deals with appropriately reacting to all ACS messages as well +// as maintaining the connection to ACS. package session import ( + "context" + "io" + "net/url" + "strconv" + "strings" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/api" + rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime" "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +const ( + // heartbeatTimeout is the maximum time to wait between heartbeats + // without disconnecting. + heartbeatTimeout = 1 * time.Minute + heartbeatJitter = 1 * time.Minute + + // wsRWTimeout is the duration of read and write deadline for the + // websocket connection. + wsRWTimeout = 2*heartbeatTimeout + heartbeatJitter + + inactiveInstanceReconnectDelay = 1 * time.Hour + + connectionBackoffMin = 250 * time.Millisecond + connectionBackoffMax = 2 * time.Minute + connectionBackoffJitter = 0.2 + connectionBackoffMultiplier = 1.5 + + inactiveInstanceExceptionPrefix = "InactiveInstanceException" + + // ACS protocol version spec: + // 1: default protocol version + // 2: ACS will proactively close the connection when heartbeat ACKs are missing + acsProtocolVersion = 2 ) +// Session defines an interface for Agent's long-lived connection with ACS. +// The Session.Start() method can be used to start processing messages from ACS. +type Session interface { + Start(context.Context) error +} + +// session encapsulates all arguments needed to connect to ACS and to handle messages received by ACS. +type session struct { + containerInstanceARN string + cluster string + credentialsProvider *credentials.Credentials + discoverEndpointClient api.ECSDiscoverEndpointSDK + inactiveInstanceCB func() + agentVersion string + agentHash string + dockerVersion string + payloadMessageHandler PayloadMessageHandler + credentialsManager rolecredentials.Manager + credentialsMetadataSetter CredentialsMetadataSetter + doctor *doctor.Doctor + eniHandler ENIHandler + manifestMessageIDAccessor ManifestMessageIDAccessor + taskComparer TaskComparer + sequenceNumberAccessor SequenceNumberAccessor + taskStopper TaskStopper + resourceHandler ResourceHandler + backoff retry.Backoff + sendCredentials bool + clientFactory wsclient.ClientFactory + metricsFactory metrics.EntryFactory + minAgentConfig *wsclient.WSClientMinAgentConfig + addUpdateRequestHandlers func(wsclient.ClientServer) + heartbeatTimeout time.Duration + heartbeatJitter time.Duration + disconnectTimeout time.Duration + disconnectJitter time.Duration + inactiveInstanceReconnectDelay time.Duration +} + +// NewSession creates a new Session. +func NewSession(containerInstanceARN string, + cluster string, + discoverEndpointClient api.ECSDiscoverEndpointSDK, + credentialsProvider *credentials.Credentials, + inactiveInstanceCB func(), + clientFactory wsclient.ClientFactory, + metricsFactory metrics.EntryFactory, + agentVersion string, + agentHash string, + dockerVersion string, + minAgentConfig *wsclient.WSClientMinAgentConfig, + payloadMessageHandler PayloadMessageHandler, + credentialsManager rolecredentials.Manager, + credentialsMetadataSetter CredentialsMetadataSetter, + doctor *doctor.Doctor, + eniHandler ENIHandler, + manifestMessageIDAccessor ManifestMessageIDAccessor, + taskComparer TaskComparer, + sequenceNumberAccessor SequenceNumberAccessor, + taskStopper TaskStopper, + resourceHandler ResourceHandler, + addUpdateRequestHandlers func(wsclient.ClientServer), +) Session { + backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier) + return &session{ + containerInstanceARN: containerInstanceARN, + cluster: cluster, + discoverEndpointClient: discoverEndpointClient, + credentialsProvider: credentialsProvider, + inactiveInstanceCB: inactiveInstanceCB, + clientFactory: clientFactory, + metricsFactory: metricsFactory, + agentVersion: agentVersion, + agentHash: agentHash, + dockerVersion: dockerVersion, + minAgentConfig: minAgentConfig, + payloadMessageHandler: payloadMessageHandler, + credentialsManager: credentialsManager, + credentialsMetadataSetter: credentialsMetadataSetter, + doctor: doctor, + eniHandler: eniHandler, + manifestMessageIDAccessor: manifestMessageIDAccessor, + taskComparer: taskComparer, + sequenceNumberAccessor: sequenceNumberAccessor, + taskStopper: taskStopper, + resourceHandler: resourceHandler, + addUpdateRequestHandlers: addUpdateRequestHandlers, + backoff: backoff, + sendCredentials: true, + heartbeatTimeout: heartbeatTimeout, + heartbeatJitter: heartbeatJitter, + disconnectTimeout: wsclient.DisconnectTimeout, + disconnectJitter: wsclient.DisconnectJitterMax, + inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + } +} + +// Start starts the session. It'll forever keep trying to connect to ACS unless +// the context is closed. +// +// If the context is closed, Start() would return with the error code returned +// by the context. +func (s *session) Start(ctx context.Context) error { + // connectToACS channel is used to indicate the intent to connect to ACS + // It's processed by the select loop to connect to ACS. + connectToACS := make(chan struct{}) + + // The below is required to trigger the first connection to ACS. + sendEmptyMessageOnChannel(connectToACS) + + // Loop continuously until context is closed/canceled. + for { + select { + case <-connectToACS: + logger.Debug("Received connect to ACS message. Attempting connect to ACS") + + // Start a session with ACS. + acsError := s.startSessionOnce(ctx) + + // Session with ACS was stopped with some error, start processing the error. + reconnectDelay, ok := s.reconnectDelay(acsError) + + if ok { + logger.Info("Waiting before reconnecting to ACS", logger.Fields{ + "reconnectDelay": reconnectDelay.String(), + }) + waitComplete := waitForDuration(ctx, reconnectDelay) + if waitComplete { + // If the context was not canceled and we've waited for the + // wait duration without any errors, send the message to the channel + // to reconnect to ACS. + logger.Info("Done waiting; reconnecting to ACS") + sendEmptyMessageOnChannel(connectToACS) + } else { + // Wait was interrupted. We expect the session to close as canceling + // the session context is the only way to end up here. Print a message + // to indicate the same. + logger.Info("Interrupted waiting for reconnect delay to elapse; Expect session to close") + } + } else { + // No need to delay reconnect - reconnect immediately. + logger.Info("Reconnecting to ACS immediately without waiting") + sendEmptyMessageOnChannel(connectToACS) + } + case <-ctx.Done(): + logger.Info("ACS session ended (context closed)", logger.Fields{ + field.Reason: ctx.Err(), + }) + return ctx.Err() + } + } +} + +// startSessionOnce creates a session with ACS and handles requests using the passed +// in arguments. +func (s *session) startSessionOnce(ctx context.Context) error { + acsEndpoint, err := s.discoverEndpointClient.DiscoverPollEndpoint(s.containerInstanceARN) + if err != nil { + logger.Error("ACS: Unable to discover poll endpoint", logger.Fields{ + field.Error: err, + }) + return err + } + + client := s.clientFactory.New( + s.acsURL(acsEndpoint), + s.credentialsProvider, + wsRWTimeout, + s.minAgentConfig, + s.metricsFactory) + defer client.Close() + + // Invoke Connect method as soon as we create client. This will ensure all the + // request handlers to be associated with this client have a valid connection. + disconnectTimer, err := client.Connect(metrics.ACSDisconnectTimeoutMetricName, s.disconnectTimeout, + s.disconnectJitter) + if err != nil { + logger.Error("Failed to connect to ACS", logger.Fields{ + field.Error: err, + }) + return err + } + defer disconnectTimer.Stop() + + // Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence + // and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS. + logger.Info("Connected to ACS endpoint") + s.sendCredentials = false + + return s.startACSSession(ctx, client) +} + +// startACSSession starts a session with ACS. It adds request handlers for various +// kinds of messages expected from ACS. It returns on server disconnection or when +// the context is canceled. +func (s *session) startACSSession(ctx context.Context, client wsclient.ClientServer) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + responseSender := func(response interface{}) error { + return client.MakeRequest(response) + } + responders := []wsclient.RequestResponder{ + NewPayloadResponder(s.payloadMessageHandler, responseSender), + NewRefreshCredentialsResponder(s.credentialsManager, s.credentialsMetadataSetter, s.metricsFactory, + responseSender), + NewAttachTaskENIResponder(s.eniHandler, responseSender), + NewAttachInstanceENIResponder(s.eniHandler, responseSender), + NewHeartbeatResponder(s.doctor, responseSender), + NewTaskManifestResponder(s.taskComparer, s.sequenceNumberAccessor, s.manifestMessageIDAccessor, + s.metricsFactory, responseSender), + NewTaskStopVerificationACKResponder(s.taskStopper, s.manifestMessageIDAccessor, s.metricsFactory), + } + for _, r := range responders { + client.AddRequestHandler(r.HandlerFunc()) + } + + if s.dockerVersion == "containerd" && s.resourceHandler != nil { + client.AddRequestHandler(NewAttachResourceResponder(s.resourceHandler, s.metricsFactory, + responseSender).HandlerFunc()) + } + + if s.dockerVersion != "containerd" && s.addUpdateRequestHandlers != nil { + s.addUpdateRequestHandlers(client) + } + + // Start a heartbeat timer for closing the connection. + heartbeatTimer := newHeartbeatTimer(client, s.heartbeatTimeout, s.heartbeatJitter) + // Any message from the server resets the heartbeat timer. + client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client)) + defer heartbeatTimer.Stop() + + backoffResetTimer := time.AfterFunc( + retry.AddJitter(s.heartbeatTimeout, s.heartbeatJitter), func() { + // If we do not have an error connecting and remain connected for at + // least 1 or so minutes, reset the backoff. This prevents disconnect + // errors that only happen infrequently from damaging the reconnect + // delay as significantly. + s.backoff.Reset() + }) + defer backoffResetTimer.Stop() + + return client.Serve(ctx) +} + +func (s *session) reconnectDelay(acsError error) (time.Duration, bool) { + if isInactiveInstanceError(acsError) { + logger.Info("Container instance is deregistered") + s.inactiveInstanceCB() + return s.inactiveInstanceReconnectDelay, true + } + if shouldReconnectWithoutBackoff(acsError) { + // ACS has closed the connection for valid reasons. Example: periodic disconnect. + // No need to wait/backoff to reconnect. + logger.Info("ACS WebSocket connection closed for a valid reason") + s.backoff.Reset() + return 0, false + + } + // Disconnected unexpectedly from ACS, compute backoff duration to reconnect. + return s.backoff.Duration(), true +} + +// acsURL returns the websocket url for ACS given the endpoint. +func (s *session) acsURL(endpoint string) string { + wsURL := endpoint + if endpoint[len(endpoint)-1] != '/' { + wsURL += "/" + } + wsURL += "ws" + query := url.Values{} + query.Set("clusterArn", s.cluster) + query.Set("containerInstanceArn", s.containerInstanceARN) + query.Set("agentHash", s.agentHash) + query.Set("agentVersion", s.agentVersion) + query.Set("seqNum", "1") + query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion)) + if s.dockerVersion != "" { + query.Set("dockerVersion", formatDockerVersion(s.dockerVersion)) + } + // Below indicates if ACS should send credentials for all tasks upon establishing the connection. + query.Set("sendCredentials", strconv.FormatBool(s.sendCredentials)) + return wsURL + "?" + query.Encode() +} + // ResponseToACSSender returns a wsclient.RespondFunc that a responder can invoke in response to receiving and // processing specific websocket request messages from ACS. The returned wsclient.RespondFunc: // 1. logs the response to be sent, as well as the name of the invoking responder @@ -31,3 +359,69 @@ func ResponseToACSSender(responderName string, responseSender wsclient.RespondFu return responseSender(response) } } + +// newHeartbeatTimer creates a new time object, with a callback to +// disconnect from ACS on inactivity (i.e., after timeout + jitter). +func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer { + timer := time.AfterFunc(retry.AddJitter(timeout, jitter), func() { + logger.Warn("ACS Connection hasn't had any activity for too long; closing connection") + if err := client.Close(); err != nil { + logger.Warn("Error disconnecting from ACS", logger.Fields{ + field.Error: err, + }) + } + logger.Info("Disconnected from ACS") + }) + + return timer +} + +// anyMessageHandler handles any server message. Any server message means the +// connection is active and thus the heartbeat disconnect should not occur. +func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) { + return func(interface{}) { + logger.Debug("ACS activity occurred") + // Reset read deadline as there's activity on the channel. + if err := client.SetReadDeadline(time.Now().Add(wsRWTimeout)); err != nil { + logger.Warn("Unable to extend read deadline for ACS connection", logger.Fields{ + field.Error: err, + }) + } + + // Reset heartbeat timer. + timer.Reset(retry.AddJitter(heartbeatTimeout, heartbeatJitter)) + } +} + +// waitForDuration waits for the specified duration of time. It returns true if the wait time has completed. +// Else, it returns false. +func waitForDuration(ctx context.Context, duration time.Duration) bool { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(duration)) + defer cancel() + <-ctx.Done() + err := ctx.Err() + return err == context.DeadlineExceeded +} + +// sendEmptyMessageOnChannel sends an empty message using a goroutine on the +// specified channel. +func sendEmptyMessageOnChannel(channel chan<- struct{}) { + go func() { + channel <- struct{}{} + }() +} + +func shouldReconnectWithoutBackoff(acsError error) bool { + return acsError == nil || acsError == io.EOF +} + +func isInactiveInstanceError(acsError error) bool { + return acsError != nil && strings.Contains(acsError.Error(), inactiveInstanceExceptionPrefix) +} + +func formatDockerVersion(dockerVersionValue string) string { + if dockerVersionValue != "containerd" { + return "DockerVersion: " + dockerVersionValue + } + return dockerVersionValue +} diff --git a/ecs-agent/acs/session/session_test.go b/ecs-agent/acs/session/session_test.go new file mode 100644 index 00000000000..aa344a643a5 --- /dev/null +++ b/ecs-agent/acs/session/session_test.go @@ -0,0 +1,1329 @@ +//go:build unit +// +build unit + +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package session + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "runtime" + "runtime/pprof" + "strconv" + "sync" + "testing" + "time" + + acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" + mock_session "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + mock_api "github.com/aws/amazon-ecs-agent/ecs-agent/api/mocks" + rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" + metricsfactory "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + mock_retry "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/golang/mock/gomock" + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +const ( + samplePayloadMessage = ` +{ + "type": "PayloadMessage", + "message": { + "messageId": "123", + "tasks": [ + { + "taskDefinitionAccountId": "123", + "containers": [ + { + "environment": {}, + "name": "name", + "cpu": 1, + "essential": true, + "memory": 1, + "portMappings": [], + "overrides": "{}", + "image": "i", + "mountPoints": [], + "volumesFrom": [] + } + ], + "elasticNetworkInterfaces":[{ + "attachmentArn": "eni_attach_arn", + "ec2Id": "eni_id", + "ipv4Addresses":[{ + "primary": true, + "privateAddress": "ipv4" + }], + "ipv6Addresses": [{ + "address": "ipv6" + }], + "subnetGatewayIpv4Address": "ipv4/20", + "macAddress": "mac" + }], + "roleCredentials": { + "credentialsId": "credsId", + "accessKeyId": "accessKeyId", + "expiration": "2016-03-25T06:17:19.318+0000", + "roleArn": "r1", + "secretAccessKey": "secretAccessKey", + "sessionToken": "token" + }, + "version": "3", + "volumes": [], + "family": "f", + "arn": "arn", + "desiredStatus": "RUNNING" + } + ], + "generatedAt": 1, + "clusterArn": "1", + "containerInstanceArn": "1", + "seqNum": 1 + } +} +` + sampleRefreshCredentialsMessage = ` +{ + "type": "IAMRoleCredentialsMessage", + "message": { + "messageId": "123", + "clusterArn": "default", + "taskArn": "t1", + "roleType": "TaskApplication", + "roleCredentials": { + "credentialsId": "credsId", + "accessKeyId": "newakid", + "expiration": "later", + "roleArn": "r1", + "secretAccessKey": "newskid", + "sessionToken": "newstkn" + } + } +} +` + sampleAttachResourceMessage = ` +{ + "type": "ConfirmAttachmentMessage", + "message": { + "messageId": "123", + "clusterArn": "arn:aws:ecs:us-west-2:123456789012:cluster/a1b2c3d4-5678-90ab-cdef-11111EXAMPLE", + "containerInstanceArn": "arn:aws:ecs:us-west-2:123456789012:container-instance/a1b2c3d4-5678-90ab-cdef-11111EXAMPLE", + "taskArn": "arn:aws:ecs:us-west-2:1234567890:task/test-cluster/abc", + "waitTimeoutMs": 1000, + "attachment": { + "attachmentArn": "arn:aws:ecs:us-west-2:123456789012:ephemeral-storage/a1b2c3d4-5678-90ab-cdef-11111EXAMPLE", + "attachmentProperties": [ + { + "name": "resourceID", + "value": "id1" + }, + { + "name": "volumeID", + "value": "id1" + }, + { + "name": "volumeSizeInGiB", + "value": "size1" + }, + { + "name": "requestedSizeInGiB", + "value": "size1" + }, + { + "name": "resourceType", + "value": "EphemeralStorage" + }, + { + "name": "deviceName", + "value": "device1" + } + ] + } + } +} +` + agentVersion = "1.23.4" + agentGitShortHash = "ffffffff" + dockerVersion = "1.2.3" + acsURL = "http://endpoint.tld" +) + +var inactiveInstanceError = errors.New("InactiveInstanceException") +var noopFunc = func() {} +var testCreds = credentials.NewStaticCredentials("test-id", "test-secret", "test-token") +var testMinAgentConfig = &wsclient.WSClientMinAgentConfig{ + AcceptInsecureCert: true, + AWSRegion: "us-west-2", + DockerEndpoint: "unix:///var/run/docker.sock", + IsDocker: true, +} + +// TestACSURL tests that the URL is constructed correctly when connecting to ACS. +func TestACSURL(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + acsSession := session{ + sendCredentials: true, + containerInstanceARN: testconst.ContainerInstanceARN, + cluster: testconst.ClusterName, + agentVersion: agentVersion, + agentHash: agentGitShortHash, + dockerVersion: dockerVersion, + } + wsurl := acsSession.acsURL(acsURL) + + parsed, err := url.Parse(wsurl) + assert.NoError(t, err, "should be able to parse URL") + assert.Equal(t, "/ws", parsed.Path, "wrong path") + assert.Equal(t, testconst.ClusterName, parsed.Query().Get("clusterArn"), "wrong cluster") + assert.Equal(t, testconst.ContainerInstanceARN, parsed.Query().Get("containerInstanceArn"), + "wrong container instance") + assert.Equal(t, agentVersion, parsed.Query().Get("agentVersion"), "wrong agent version") + assert.Equal(t, agentGitShortHash, parsed.Query().Get("agentHash"), "wrong agent hash") + assert.Equal(t, "DockerVersion: "+dockerVersion, parsed.Query().Get("dockerVersion"), "wrong docker version") + assert.Equalf(t, "true", parsed.Query().Get("sendCredentials"), + "Wrong value set for: sendCredentials") + assert.Equal(t, "1", parsed.Query().Get("seqNum"), "wrong seqNum") + protocolVersion, _ := strconv.Atoi(parsed.Query().Get("protocolVersion")) + assert.True(t, protocolVersion > 1, "ACS protocol version should be greater than 1") +} + +// TestSessionReconnectsOnConnectErrors tests that Session retries reconnecting +// to establish the session with ACS when ClientServer.Connect() returns errors. +func TestSessionReconnectsOnConnectErrors(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + gomock.InOrder( + // Connect fails 10 times. + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, io.EOF).Times(10), + // Cancel trying to connect to ACS on the 11th attempt. + // Failure to retry on Connect() errors should cause the test to time out as the context is never canceled. + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + cancel() + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).MinTimes(1), + ) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestIsInactiveInstanceErrorReturnsTrueForInactiveInstance tests that 'InactiveInstance' +// exception is identified correctly. +func TestIsInactiveInstanceErrorReturnsTrueForInactiveInstance(t *testing.T) { + assert.True(t, isInactiveInstanceError(inactiveInstanceError), + "inactive instance exception message parsed incorrectly") +} + +// TestIsInactiveInstanceErrorReturnsFalseForActiveInstance tests that non-'InactiveInstance' +// exceptions are identified correctly. +func TestIsInactiveInstanceErrorReturnsFalseForActiveInstance(t *testing.T) { + assert.False(t, isInactiveInstanceError(io.EOF), + "inactive instance exception message parsed incorrectly") +} + +// TestReconnectDelayForInactiveInstance tests that the reconnect delay is computed +// correctly for an inactive instance. +func TestReconnectDelayForInactiveInstance(t *testing.T) { + acsSession := session{ + inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + inactiveInstanceCB: noopFunc, + } + delay, ok := acsSession.reconnectDelay(inactiveInstanceError) + assert.True(t, ok, "Delaying reconnect should be OK for inactive instance") + assert.Equal(t, inactiveInstanceReconnectDelay, delay, + "Reconnect delay doesn't match expected value for inactive instance") +} + +// TestReconnectDelayForActiveInstanceOnUnexpectedDisconnect tests that the reconnect delay is computed +// correctly for an active instance. +func TestReconnectDelayForActiveInstanceOnUnexpectedDisconnect(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackoff := mock_retry.NewMockBackoff(ctrl) + mockBackoff.EXPECT().Duration().Return(connectionBackoffMax) + + acsSession := session{backoff: mockBackoff} + delay, ok := acsSession.reconnectDelay(fmt.Errorf("unexpcted disconnect error")) + + assert.True(t, ok, "Delaying reconnect should be OK for active instance on unexpected disconnect") + assert.Equal(t, connectionBackoffMax, delay, + "Reconnect delay doesn't match expected value for active instance on unexpected disconnect") +} + +// TestWaitForDurationReturnsTrueWhenContextNotCanceled tests that the waitForDuration function behaves correctly when +// the passed in context is not canceled. +func TestWaitForDurationReturnsTrueWhenContextNotCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + assert.True(t, waitForDuration(ctx, time.Millisecond), + "waitForDuration should return true when uninterrupted") +} + +// TestWaitForDurationReturnsFalseWhenContextCanceled tests that the waitForDuration function behaves correctly when +// the passed in context is canceled. +func TestWaitForDurationReturnsFalseWhenContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + assert.False(t, waitForDuration(ctx, time.Millisecond), + "waitForDuration should return false when interrupted") +} + +func TestShouldReconnectWithoutBackoffReturnsTrueForEOF(t *testing.T) { + assert.True(t, shouldReconnectWithoutBackoff(io.EOF), + "Reconnect without backoff should return true when connection is closed") +} + +func TestShouldReconnectWithoutBackoffReturnsFalseForNonEOF(t *testing.T) { + assert.False(t, shouldReconnectWithoutBackoff(fmt.Errorf("not EOF")), + "Reconnect without backoff should return false for non io.EOF error") +} + +// TestSessionReconnectsWithoutBackoffOnEOFError tests that the Session reconnects +// to ACS without any delay when the connection is closed with the io.EOF error. +func TestSessionReconnectsWithoutBackoffOnEOFError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + mockBackoff := mock_retry.NewMockBackoff(ctrl) + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + gomock.InOrder( + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, io.EOF), + // The backoff.Reset() method is expected to be invoked when the connection is closed with io.EOF. + mockBackoff.EXPECT().Reset(), + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + // Cancel the context on the 2nd connect attempt, which should stop the test. + cancel() + }).Return(nil, io.EOF), + mockBackoff.EXPECT().Reset().AnyTimes(), + ) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + backoff: mockBackoff, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestSessionReconnectsWithoutBackoffOnEOFError tests that the Session reconnects +// to ACS after a backoff duration when the connection is closed with non io.EOF error. +func TestSessionReconnectsWithBackoffOnNonEOFError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + mockBackoff := mock_retry.NewMockBackoff(ctrl) + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + gomock.InOrder( + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, + fmt.Errorf("not EOF")), + // The backoff.Duration() method is expected to be invoked when the connection is closed with a non-EOF error + // code to compute the backoff. Also, no calls to backoff.Reset() are expected in this code path. + mockBackoff.EXPECT().Duration().Return(time.Millisecond), + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + cancel() + }).Return(nil, io.EOF), + mockBackoff.EXPECT().Reset().AnyTimes(), + ) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + backoff: mockBackoff, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestSessionCallsInactiveInstanceCB tests that the Session calls its inactiveInstanceCB func (which in this test +// generates an event into a deregister instance event stream) when the ACS connection is closed +// with inactive instance error. +func TestSessionCallsInactiveInstanceCB(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) + + // receiverFunc cancels the context when invoked. + // Any event on the deregister instance event stream would trigger this. + receiverFunc := func(...interface{}) error { + cancel() + return nil + } + err := deregisterInstanceEventStream.Subscribe("DeregisterContainerInstance", receiverFunc) + assert.NoError(t, err, "Error adding deregister instance event stream subscriber") + deregisterInstanceEventStream.StartListening() + inactiveInstanceCB := func() { + err := deregisterInstanceEventStream.WriteToEventStream(struct{}{}) + assert.NoError(t, err, "Error writing to deregister container instance event stream") + } + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, inactiveInstanceError) + inactiveInstanceReconnectDelay := 200 * time.Millisecond + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: inactiveInstanceCB, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + err = acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestSessionReconnectDelayForInactiveInstanceError tests that the Session applies the proper reconnect delay with ACS +// when ClientServer.Connect() returns the InstanceInactive error. +func TestSessionReconnectDelayForInactiveInstanceError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + var firstConnectionAttemptTime time.Time + inactiveInstanceReconnectDelay := 200 * time.Millisecond + gomock.InOrder( + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + firstConnectionAttemptTime = time.Now() + }).Return(nil, inactiveInstanceError), + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + reconnectDelay := time.Now().Sub(firstConnectionAttemptTime) + reconnectDelayTime := time.Now() + t.Logf("Delay between successive connections: %v", reconnectDelay) + timeSubFuncSlopAllowed := 2 * time.Millisecond + if reconnectDelay < inactiveInstanceReconnectDelay { + // On windows platform, we found issue with time.Now().Sub(...) reporting 199.9989 even + // after the code has already waited for time.NewTimer(200)ms. + assert.WithinDuration(t, reconnectDelayTime, + firstConnectionAttemptTime.Add(inactiveInstanceReconnectDelay), timeSubFuncSlopAllowed) + } + cancel() + }).Return(nil, io.EOF), + ) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + inactiveInstanceReconnectDelay: inactiveInstanceReconnectDelay, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestSessionReconnectsOnServeErrors tests that the Session retries to establish the connection with ACS when +// ClientServer.Serve() returns errors. +func TestSessionReconnectsOnServeErrors(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()). + Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + gomock.InOrder( + // Serve fails 10 times. + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).Times(10), + // Cancel trying to Serve ACS requests on the 11th attempt. + // Failure to retry on Serve() errors should cause the test to time out as the context is never canceled. + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { + cancel() + }), + ) + + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestSessionStopsWhenContextIsCanceled tests that the Session's Start() method returns +// when its context is canceled. +func TestSessionStopsWhenContextIsCanceled(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()). + Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + gomock.InOrder( + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF), + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { + cancel() + }).Return(inactiveInstanceError), + ) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestSessionStopsWhenContextIsErrorDueToTimeout tests that Session's Start() method returns +// when its context is in error due to timeout on reconnect delay. +func TestSessionStopsWhenContextIsErrorDueToTimeout(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + defer cancel() + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()). + Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Return(inactiveInstanceError).AnyTimes() + + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + inactiveInstanceReconnectDelay: 1 * time.Hour, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + err := acsSession.Start(ctx) + assert.Equal(t, context.DeadlineExceeded, err) +} + +// TestSessionReconnectsOnDiscoverPollEndpointError tests that the Session retries to establish the connection with ACS +// on DiscoverPollEndpoint errors. +func TestSessionReconnectsOnDiscoverPollEndpointError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + // Connect method being called means preceding DiscoverPollEndpoint did not return an error (i.e., was successful). + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + // Serve() cancels the context. + cancel() + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).MinTimes(1) + + gomock.InOrder( + // DiscoverPollEndpoint returns an error on its first invocation. + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return("", fmt.Errorf("oops")), + // Second invocation returns a success. + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil), + ) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + + start := time.Now() + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) + + // Measure the duration between retries. + timeSinceStart := time.Since(start) + assert.GreaterOrEqual(t, timeSinceStart, connectionBackoffMin, + "Duration since start is less than minimum threshold for backoff: %s", timeSinceStart.String()) + + // The upper limit here should really be connectionBackoffMin + (connectionBackoffMin * jitter), + // but it can be off by a few milliseconds to account for execution of other instructions. + // In any case, it should never be higher than 4*connectionBackoffMin. + assert.LessOrEqual(t, timeSinceStart, 4*connectionBackoffMin, + "Duration since start is greater than maximum anticipated wait time: %v", timeSinceStart.String()) +} + +// TestConnectionIsClosedOnIdle tests that the connection to ACS is closed when the connection is idle. +func TestConnectionIsClosedOnIdle(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()). + Return(time.NewTimer(wsclient.DisconnectTimeout), nil) + connectionInactive := false + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { + // Pretend as if the maximum heartbeatTimeout duration has been breached while Serving requests. + time.Sleep(30 * time.Millisecond) + connectionInactive = true + }).Return(io.EOF) + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).MinTimes(1) + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + // Wait for connection to be closed. If the connection is not closed due to inactivity, the test will time out. + err := acsSession.startSessionOnce(ctx) + assert.EqualError(t, err, io.EOF.Error()) + assert.True(t, connectionInactive) +} + +func TestSessionDoesntLeakGoroutines(t *testing.T) { + // Skip this test on "windows" platform as we have observed this to + // fail often after upgrading the windows builds to golang v1.17. + if runtime.GOOS == "windows" { + t.Skip() + } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + payloadMessageHandler := mock_session.NewMockPayloadMessageHandler(ctrl) + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ctx, cancel := context.WithCancel(context.Background()) + + closeWS := make(chan bool) + fakeServer, serverIn, requests, errs, err := startFakeACSServer(closeWS) + if err != nil { + t.Fatal(err) + } + go func() { + for { + select { + case <-requests: + case <-errs: + case <-ctx.Done(): + return + } + } + }() + + timesConnected := 0 + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil). + AnyTimes().Do(func(_ interface{}) { + timesConnected++ + }) + payloadMessageHandler.EXPECT().ProcessMessage(gomock.Any(), gomock.Any()).AnyTimes(). + Do(func(interface{}, interface{}) { + go func() { + time.Sleep(5 * time.Millisecond) // do some work + }() + }) + + emptyDoctor, _ := doctor.NewDoctor([]doctor.Healthcheck{}, testconst.ClusterName, testconst.ContainerInstanceARN) + + ended := make(chan bool, 1) + go func() { + acsSession := session{ + containerInstanceARN: testconst.ContainerInstanceARN, + credentialsProvider: testCreds, + dockerVersion: dockerVersion, + minAgentConfig: testMinAgentConfig, + discoverEndpointClient: discoverEndpointClient, + inactiveInstanceCB: noopFunc, + clientFactory: acsclient.NewACSClientFactory(), + metricsFactory: metricsfactory.NewNopEntryFactory(), + payloadMessageHandler: payloadMessageHandler, + heartbeatTimeout: 1 * time.Second, + doctor: emptyDoctor, + backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, + connectionBackoffJitter, connectionBackoffMultiplier), + } + acsSession.Start(ctx) + ended <- true + }() + // Warm it up. + serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}` + serverIn <- samplePayloadMessage + + beforeGoroutines := runtime.NumGoroutine() + for i := 0; i < 40; i++ { + serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}` + serverIn <- samplePayloadMessage + closeWS <- true + } + + cancel() + <-ended + + afterGoroutines := runtime.NumGoroutine() + + t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines) + + if timesConnected < 20 { + t.Fatal("Expected times connected to be a large number, was ", timesConnected) + } + if afterGoroutines > beforeGoroutines+2 { + t.Error("Goroutine leak, oh no!") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + } +} + +// TestStartSessionHandlesRefreshCredentialsMessages tests the scenario where a refresh credentials message is +// processed immediately on connection establishment with ACS. +func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + credentialsMetadataSetter := mock_session.NewMockCredentialsMetadataSetter(ctrl) + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ctx, cancel := context.WithCancel(context.Background()) + closeWS := make(chan bool) + fakeServer, serverIn, requestsChan, errChan, err := startFakeACSServer(closeWS) + if err != nil { + t.Fatal(err) + } + defer close(serverIn) + + go func() { + for { + select { + case <-requestsChan: + // Cancel the context when we get the ACK request. + cancel() + } + } + }() + + // DiscoverPollEndpoint returns the URL for the server that we started. + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil) + + credentialsManager := mock_credentials.NewMockManager(ctrl) + + ended := make(chan bool, 1) + go func() { + acsSession := NewSession(testconst.ContainerInstanceARN, + testconst.ClusterName, + discoverEndpointClient, + testCreds, + noopFunc, + acsclient.NewACSClientFactory(), + metricsfactory.NewNopEntryFactory(), + agentVersion, + agentGitShortHash, + dockerVersion, + testMinAgentConfig, + nil, + credentialsManager, + credentialsMetadataSetter, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + acsSession.Start(ctx) + // Start should never return unless the context is canceled. + ended <- true + }() + + updatedCredentials := rolecredentials.TaskIAMRoleCredentials{} + credentialsIdInRefreshMessage := "credsId" + // Ensure that credentials manager interface methods are invoked in the correct order, with expected arguments. + gomock.InOrder( + // The last invocation of SetCredentials is to update + // credentials when a refresh message is received by the handler + credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()). + Do(func(creds *rolecredentials.TaskIAMRoleCredentials) { + updatedCredentials = *creds + // Validate parsed credentials after the update + expectedCreds := rolecredentials.TaskIAMRoleCredentials{ + ARN: "t1", + IAMRoleCredentials: rolecredentials.IAMRoleCredentials{ + RoleArn: "r1", + AccessKeyID: "newakid", + SecretAccessKey: "newskid", + SessionToken: "newstkn", + Expiration: "later", + CredentialsID: credentialsIdInRefreshMessage, + RoleType: rolecredentials.ApplicationRoleType, + }, + } + assert.Equal(t, expectedCreds, updatedCredentials, "Mismatch between expected and updated credentials") + }).Return(nil), + credentialsMetadataSetter.EXPECT().SetTaskRoleCredentialsMetadata(gomock.Any()).Return(nil), + ) + serverIn <- sampleRefreshCredentialsMessage + + select { + case err := <-errChan: + t.Fatal("Error should not have been returned from server", err) + case <-ctx.Done(): + // Context is canceled when requestsChan receives an ACK. + } + + fakeServer.Close() + // Cancel context should close the session. + <-ended +} + +// TestSessionCorrectlySetsSendCredentials tests that the Session's 'sendCredentials' field +// is set correctly for successive invocations of startSessionOnce. +func TestSessionCorrectlySetsSendCredentials(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + const numInvocations = 10 + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() + + acsSession := NewSession(testconst.ContainerInstanceARN, + testconst.ClusterName, + discoverEndpointClient, + nil, + noopFunc, + mockClientFactory, + metricsfactory.NewNopEntryFactory(), + agentVersion, + agentGitShortHash, + dockerVersion, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + acsSession.(*session).heartbeatTimeout = 20 * time.Millisecond + acsSession.(*session).heartbeatJitter = 10 * time.Millisecond + acsSession.(*session).disconnectTimeout = 30 * time.Millisecond + acsSession.(*session).disconnectJitter = 10 * time.Millisecond + gomock.InOrder( + // When the websocket client connects to ACS for the first time, 'sendCredentials' should be set to true. + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + assert.Equal(t, true, acsSession.(*session).sendCredentials) + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil), + // For all subsequent connections to ACS, 'sendCredentials' should be set to false. + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + assert.Equal(t, false, acsSession.(*session).sendCredentials) + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).Times(numInvocations-1), + ) + + go func() { + for i := 0; i < numInvocations; i++ { + acsSession.(*session).startSessionOnce(ctx) + } + cancel() + }() + + // Wait for context to be canceled. + select { + case <-ctx.Done(): + } +} + +// TestSessionReconnectCorrectlySetsAcsUrl tests that the ACS URL is set correctly for the Session's initial connection +// and subsequent connections with ACS. +func TestSessionReconnectCorrectlySetsAcsUrl(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + + mockBackoff := mock_retry.NewMockBackoff(ctrl) + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockBackoff.EXPECT().Reset().AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() + + // On the initial connection, sendCredentials must be true because Agent forces ACS to send credentials. + initialAcsURL := fmt.Sprintf( + "http://endpoint.tld/ws?agentHash=%s&agentVersion=%s&clusterArn=%s&containerInstanceArn=%s&"+ + "dockerVersion=DockerVersion%%3A+%s&protocolVersion=%v&sendCredentials=true&seqNum=1", + agentGitShortHash, agentVersion, url.QueryEscape(testconst.ClusterName), + url.QueryEscape(testconst.ContainerInstanceARN), dockerVersion, acsProtocolVersion) + + // But after that, ACS sends credentials at ACS's own cadence, so sendCredentials must be false. + subsequentAcsURL := fmt.Sprintf( + "http://endpoint.tld/ws?agentHash=%s&agentVersion=%s&clusterArn=%s&containerInstanceArn=%s&"+ + "dockerVersion=DockerVersion%%3A+%s&protocolVersion=%v&sendCredentials=false&seqNum=1", + agentGitShortHash, agentVersion, url.QueryEscape(testconst.ClusterName), + url.QueryEscape(testconst.ContainerInstanceARN), dockerVersion, acsProtocolVersion) + + gomock.InOrder( + mockClientFactory.EXPECT(). + New(initialAcsURL, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient), + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()). + Return(time.NewTimer(wsclient.DisconnectTimeout), nil), + mockClientFactory.EXPECT(). + New(subsequentAcsURL, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient), + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(interface{}, + interface{}, interface{}) { + cancel() + }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil), + ) + acsSession := NewSession(testconst.ContainerInstanceARN, + testconst.ClusterName, + discoverEndpointClient, + nil, + noopFunc, + mockClientFactory, + metricsfactory.NewNopEntryFactory(), + agentVersion, + agentGitShortHash, + dockerVersion, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + acsSession.(*session).backoff = mockBackoff + acsSession.(*session).heartbeatTimeout = 20 * time.Millisecond + acsSession.(*session).heartbeatJitter = 10 * time.Millisecond + acsSession.(*session).disconnectTimeout = 30 * time.Millisecond + acsSession.(*session).disconnectJitter = 10 * time.Millisecond + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) +} + +// TestStartSessionHandlesAttachResourceMessages tests that the Session is able to handle attach +// resource messages when the Session's resourceHandler is not nil and its dockerVersion is "containerd". +func TestStartSessionHandlesAttachResourceMessages(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + resourceHandler := mock_session.NewMockResourceHandler(ctrl) + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ctx, cancel := context.WithCancel(context.Background()) + closeWS := make(chan bool) + fakeServer, serverIn, requestsChan, errChan, err := startFakeACSServer(closeWS) + if err != nil { + t.Fatal(err) + } + defer close(serverIn) + + go func() { + for { + select { + case <-requestsChan: + // Cancel the context when we get the ACK request. + cancel() + } + } + }() + + // DiscoverPollEndpoint returns the URL for the server that we started. + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil) + + ended := make(chan bool, 1) + go func() { + acsSession := NewSession(testconst.ContainerInstanceARN, + testconst.ClusterName, + discoverEndpointClient, + testCreds, + noopFunc, + acsclient.NewACSClientFactory(), + metricsfactory.NewNopEntryFactory(), + agentVersion, + agentGitShortHash, + "containerd", + testMinAgentConfig, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + resourceHandler, + nil, + ) + + acsSession.Start(ctx) + // Start should never return unless the context is canceled. + ended <- true + }() + + // WaitGroup is necessary to wait for HandleResourceAttachment to be called in separate goroutine before exiting + // the test. + wg := sync.WaitGroup{} + wg.Add(1) + resourceHandler.EXPECT().HandleResourceAttachment(gomock.Any()).Do(func(arg0 interface{}) { + defer wg.Done() // decrement WaitGroup counter now that HandleResourceAttachment function has been called + }) + + serverIn <- sampleAttachResourceMessage + + select { + case err := <-errChan: + t.Fatal("Error should not have been returned from server", err) + case <-ctx.Done(): + // Context is canceled when requestsChan receives an ACK. + } + + wg.Wait() + + fakeServer.Close() + // Cancel context should close the session. + <-ended +} + +// TestSessionCallsAddUpdateRequestHandlers tests that the Session calls the function contained in its struct field +// 'addUpdateRequestHandlers' is called if it is not nil and the Session's dockerVersion is not "containerd". +func TestSessionCallsAddUpdateRequestHandlers(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + addUpdateRequestHandlersCalled := false + addUpdateRequestHandlers := func(cs wsclient.ClientServer) { + addUpdateRequestHandlersCalled = true + } + + discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().Connect(gomock.Any(), gomock.Any(), gomock.Any()). + Return(time.NewTimer(wsclient.DisconnectTimeout), nil).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { + if addUpdateRequestHandlersCalled { + cancel() + } + }) + mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + + acsSession := NewSession(testconst.ContainerInstanceARN, + testconst.ClusterName, + discoverEndpointClient, + nil, + noopFunc, + mockClientFactory, + metricsfactory.NewNopEntryFactory(), + agentVersion, + agentGitShortHash, + dockerVersion, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + addUpdateRequestHandlers, + ) + + err := acsSession.Start(ctx) + assert.Equal(t, context.Canceled, err) + assert.True(t, addUpdateRequestHandlersCalled) +} + +func startFakeACSServer(closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { + serverChan := make(chan string, 1) + requestsChan := make(chan string, 1) + errChan := make(chan error, 1) + + upgrader := websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + + if err != nil { + errChan <- err + } + + go func() { + _, msg, err := ws.ReadMessage() + if err != nil { + errChan <- err + } else { + requestsChan <- string(msg) + } + }() + for { + select { + case str := <-serverChan: + err := ws.WriteMessage(websocket.TextMessage, []byte(str)) + if err != nil { + errChan <- err + } + + case <-closeWS: + ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, + "")) + ws.Close() + errChan <- io.EOF + // Quit listening to serverChan if we've been closed. + return + } + + } + }) + + fakeServer := httptest.NewTLSServer(handler) + return fakeServer, serverChan, requestsChan, errChan, nil +}