From d534247caf3f9bc93fefed52626048e9f23eb1ee Mon Sep 17 00:00:00 2001 From: Dane H Lim Date: Thu, 11 May 2023 20:32:33 -0700 Subject: [PATCH] Move ACS client to ecs-agent module and refactor --- agent/acs/handler/acs_handler.go | 126 +--- agent/acs/handler/acs_handler_test.go | 292 ++++++--- .../acs/handler/attach_eni_handler_common.go | 2 +- .../handler/attach_instance_eni_handler.go | 2 +- .../attach_instance_eni_handler_test.go | 2 +- agent/acs/handler/attach_task_eni_handler.go | 2 +- .../handler/attach_task_eni_handler_test.go | 2 +- agent/acs/handler/heartbeat_handler.go | 2 +- agent/acs/handler/heartbeat_handler_test.go | 2 +- agent/acs/handler/payload_handler.go | 2 +- agent/acs/handler/payload_handler_test.go | 2 +- .../handler/refresh_credentials_handler.go | 2 +- .../refresh_credentials_handler_test.go | 2 +- agent/acs/handler/task_manifest_handler.go | 2 +- .../acs/handler/task_manifest_handler_test.go | 2 +- agent/acs/update_handler/updater.go | 2 +- agent/acs/update_handler/updater_test.go | 2 +- agent/app/agent.go | 2 + .../ecs-agent}/acs/client/acs_client.go | 24 +- .../ecs-agent}/acs/client/acs_client_types.go | 3 +- .../ecs-agent}/acs/client/acs_error.go | 2 +- .../ecs-agent/wsclient/client.go | 574 ++++++++++++++++++ .../ecs-agent/wsclient/client_factory.go | 13 + .../ecs-agent/wsclient/decode.go | 101 +++ .../ecs-agent/wsclient/error.go | 120 ++++ .../ecs-agent/wsclient/generate_mocks.go | 3 + .../ecs-agent/wsclient/mock/client.go | 319 ++++++++++ .../ecs-agent/wsclient/types.go | 59 ++ .../ecs-agent/wsclient/wsconn/conn.go | 27 + .../wsclient/wsconn/generate_mocks.go | 16 + agent/vendor/modules.txt | 4 + ecs-agent/acs/client/acs_client.go | 73 +++ .../acs/client/acs_client_test.go | 28 +- ecs-agent/acs/client/acs_client_types.go | 60 ++ ecs-agent/acs/client/acs_error.go | 52 ++ .../acs/client/acs_error_test.go | 0 36 files changed, 1694 insertions(+), 234 deletions(-) rename agent/{ => vendor/github.com/aws/amazon-ecs-agent/ecs-agent}/acs/client/acs_client.go (73%) rename agent/{ => vendor/github.com/aws/amazon-ecs-agent/ecs-agent}/acs/client/acs_client_types.go (95%) rename agent/{ => vendor/github.com/aws/amazon-ecs-agent/ecs-agent}/acs/client/acs_error.go (97%) create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client_factory.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/decode.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/error.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/generate_mocks.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock/client.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/types.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/conn.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/generate_mocks.go create mode 100644 ecs-agent/acs/client/acs_client.go rename {agent => ecs-agent}/acs/client/acs_client_test.go (94%) create mode 100644 ecs-agent/acs/client/acs_client_types.go create mode 100644 ecs-agent/acs/client/acs_error.go rename {agent => ecs-agent}/acs/client/acs_error_test.go (100%) diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index d74fdb98621..df205b6f4c2 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -24,7 +24,6 @@ import ( "sync" "time" - acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client" updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler" "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/config" @@ -35,11 +34,11 @@ import ( "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/agent/eventstream" "github.com/aws/amazon-ecs-agent/agent/version" - "github.com/aws/amazon-ecs-agent/agent/wsclient" rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/cihub/seelog" @@ -105,7 +104,8 @@ type session struct { ctx context.Context cancel context.CancelFunc backoff retry.Backoff - resources sessionResources + clientFactory wsclient.ClientFactory + sendCredentials bool latestSeqNumTaskManifest *int64 doctor *doctor.Doctor _heartbeatTimeout time.Duration @@ -115,43 +115,6 @@ type session struct { _inactiveInstanceReconnectDelay time.Duration } -// sessionResources defines the resource creator interface for starting -// a session with ACS. This interface is intended to define methods -// that create resources used to establish the connection to ACS -// It is confined to just the createACSClient() method for now. It can be -// extended to include the acsWsURL() and newHeartbeatTimer() methods -// when needed -// The goal is to make it easier to test and inject dependencies -type sessionResources interface { - // createACSClient creates a new websocket client - createACSClient(url string, cfg *config.Config) wsclient.ClientServer - sessionState -} - -// acsSessionResources implements resource creator and session state interfaces -// to create resources needed to connect to ACS and to record session state -// for the same -type acsSessionResources struct { - credentialsProvider *credentials.Credentials - // sendCredentials is used to set the 'sendCredentials' URL parameter - // used to connect to ACS - // It is set to 'true' for the very first successful connection on - // agent start. It is set to false for all successive connections - sendCredentials bool -} - -// sessionState defines state recorder interface for the -// session established with ACS. It can be used to record and -// retrieve data shared across multiple connections to ACS -type sessionState interface { - // connectedToACS callback indicates that the client has - // connected to ACS - connectedToACS() - // getSendCredentialsURLParameter retrieves the value for - // the 'sendCredentials' URL parameter - getSendCredentialsURLParameter() string -} - // NewSession creates a new Session object func NewSession( ctx context.Context, @@ -168,8 +131,8 @@ func NewSession( taskHandler *eventhandler.TaskHandler, latestSeqNumTaskManifest *int64, doctor *doctor.Doctor, + clientFactory wsclient.ClientFactory, ) Session { - resources := newSessionResources(credentialsProvider) backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier) derivedContext, cancel := context.WithCancel(ctx) @@ -189,9 +152,10 @@ func NewSession( ctx: derivedContext, cancel: cancel, backoff: backoff, - resources: resources, latestSeqNumTaskManifest: latestSeqNumTaskManifest, doctor: doctor, + clientFactory: clientFactory, + sendCredentials: true, _heartbeatTimeout: heartbeatTimeout, _heartbeatJitter: heartbeatJitter, connectionTime: connectionTime, @@ -257,14 +221,23 @@ func (acsSession *session) Start() error { // startSessionOnce creates a session with ACS and handles requests using the passed // in arguments func (acsSession *session) startSessionOnce() error { + minAgentCfg := &wsclient.WSClientMinAgentConfig{ + AcceptInsecureCert: acsSession.agentConfig.AcceptInsecureCert, + AWSRegion: acsSession.agentConfig.AWSRegion, + } + acsEndpoint, err := acsSession.ecsClient.DiscoverPollEndpoint(acsSession.containerInstanceARN) if err != nil { seelog.Errorf("acs: unable to discover poll endpoint, err: %v", err) return err } - url := acsWsURL(acsEndpoint, acsSession.agentConfig.Cluster, acsSession.containerInstanceARN, acsSession.taskEngine, acsSession.resources) - client := acsSession.resources.createACSClient(url, acsSession.agentConfig) + url := acsSession.acsURL(acsEndpoint) + client := acsSession.clientFactory.New( + url, + acsSession.credentialsProvider, + wsRWTimeout, + minAgentCfg) defer client.Close() return acsSession.startACSSession(client) @@ -371,7 +344,9 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client)) defer heartbeatTimer.Stop() - acsSession.resources.connectedToACS() + // Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence + // and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS. + acsSession.sendCredentials = false backoffResetTimer := time.AfterFunc( retry.AddJitter(acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()), func() { @@ -383,30 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error { }) defer backoffResetTimer.Stop() - serveErr := make(chan error, 1) - go func() { - serveErr <- client.Serve() - }() - - for { - select { - case <-acsSession.ctx.Done(): - // Stop receiving and sending messages from and to ACS when - // the context received from the main function is canceled - seelog.Infof("ACS session exited cleanly.") - return acsSession.ctx.Err() - case err := <-serveErr: - // Stop receiving and sending messages from and to ACS when - // client.Serve returns an error. This can happen when the - // connection is closed by ACS or the agent - if err == nil || err == io.EOF { - seelog.Info("ACS Websocket connection closed for a valid reason") - } else { - seelog.Errorf("Error: lost websocket connection with Agent Communication Service (ACS): %v", err) - } - return err - } - } + return client.Serve(acsSession.ctx) } func (acsSession *session) computeReconnectDelay(isInactiveInstance bool) time.Duration { @@ -438,48 +390,24 @@ func (acsSession *session) heartbeatJitter() time.Duration { return acsSession._heartbeatJitter } -// createACSClient creates the ACS Client using the specified URL -func (acsResources *acsSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer { - return acsclient.New(url, cfg, acsResources.credentialsProvider, wsRWTimeout) -} - -// connectedToACS records a successful connection to ACS -// It sets sendCredentials to false on such an event -func (acsResources *acsSessionResources) connectedToACS() { - acsResources.sendCredentials = false -} - -// getSendCredentialsURLParameter gets the value to be set for the -// 'sendCredentials' URL parameter -func (acsResources *acsSessionResources) getSendCredentialsURLParameter() string { - return strconv.FormatBool(acsResources.sendCredentials) -} - -func newSessionResources(credentialsProvider *credentials.Credentials) sessionResources { - return &acsSessionResources{ - credentialsProvider: credentialsProvider, - sendCredentials: true, - } -} - -// acsWsURL returns the websocket url for ACS given the endpoint -func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.TaskEngine, acsSessionState sessionState) string { +// acsURL returns the websocket url for ACS given the endpoint +func (acsSession *session) acsURL(endpoint string) string { acsURL := endpoint if endpoint[len(endpoint)-1] != '/' { acsURL += "/" } acsURL += "ws" query := url.Values{} - query.Set("clusterArn", cluster) - query.Set("containerInstanceArn", containerInstanceArn) + query.Set("clusterArn", acsSession.agentConfig.Cluster) + query.Set("containerInstanceArn", acsSession.containerInstanceARN) query.Set("agentHash", version.GitHashString()) query.Set("agentVersion", version.Version) query.Set("seqNum", "1") query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion)) - if dockerVersion, err := taskEngine.Version(); err == nil { + if dockerVersion, err := acsSession.taskEngine.Version(); err == nil { query.Set("dockerVersion", "DockerVersion: "+dockerVersion) } - query.Set(sendCredentialsURLParameterName, acsSessionState.getSendCredentialsURLParameter()) + query.Set(sendCredentialsURLParameterName, strconv.FormatBool(acsSession.sendCredentials)) return acsURL + "?" + query.Encode() } diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 03b93beacce..8f1d31f4294 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -42,13 +42,13 @@ import ( "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/agent/eventstream" "github.com/aws/amazon-ecs-agent/agent/version" - "github.com/aws/amazon-ecs-agent/agent/wsclient" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" + acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" mock_retry "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -145,35 +145,26 @@ var testConfig = &config.Config{ var testCreds = credentials.NewStaticCredentials("test-id", "test-secret", "test-token") -type mockSessionResources struct { - client wsclient.ClientServer -} - -func (m *mockSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer { - return m.client -} - -func (m *mockSessionResources) connectedToACS() { -} - -func (m *mockSessionResources) getSendCredentialsURLParameter() string { - return "true" -} - -// TestACSWSURL tests if the URL is constructed correctly when connecting to ACS -func TestACSWSURL(t *testing.T) { +// TestACSURL tests if the URL is constructed correctly when connecting to ACS +func TestACSURL(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := mock_engine.NewMockTaskEngine(ctrl) taskEngine.EXPECT().Version().Return("Docker version result", nil) - wsurl := acsWsURL(acsURL, "myCluster", "myContainerInstance", taskEngine, &mockSessionResources{}) + acsSession := session{ + taskEngine: taskEngine, + sendCredentials: true, + agentConfig: testConfig, + containerInstanceARN: "myContainerInstance", + } + wsurl := acsSession.acsURL(acsURL) parsed, err := url.Parse(wsurl) assert.NoError(t, err, "should be able to parse URL") assert.Equal(t, "/ws", parsed.Path, "wrong path") - assert.Equal(t, "myCluster", parsed.Query().Get("clusterArn"), "wrong cluster") + assert.Equal(t, "someCluster", parsed.Query().Get("clusterArn"), "wrong cluster") assert.Equal(t, "myContainerInstance", parsed.Query().Get("containerInstanceArn"), "wrong container instance") assert.Equal(t, version.Version, parsed.Query().Get("agentVersion"), "wrong agent version") assert.Equal(t, version.GitHashString(), parsed.Query().Get("agentHash"), "wrong agent hash") @@ -199,9 +190,13 @@ func TestHandlerReconnectsOnConnectErrors(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Serve().AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() gomock.InOrder( @@ -225,7 +220,7 @@ func TestHandlerReconnectsOnConnectErrors(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -338,6 +333,10 @@ func TestHandlerReconnectsWithoutBackoffOnEOFError(t *testing.T) { mockBackoff := mock_retry.NewMockBackoff(ctrl) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() @@ -366,7 +365,7 @@ func TestHandlerReconnectsWithoutBackoffOnEOFError(t *testing.T) { backoff: mockBackoff, ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, latestSeqNumTaskManifest: aws.Int64(10), _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, @@ -404,6 +403,10 @@ func TestHandlerReconnectsWithBackoffOnNonEOFError(t *testing.T) { mockBackoff := mock_retry.NewMockBackoff(ctrl) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() @@ -432,7 +435,7 @@ func TestHandlerReconnectsWithBackoffOnNonEOFError(t *testing.T) { backoff: mockBackoff, ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -475,6 +478,10 @@ func TestHandlerGeneratesDeregisteredInstanceEvent(t *testing.T) { assert.NoError(t, err, "Error adding deregister instance event stream subscriber") deregisterInstanceEventStream.StartListening() mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() @@ -493,7 +500,7 @@ func TestHandlerGeneratesDeregisteredInstanceEvent(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -530,6 +537,10 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) { // deregisterInstanceEventStream.StartListening() mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() @@ -565,7 +576,7 @@ func TestHandlerReconnectDelayForInactiveInstanceError(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -597,6 +608,10 @@ func TestHandlerReconnectsOnServeErrors(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() @@ -604,11 +619,11 @@ func TestHandlerReconnectsOnServeErrors(t *testing.T) { mockWsClient.EXPECT().Close().Return(nil).AnyTimes() gomock.InOrder( // Serve fails 10 times - mockWsClient.EXPECT().Serve().Return(io.EOF).Times(10), + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).Times(10), // Cancel trying to Serve ACS requests on the 11th attempt // Failure to retry on Serve() errors should cause the // test to time out as the context is never cancelled - mockWsClient.EXPECT().Serve().Do(func() { + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { cancel() }), ) @@ -624,7 +639,7 @@ func TestHandlerReconnectsOnServeErrors(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -655,14 +670,18 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() gomock.InOrder( - mockWsClient.EXPECT().Serve().Return(io.EOF), - mockWsClient.EXPECT().Serve().Do(func() { + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF), + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { cancel() }).Return(errors.New("InactiveInstanceException")), ) @@ -677,7 +696,7 @@ func TestHandlerStopsWhenContextIsCancelled(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -709,12 +728,16 @@ func TestHandlerStopsWhenContextIsError(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve().Do(func() { + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { time.Sleep(5 * time.Millisecond) }).Return(io.EOF).AnyTimes() @@ -729,7 +752,7 @@ func TestHandlerStopsWhenContextIsError(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, } @@ -759,12 +782,16 @@ func TestHandlerStopsWhenContextIsErrorReconnectDelay(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve().Return(errors.New("InactiveInstanceException")).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Return(errors.New("InactiveInstanceException")).AnyTimes() acsSession := session{ containerInstanceARN: "myArn", @@ -777,7 +804,7 @@ func TestHandlerStopsWhenContextIsErrorReconnectDelay(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, _inactiveInstanceReconnectDelay: 1 * time.Hour, @@ -806,9 +833,13 @@ func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) { taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() - mockWsClient.EXPECT().Serve().AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() mockWsClient.EXPECT().Connect().Do(func() { @@ -833,7 +864,7 @@ func TestHandlerReconnectsOnDiscoverPollEndpointError(t *testing.T) { backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), ctx: ctx, cancel: cancel, - resources: &mockSessionResources{mockWsClient}, + clientFactory: mockClientFactory, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -882,7 +913,7 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil) - mockWsClient.EXPECT().Serve().Do(func() { + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { wait.Done() // Pretend as if the maximum heartbeatTimeout duration has // been breached while Serving requests @@ -905,7 +936,6 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { taskHandler: taskHandler, ctx: context.Background(), backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - resources: &mockSessionResources{}, _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 30 * time.Millisecond, @@ -935,7 +965,7 @@ func TestConnectionIsClosedAfterTimeIsUp(t *testing.T) { mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).Do(func(v interface{}) {}).AnyTimes() mockWsClient.EXPECT().Connect().Return(nil) - mockWsClient.EXPECT().Serve().Do(func() { + mockWsClient.EXPECT().Serve(gomock.Any()).Do(func(interface{}) { // pretend as if the connectionTime has elapsed time.Sleep(30 * time.Millisecond) cancel() @@ -955,7 +985,6 @@ func TestConnectionIsClosedAfterTimeIsUp(t *testing.T) { taskHandler: taskHandler, ctx: context.Background(), backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - resources: &mockSessionResources{}, _heartbeatTimeout: 50 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, connectionTime: 20 * time.Millisecond, @@ -1028,9 +1057,9 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) { dataClient: data.NewNoopClient(), taskHandler: taskHandler, ctx: ctx, + clientFactory: acsclient.NewACSClientFactory(), _heartbeatTimeout: 1 * time.Second, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - resources: newSessionResources(testCreds), credentialsManager: rolecredentials.NewManager(), latestSeqNumTaskManifest: aws.Int64(12), doctor: emptyDoctor, @@ -1120,6 +1149,7 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { taskHandler, &latestSeqNumberTaskManifest, emptyDoctor, + acsclient.NewACSClientFactory(), ) acsSession.Start() // StartSession should never return unless the context is canceled @@ -1176,77 +1206,68 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { <-ended } -// TestACSSessionResourcesCorrectlySetsSendCredentials tests if acsSessionResources -// struct correctly sets 'sendCredentials' -func TestACSSessionResourcesCorrectlySetsSendCredentials(t *testing.T) { - acsResources := newSessionResources(nil) - // Validate that 'sendCredentials' is set to true on create - sendCredentials := acsResources.getSendCredentialsURLParameter() - if sendCredentials != "true" { - t.Errorf("Mismatch in sendCredentials URL parameter value, expected: 'true', got: %s", sendCredentials) - } - // Simulate a successful connection to ACS - acsResources.connectedToACS() - // On successful connection to ACS, 'sendCredentials' should be set to false - sendCredentials = acsResources.getSendCredentialsURLParameter() - if sendCredentials != "false" { - t.Errorf("Mismatch in sendCredentials URL parameter value, expected: 'false', got: %s", sendCredentials) - } -} - -// TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter tests if -// the 'sendCredentials' URL parameter is set correctly for successive -// invocations of startACSSession -func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) { +// TestHandlerCorrectlySetsSendCredentials tests if 'sendCredentials' +// is set correctly for successive invocations of startACSSession +func TestHandlerCorrectlySetsSendCredentials(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() taskEngine := mock_engine.NewMockTaskEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) + deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) + deregisterInstanceEventStream.StartListening() + dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) + emptyHealthchecksList := []doctor.Healthcheck{} + emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "test-cluster", "this:is:an:instance:arn") mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockClientFactory.EXPECT(). + New(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient).AnyTimes() mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() mockWsClient.EXPECT().Close().Return(nil).AnyTimes() - mockWsClient.EXPECT().Serve().Return(io.EOF).AnyTimes() - - dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - resources := newSessionResources(testCreds) + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() + + acsSession := NewSession( + ctx, + testConfig, + deregisterInstanceEventStream, + "myArn", + testCreds, + dockerClient, + ecsClient, + dockerstate.NewTaskEngineState(), + data.NewNoopClient(), + taskEngine, + rolecredentials.NewManager(), + taskHandler, + aws.Int64(10), + emptyDoctor, + mockClientFactory) + acsSession.(*session)._heartbeatTimeout = 20 * time.Millisecond + acsSession.(*session)._heartbeatJitter = 10 * time.Millisecond + acsSession.(*session).connectionTime = 30 * time.Millisecond + acsSession.(*session).connectionJitter = 10 * time.Millisecond gomock.InOrder( // When the websocket client connects to ACS for the first // time, 'sendCredentials' should be set to true mockWsClient.EXPECT().Connect().Do(func() { - validateSendCredentialsInSession(t, resources, "true") + assert.Equal(t, true, acsSession.(*session).sendCredentials) }).Return(nil), // For all subsequent connections to ACS, 'sendCredentials' // should be set to false mockWsClient.EXPECT().Connect().Do(func() { - validateSendCredentialsInSession(t, resources, "false") + assert.Equal(t, false, acsSession.(*session).sendCredentials) }).Return(nil).AnyTimes(), ) - acsSession := session{ - containerInstanceARN: "myArn", - credentialsProvider: testCreds, - agentConfig: testConfig, - taskEngine: taskEngine, - dockerClient: dockerClient, - ecsClient: ecsClient, - dataClient: data.NewNoopClient(), - taskHandler: taskHandler, - ctx: ctx, - resources: resources, - backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), - _heartbeatTimeout: 20 * time.Millisecond, - _heartbeatJitter: 10 * time.Millisecond, - connectionTime: 30 * time.Millisecond, - connectionJitter: 10 * time.Millisecond, - } go func() { for i := 0; i < 10; i++ { - acsSession.startACSSession(mockWsClient) + acsSession.(*session).startACSSession(mockWsClient) } cancel() }() @@ -1257,6 +1278,90 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) } } +// TestHandlerReconnectCorrectlySetsAcsUrl tests if the ACS URL +// is set correctly for the initial connection and subsequent connections +func TestHandlerReconnectCorrectlySetsAcsUrl(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + dockerVerStr := "1.5.0" + taskEngine := mock_engine.NewMockTaskEngine(ctrl) + taskEngine.EXPECT().Version().Return(fmt.Sprintf("Docker: %s", dockerVerStr), nil).AnyTimes() + ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ctx, cancel := context.WithCancel(context.Background()) + taskHandler := eventhandler.NewTaskHandler(ctx, data.NewNoopClient(), nil, nil) + deregisterInstanceEventStream := eventstream.NewEventStream("DeregisterContainerInstance", ctx) + deregisterInstanceEventStream.StartListening() + dockerClient := mock_dockerapi.NewMockDockerClient(ctrl) + emptyHealthchecksList := []doctor.Healthcheck{} + emptyDoctor, _ := doctor.NewDoctor(emptyHealthchecksList, "test-cluster", "this:is:an:instance:arn") + + mockBackoff := mock_retry.NewMockBackoff(ctrl) + mockWsClient := mock_wsclient.NewMockClientServer(ctrl) + mockClientFactory := mock_wsclient.NewMockClientFactory(ctrl) + mockWsClient.EXPECT().SetAnyRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().AddRequestHandler(gomock.Any()).AnyTimes() + mockWsClient.EXPECT().WriteCloseMessage().AnyTimes() + mockWsClient.EXPECT().Close().Return(nil).AnyTimes() + mockWsClient.EXPECT().Serve(gomock.Any()).Return(io.EOF).AnyTimes() + + // On the initial connection, sendCredentials must be true because Agent forces ACS to send credentials. + initialAcsURL := fmt.Sprintf( + "http://endpoint.tld/ws?agentHash=%s&agentVersion=%s&clusterArn=%s&containerInstanceArn=%s&"+ + "dockerVersion=DockerVersion%%3A+Docker%%3A+%s&protocolVersion=%v&sendCredentials=true&seqNum=1", + version.GitShortHash, version.Version, testConfig.Cluster, "myArn", dockerVerStr, acsProtocolVersion) + + // But after that, ACS sends credentials at ACS's own cadence, so sendCredentials must be false. + subsequentAcsURL := fmt.Sprintf( + "http://endpoint.tld/ws?agentHash=%s&agentVersion=%s&clusterArn=%s&containerInstanceArn=%s&"+ + "dockerVersion=DockerVersion%%3A+Docker%%3A+%s&protocolVersion=%v&sendCredentials=false&seqNum=1", + version.GitShortHash, version.Version, testConfig.Cluster, "myArn", dockerVerStr, acsProtocolVersion) + + gomock.InOrder( + mockClientFactory.EXPECT(). + New(initialAcsURL, gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient), + mockWsClient.EXPECT().Connect().Return(nil), + mockBackoff.EXPECT().Reset(), + mockClientFactory.EXPECT(). + New(subsequentAcsURL, gomock.Any(), gomock.Any(), gomock.Any()). + Return(mockWsClient), + mockWsClient.EXPECT().Connect().Do(func() { + cancel() + }).Return(nil), + ) + acsSession := NewSession( + ctx, + testConfig, + deregisterInstanceEventStream, + "myArn", + testCreds, + dockerClient, + ecsClient, + dockerstate.NewTaskEngineState(), + data.NewNoopClient(), + taskEngine, + rolecredentials.NewManager(), + taskHandler, + aws.Int64(10), + emptyDoctor, + mockClientFactory) + acsSession.(*session).backoff = mockBackoff + acsSession.(*session)._heartbeatTimeout = 20 * time.Millisecond + acsSession.(*session)._heartbeatJitter = 10 * time.Millisecond + acsSession.(*session).connectionTime = 30 * time.Millisecond + acsSession.(*session).connectionJitter = 10 * time.Millisecond + + go func() { + acsSession.Start() + }() + + // Wait for context to be cancelled + select { + case <-ctx.Done(): + } +} + // TODO: replace with gomock func startMockAcsServer(t *testing.T, closeWS <-chan bool) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { serverChan := make(chan string, 1) @@ -1341,10 +1446,3 @@ func validateAddedContainer(expectedContainer *apicontainer.Container, addedCont } return nil } - -func validateSendCredentialsInSession(t *testing.T, state sessionState, expected string) { - sendCredentials := state.getSendCredentialsURLParameter() - if sendCredentials != expected { - t.Errorf("Incorrect value set for sendCredentials, expected: %s, got: %s", expected, sendCredentials) - } -} diff --git a/agent/acs/handler/attach_eni_handler_common.go b/agent/acs/handler/attach_eni_handler_common.go index cf025fa9cb4..656555d17e3 100644 --- a/agent/acs/handler/attach_eni_handler_common.go +++ b/agent/acs/handler/attach_eni_handler_common.go @@ -22,10 +22,10 @@ import ( "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" "github.com/aws/amazon-ecs-agent/agent/utils" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/arn" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" diff --git a/agent/acs/handler/attach_instance_eni_handler.go b/agent/acs/handler/attach_instance_eni_handler.go index bdda2872926..3c1019f0a31 100644 --- a/agent/acs/handler/attach_instance_eni_handler.go +++ b/agent/acs/handler/attach_instance_eni_handler.go @@ -16,12 +16,12 @@ package handler import ( "time" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" "github.com/aws/amazon-ecs-agent/ecs-agent/api/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" diff --git a/agent/acs/handler/attach_instance_eni_handler_test.go b/agent/acs/handler/attach_instance_eni_handler_test.go index 387fd2c221c..0ade8f44c0d 100644 --- a/agent/acs/handler/attach_instance_eni_handler_test.go +++ b/agent/acs/handler/attach_instance_eni_handler_test.go @@ -25,10 +25,10 @@ import ( "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" diff --git a/agent/acs/handler/attach_task_eni_handler.go b/agent/acs/handler/attach_task_eni_handler.go index ced295d4ca0..5c85a32c6fd 100644 --- a/agent/acs/handler/attach_task_eni_handler.go +++ b/agent/acs/handler/attach_task_eni_handler.go @@ -16,12 +16,12 @@ package handler import ( "time" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" "github.com/aws/amazon-ecs-agent/ecs-agent/api/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" diff --git a/agent/acs/handler/attach_task_eni_handler_test.go b/agent/acs/handler/attach_task_eni_handler_test.go index b6a94774f57..075dbc2f801 100644 --- a/agent/acs/handler/attach_task_eni_handler_test.go +++ b/agent/acs/handler/attach_task_eni_handler_test.go @@ -25,10 +25,10 @@ import ( "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" diff --git a/agent/acs/handler/heartbeat_handler.go b/agent/acs/handler/heartbeat_handler.go index ad82cdc1790..c6d76789180 100644 --- a/agent/acs/handler/heartbeat_handler.go +++ b/agent/acs/handler/heartbeat_handler.go @@ -14,9 +14,9 @@ package handler import ( - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" ) diff --git a/agent/acs/handler/heartbeat_handler_test.go b/agent/acs/handler/heartbeat_handler_test.go index 3dc2bbe518c..8dfa86980d3 100644 --- a/agent/acs/handler/heartbeat_handler_test.go +++ b/agent/acs/handler/heartbeat_handler_test.go @@ -19,9 +19,9 @@ package handler import ( "testing" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" diff --git a/agent/acs/handler/payload_handler.go b/agent/acs/handler/payload_handler.go index 513677f081a..d606a5c71d1 100644 --- a/agent/acs/handler/payload_handler.go +++ b/agent/acs/handler/payload_handler.go @@ -26,10 +26,10 @@ import ( "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine" "github.com/aws/amazon-ecs-agent/agent/eventhandler" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" diff --git a/agent/acs/handler/payload_handler_test.go b/agent/acs/handler/payload_handler_test.go index c44d943be5b..4f3b0ae2a96 100644 --- a/agent/acs/handler/payload_handler_test.go +++ b/agent/acs/handler/payload_handler_test.go @@ -33,10 +33,10 @@ import ( mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/agent/taskresource" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" diff --git a/agent/acs/handler/refresh_credentials_handler.go b/agent/acs/handler/refresh_credentials_handler.go index ab9a4add851..db49c8acc65 100644 --- a/agent/acs/handler/refresh_credentials_handler.go +++ b/agent/acs/handler/refresh_credentials_handler.go @@ -18,9 +18,9 @@ import ( "context" "github.com/aws/amazon-ecs-agent/agent/engine" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" ) diff --git a/agent/acs/handler/refresh_credentials_handler_test.go b/agent/acs/handler/refresh_credentials_handler_test.go index 37163f743cb..fbb18c507ff 100644 --- a/agent/acs/handler/refresh_credentials_handler_test.go +++ b/agent/acs/handler/refresh_credentials_handler_test.go @@ -24,9 +24,9 @@ import ( apitask "github.com/aws/amazon-ecs-agent/agent/api/task" mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" diff --git a/agent/acs/handler/task_manifest_handler.go b/agent/acs/handler/task_manifest_handler.go index 2e6209f4705..58787495f66 100644 --- a/agent/acs/handler/task_manifest_handler.go +++ b/agent/acs/handler/task_manifest_handler.go @@ -20,9 +20,9 @@ import ( apitaskstatus "github.com/aws/amazon-ecs-agent/agent/api/task/status" "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" ) diff --git a/agent/acs/handler/task_manifest_handler_test.go b/agent/acs/handler/task_manifest_handler_test.go index 697bba3600b..225822e85a4 100644 --- a/agent/acs/handler/task_manifest_handler_test.go +++ b/agent/acs/handler/task_manifest_handler_test.go @@ -26,8 +26,8 @@ import ( apitaskstatus "github.com/aws/amazon-ecs-agent/agent/api/task/status" "github.com/aws/amazon-ecs-agent/agent/data" mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" - mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + mock_wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" diff --git a/agent/acs/update_handler/updater.go b/agent/acs/update_handler/updater.go index 110e5236b6d..05830bedac5 100644 --- a/agent/acs/update_handler/updater.go +++ b/agent/acs/update_handler/updater.go @@ -34,9 +34,9 @@ import ( "github.com/aws/amazon-ecs-agent/agent/sighandlers" "github.com/aws/amazon-ecs-agent/agent/sighandlers/exitcodes" "github.com/aws/amazon-ecs-agent/agent/utils" - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" diff --git a/agent/acs/update_handler/updater_test.go b/agent/acs/update_handler/updater_test.go index 68ef2c2e128..b93a2d74cff 100644 --- a/agent/acs/update_handler/updater_test.go +++ b/agent/acs/update_handler/updater_test.go @@ -34,8 +34,8 @@ import ( "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" "github.com/aws/amazon-ecs-agent/agent/httpclient" mock_http "github.com/aws/amazon-ecs-agent/agent/httpclient/mock" - mock_client "github.com/aws/amazon-ecs-agent/agent/wsclient/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + mock_client "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" diff --git a/agent/app/agent.go b/agent/app/agent.go index e785b5007f2..b435d60362e 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -62,6 +62,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/utils/loader" "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper" "github.com/aws/amazon-ecs-agent/agent/version" + acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" @@ -943,6 +944,7 @@ func (agent *ecsAgent) startACSSession( taskHandler, agent.latestSeqNumberTaskManifest, doctor, + acsclient.NewACSClientFactory(), ) seelog.Info("Beginning Polling for updates") err := acsSession.Start() diff --git a/agent/acs/client/acs_client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_client.go similarity index 73% rename from agent/acs/client/acs_client.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_client.go index 574eb5a557d..35edea2c3a7 100644 --- a/agent/acs/client/acs_client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_client.go @@ -19,13 +19,13 @@ package acsclient import ( + "context" "errors" "time" - "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/wsclient" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/cihub/seelog" ) // clientServer implements ClientServer for acs. @@ -33,14 +33,22 @@ type clientServer struct { wsclient.ClientServerImpl } +type acsClientFactory struct{} + +// NewACSClientFactory creates a new ACS client factory object. This can be +// used to create new ACS clients. +func NewACSClientFactory() wsclient.ClientFactory { + return &acsClientFactory{} +} + // New returns a client/server to bidirectionally communicate with ACS // The returned struct should have both 'Connect' and 'Serve' called upon it // before being used. -func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials, rwTimeout time.Duration) wsclient.ClientServer { +func (*acsClientFactory) New(url string, credentialProvider *credentials.Credentials, rwTimeout time.Duration, cfg *wsclient.WSClientMinAgentConfig) wsclient.ClientServer { cs := &clientServer{} cs.URL = url cs.CredentialProvider = credentialProvider - cs.AgentConfig = cfg + cs.Cfg = cfg cs.ServiceError = &acsError{} cs.RequestHandlers = make(map[string]wsclient.RequestHandler) cs.TypeDecoder = NewACSDecoder() @@ -51,12 +59,12 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent // Serve begins serving requests using previously registered handlers (see // AddRequestHandler). All request handlers should be added prior to making this // call as unhandled requests will be discarded. -func (cs *clientServer) Serve() error { - seelog.Debug("ACS client starting websocket poll loop") +func (cs *clientServer) Serve(ctx context.Context) error { + logger.Debug("ACS client starting websocket poll loop") if !cs.IsReady() { return errors.New("acs client: websocket not ready for connections") } - return cs.ConsumeMessages() + return cs.ConsumeMessages(ctx) } // Close closes the underlying connection diff --git a/agent/acs/client/acs_client_types.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_client_types.go similarity index 95% rename from agent/acs/client/acs_client_types.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_client_types.go index 2bfb5a6dc68..9f86ae16047 100644 --- a/agent/acs/client/acs_client_types.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_client_types.go @@ -14,8 +14,8 @@ package acsclient import ( - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" ) var acsRecognizedTypes []interface{} @@ -48,6 +48,7 @@ func init() { ecsacs.ErrorMessage{}, ecsacs.AttachTaskNetworkInterfacesMessage{}, ecsacs.AttachInstanceNetworkInterfacesMessage{}, + ecsacs.ConfirmAttachmentMessage{}, ecsacs.TaskManifestMessage{}, ecsacs.TaskStopVerificationAck{}, ecsacs.TaskStopVerificationMessage{}, diff --git a/agent/acs/client/acs_error.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_error.go similarity index 97% rename from agent/acs/client/acs_error.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_error.go index 2a1790bb076..99d7072c6d9 100644 --- a/agent/acs/client/acs_error.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/client/acs_error.go @@ -14,8 +14,8 @@ package acsclient import ( - "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" ) const errType = "ACSError" diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go new file mode 100644 index 00000000000..3e029a8a404 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go @@ -0,0 +1,574 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +// Package wsclient wraps the generated aws-sdk-go client to provide marshalling +// and unmarshalling of data over a websocket connection in the format expected +// by backend. It allows for bidirectional communication and acts as both a +// client-and-server in terms of requests, but only as a client in terms of +// connecting. +package wsclient + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "reflect" + "strings" + "sync" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/cipher" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/httpproxy" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" + "github.com/gorilla/websocket" + "github.com/pkg/errors" +) + +const ( + // ServiceName defines the service name for the agent. This is used to sign messages + // that are sent to the backend. + ServiceName = "ecs" + + // wsConnectTimeout specifies the default connection timeout to the backend. + wsConnectTimeout = 30 * time.Second + + // wsHandshakeTimeout specifies the default handshake timeout for the websocket client + wsHandshakeTimeout = wsConnectTimeout + + // readBufSize is the size of the read buffer for the ws connection. + readBufSize = 4096 + + // writeBufSize is the size of the write buffer for the ws connection. + writeBufSize = 32768 + + // Default NO_PROXY env var IP addresses + defaultNoProxyIP = "169.254.169.254,169.254.170.2" + + errClosed = "use of closed network connection" + + // ExitTerminal indicates the agent run into error that's not recoverable + // no need to restart + ExitTerminal = 5 +) + +// ReceivedMessage is the intermediate message used to unmarshal a +// message from backend +type ReceivedMessage struct { + Type string `json:"type"` + Message json.RawMessage `json:"message"` +} + +// RequestMessage is the intermediate message marshalled to send to backend. +type RequestMessage struct { + Type string `json:"type"` + Message json.RawMessage `json:"message"` +} + +// RequestHandler would be func(*ecsacs.T for T in ecsacs.*) to be more proper, but it needs +// to be interface{} to properly capture that +type RequestHandler interface{} + +// RequestResponder wraps the RequestHandler interface with a Respond() +// method that can be used to Respond to requests read and processed via +// the RequestHandler interface for a particular message type. +// +// Example: +// +// type payloadMessageDispatcher struct { +// respond func(interface{}) error +// dispatcher actor.Dispatcher +// } +// func(d *payloadmessagedispatcher) RegisterResponder(respond func(interface{}) error) error { +// d.respond = respond +// return nil +// } +// func(d *payloadmessagedispatcher) HandlerFunc() RequestHandler { +// return func(payload *ecsacs.PayloadMessage) { +// message := &actor.DispatcherMessage{ +// Payload: payload, +// AckFunc: func() error { +// return d.respond() +// }, +// ... +// } +// d.dispatcher.Send(message) +// } +// } +type RequestResponder interface { + // Name returns the name of the responder. This is used mostly for logging. + Name() string + // RegisterResponder registers a function that can be invoked in response + // to receiving and processing a websocket request message. + RegisterResponder(RespondFunc) + // HandlerFunc returns the RequestHandler callback for a particular + // websocket request message type. + HandlerFunc() RequestHandler +} + +// RespondFunc specifies a function callback that can be used by the +// RequestResponder to respond to requests. +type RespondFunc func(interface{}) error + +// ClientServer is a combined client and server for the backend websocket connection +type ClientServer interface { + AddRequestHandler(RequestHandler) + // SetAnyRequestHandler takes a function with the signature 'func(i + // interface{})' and calls it with every message the server passes down. + // Only a single 'AnyRequestHandler' will be active at a given time for a + // ClientServer + SetAnyRequestHandler(RequestHandler) + MakeRequest(input interface{}) error + WriteMessage(input []byte) error + WriteCloseMessage() error + Connect() error + IsConnected() bool + SetConnection(conn wsconn.WebsocketConn) + Disconnect(...interface{}) error + Serve(ctx context.Context) error + SetReadDeadline(t time.Time) error + io.Closer +} + +// WSClientMinAgentConfig is a subset of agent's config. +type WSClientMinAgentConfig struct { + AWSRegion string + AcceptInsecureCert bool + DockerEndpoint string + IsDocker bool +} + +// ClientServerImpl wraps commonly used methods defined in ClientServer interface. +type ClientServerImpl struct { + // Cfg is the subset of user-specified runtime configuration + Cfg *WSClientMinAgentConfig + // conn holds the underlying low-level websocket connection + conn wsconn.WebsocketConn + // CredentialProvider is used to retrieve AWS credentials + CredentialProvider *credentials.Credentials + // RequestHandlers is a map from message types to handler functions of the + // form: + // "FooMessage": func(message *ecsacs.FooMessage) + RequestHandlers map[string]RequestHandler + // AnyRequestHandler is a request handler that, if set, is called on every + // message with said message. It will be called before a RequestHandler is + // called. It must take a single interface{} argument. + AnyRequestHandler RequestHandler + // MakeRequestHook is an optional callback that, if set, is called on every + // generated request with the raw request body. + MakeRequestHook MakeRequestHookFunc + // URL is the full url to the backend, including path, querystring, and so on. + URL string + // RWTimeout is the duration used for setting read and write deadlines + // for the websocket connection + RWTimeout time.Duration + // writeLock needed to ensure that only one routine is writing to the socket + writeLock sync.RWMutex + ClientServer + ServiceError + TypeDecoder +} + +// MakeRequestHookFunc is a function that is invoked on every generated request +// with the raw request body. MakeRequestHookFunc must return either the body +// to send or an error. +type MakeRequestHookFunc func([]byte) ([]byte, error) + +// Connect opens a connection to the backend and upgrades it to a websocket. Calls to +// 'MakeRequest' can be made after calling this, but responses will not be +// receivable until 'Serve' is also called. +func (cs *ClientServerImpl) Connect() error { + logger.Info("Establishing a Websocket connection", logger.Fields{ + "url": cs.URL, + }) + parsedURL, err := url.Parse(cs.URL) + if err != nil { + return err + } + + wsScheme, err := websocketScheme(parsedURL.Scheme) + if err != nil { + return err + } + parsedURL.Scheme = wsScheme + + // NewRequest never returns an error if the url parses and we just verified + // it did above + request, _ := http.NewRequest("GET", parsedURL.String(), nil) + + // Sign the request; we'll send its headers via the websocket client which includes the signature + err = utils.SignHTTPRequest(request, cs.Cfg.AWSRegion, ServiceName, cs.CredentialProvider, nil) + if err != nil { + return err + } + + timeoutDialer := &net.Dialer{Timeout: wsConnectTimeout} + tlsConfig := &tls.Config{ServerName: parsedURL.Host, InsecureSkipVerify: cs.Cfg.AcceptInsecureCert, MinVersion: tls.VersionTLS12} + + //TODO: In order to get rid of the check - + // 1. Remove the hardcoded cipher suites, and rely on default by tls package + // 2. NO_PROXY should be set as part of config check or in init somewhere. Wsclient is not the right place. + if cs.Cfg.IsDocker { + + cipher.WithSupportedCipherSuites(tlsConfig) + + // Ensure that NO_PROXY gets set + noProxy := os.Getenv("NO_PROXY") + if noProxy == "" { + dockerHost, err := url.Parse(cs.Cfg.DockerEndpoint) + if err == nil { + dockerHost.Scheme = "" + os.Setenv("NO_PROXY", fmt.Sprintf("%s,%s", defaultNoProxyIP, dockerHost.String())) + logger.Info(fmt.Sprintf("NO_PROXY is set: %s", os.Getenv("NO_PROXY"))) + } else { + logger.Error("NO_PROXY unable to be set: the configured Docker endpoint is invalid.") + } + } + } + + dialer := websocket.Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + TLSClientConfig: tlsConfig, + Proxy: httpproxy.Proxy, + NetDial: timeoutDialer.Dial, + HandshakeTimeout: wsHandshakeTimeout, + } + + websocketConn, httpResponse, err := dialer.Dial(parsedURL.String(), request.Header) + if httpResponse != nil { + defer httpResponse.Body.Close() + } + + if err != nil { + var resp []byte + if httpResponse != nil { + var readErr error + resp, readErr = io.ReadAll(httpResponse.Body) + if readErr != nil { + return fmt.Errorf("Unable to read websocket connection: " + readErr.Error() + ", " + err.Error()) + } + // If there's a response, we can try to unmarshal it into one of the + // modeled error types + possibleError, _, decodeErr := DecodeData(resp, cs.TypeDecoder) + if decodeErr == nil { + return cs.NewError(possibleError) + } + } + logger.Warn(fmt.Sprintf("Error creating a websocket client: %v", err)) + return errors.Wrapf(err, "websocket client: unable to dial %s response: %s", + parsedURL.Host, string(resp)) + } + + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + + cs.conn = websocketConn + logger.Debug(fmt.Sprintf("Established a Websocket connection to %s", cs.URL)) + return nil +} + +// IsReady gives a boolean response that informs the caller if the websocket +// connection is fully established. +func (cs *ClientServerImpl) IsReady() bool { + cs.writeLock.RLock() + defer cs.writeLock.RUnlock() + + return cs.conn != nil +} + +// SetConnection passes a websocket connection object into the client. This is used only in +// testing and should be avoided in non-test code. +func (cs *ClientServerImpl) SetConnection(conn wsconn.WebsocketConn) { + cs.conn = conn +} + +// SetReadDeadline sets the read deadline for the websocket connection +// A read timeout results in an io error if there are any outstanding reads +// that exceed the deadline +func (cs *ClientServerImpl) SetReadDeadline(t time.Time) error { + err := cs.conn.SetReadDeadline(t) + if err == nil { + return nil + } + logger.Warn(fmt.Sprintf("Unable to set read deadline for websocket connection: %v for %s", err, cs.URL)) + // If we get connection closed error from SetReadDeadline, break out of the for loop and + // return an error + if opErr, ok := err.(*net.OpError); ok && strings.Contains(opErr.Err.Error(), errClosed) { + logger.Error(fmt.Sprintf("Stopping redundant reads on closed network connection: %s", cs.URL)) + return opErr + } + // An unhandled error has occurred while trying to extend read deadline. + // Try asynchronously closing the connection. We don't want to be blocked on stale connections + // taking too long to close. The flip side is that we might start accumulating stale connections. + // But, that still seems more desirable than waiting for ever for the connection to close + cs.forceCloseConnection() + return err +} + +func (cs *ClientServerImpl) forceCloseConnection() { + closeChan := make(chan error, 1) + go func() { + closeChan <- cs.Close() + }() + ctx, cancel := context.WithTimeout(context.TODO(), wsConnectTimeout) + defer cancel() + select { + case closeErr := <-closeChan: + if closeErr != nil { + logger.Warn(fmt.Sprintf("Unable to close websocket connection: %v for %s", + closeErr, cs.URL)) + } + case <-ctx.Done(): + if ctx.Err() != nil { + logger.Warn(fmt.Sprintf("Context canceled waiting for termination of websocket connection: %v for %s", + ctx.Err(), cs.URL)) + } + } +} + +// Disconnect disconnects the connection +func (cs *ClientServerImpl) Disconnect(...interface{}) error { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + + if cs.conn == nil { + return fmt.Errorf("websocker client: no connection to close") + } + + // Close() in turn results in a an internal flushFrame() call in gorilla + // as the close frame needs to be sent to the server. Set the deadline + // for that as well. + if err := cs.conn.SetWriteDeadline(time.Now().Add(cs.RWTimeout)); err != nil { + logger.Warn(fmt.Sprintf("Unable to set write deadline for websocket connection: %v for %s", err, cs.URL)) + } + return cs.conn.Close() +} + +// AddRequestHandler adds a request handler to this client. +// A request handler *must* be a function taking a single argument, and that +// argument *must* be a pointer to a recognized 'ecsacs' struct. +// E.g. if you desired to handle messages from acs of type 'FooMessage', you +// would pass the following handler in: +// +// func(message *ecsacs.FooMessage) +// +// This function will cause agent exit if the passed in function does not have one pointer +// argument or the argument is not a recognized type. +// Additionally, the request handler will block processing of further messages +// on this connection so it's important that it return quickly. +func (cs *ClientServerImpl) AddRequestHandler(f RequestHandler) { + firstArg := reflect.TypeOf(f).In(0) + firstArgTypeStr := firstArg.Elem().Name() + recognizedTypes := cs.GetRecognizedTypes() + _, ok := recognizedTypes[firstArgTypeStr] + if !ok { + logger.Error(fmt.Sprintf("Invalid Handler. AddRequestHandler called with invalid function; "+ + "argument type not recognized: %v", firstArgTypeStr)) + os.Exit(ExitTerminal) + } + cs.RequestHandlers[firstArgTypeStr] = f +} + +// SetAnyRequestHandler passes a RequestHandler object into the client. +func (cs *ClientServerImpl) SetAnyRequestHandler(f RequestHandler) { + cs.AnyRequestHandler = f +} + +// MakeRequest makes a request using the given input. Note, the input *MUST* be +// a pointer to a valid backend type that this client recognises +func (cs *ClientServerImpl) MakeRequest(input interface{}) error { + send, err := cs.CreateRequestMessage(input) + if err != nil { + return err + } + + if cs.MakeRequestHook != nil { + send, err = cs.MakeRequestHook(send) + if err != nil { + return err + } + } + + // Over the wire we send something like + // {"type":"AckRequest","message":{"messageId":"xyz"}} + return cs.WriteMessage(send) +} + +// WriteMessage wraps the low level websocket write method with a lock +func (cs *ClientServerImpl) WriteMessage(send []byte) error { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + + // It is possible that the client implementing may invoke "WriteMessage" before calling a "Connect". + // It would lead to a nil pointer exception as the cs.conn value will not be set. + // Returning error messages in such cases asking the client to Connect. + if cs.conn == nil { + return errors.New("the connection is currently nil. Please connect and try again.") + } + // This is just future proofing. Ignore the error as the gorilla websocket + // library returns 'nil' anyway for SetWriteDeadline + // https://github.com/gorilla/websocket/blob/4201258b820c74ac8e6922fc9e6b52f71fe46f8d/conn.go#L761 + if err := cs.conn.SetWriteDeadline(time.Now().Add(cs.RWTimeout)); err != nil { + logger.Warn(fmt.Sprintf("Unable to set write deadline for websocket connection: %v for %s", + err, cs.URL)) + } + + return cs.conn.WriteMessage(websocket.TextMessage, send) +} + +// WriteCloseMessage wraps the low level websocket WriteControl method with a lock, and sends a message of type +// CloseMessage (Ref: https://github.com/gorilla/websocket/blob/9111bb834a68b893cebbbaed5060bdbc1d9ab7d2/conn.go#L74) +func (cs *ClientServerImpl) WriteCloseMessage() error { + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + + send := websocket.FormatCloseMessage(websocket.CloseNormalClosure, + "ConnectionExpired: Reconnect to continue") + + return cs.conn.WriteControl(websocket.CloseMessage, send, time.Now().Add(cs.RWTimeout)) +} + +// ConsumeMessages reads messages from the websocket connection and handles read +// messages from an active connection. +func (cs *ClientServerImpl) ConsumeMessages(ctx context.Context) error { + // Since ReadMessage is blocking, we don't want to wait for timeout when context gets cancelled + errChan := make(chan error, 1) + go func() { + for { + if err := cs.SetReadDeadline(time.Now().Add(cs.RWTimeout)); err != nil { + errChan <- err + return + } + messageType, message, err := cs.conn.ReadMessage() + + switch { + case err == nil: + if messageType != websocket.TextMessage { + // maybe not fatal though, we'll try to process it anyways + logger.Error(fmt.Sprintf("Unexpected messageType: %v", messageType)) + } + + cs.handleMessage(message) + + case permissibleCloseCode(err): + logger.Debug(fmt.Sprintf("Connection closed for a valid reason: %s", err)) + errChan <- io.EOF + return + + default: + // Unexpected error occurred + logger.Debug(fmt.Sprintf("Error getting message from ws backend: error: [%v], messageType: [%v] ", + err, messageType)) + errChan <- err + return + } + } + }() + + for { + select { + case <-ctx.Done(): + // Close connection and wait for Read goroutine to finish + _ = cs.Disconnect() + <-errChan + return ctx.Err() + case err := <-errChan: + return err + } + } +} + +// CreateRequestMessage creates the request json message using the given input. +// Note, the input *MUST* be a pointer to a valid backend type that this +// client recognises. +func (cs *ClientServerImpl) CreateRequestMessage(input interface{}) ([]byte, error) { + msg := &RequestMessage{} + + recognizedTypes := cs.GetRecognizedTypes() + for typeStr, typeVal := range recognizedTypes { + if reflect.TypeOf(input) == reflect.PtrTo(typeVal) { + msg.Type = typeStr + break + } + } + if msg.Type == "" { + return nil, &UnrecognizedWSRequestType{reflect.TypeOf(input).String()} + } + messageData, err := jsonutil.BuildJSON(input) + if err != nil { + return nil, &NotMarshallableWSRequest{msg.Type, err} + } + msg.Message = json.RawMessage(messageData) + + send, err := json.Marshal(msg) + if err != nil { + return nil, &NotMarshallableWSRequest{msg.Type, err} + } + return send, nil +} + +// handleMessage dispatches a message to the correct 'requestHandler' for its +// type. If no request handler is found, the message is discarded. +func (cs *ClientServerImpl) handleMessage(data []byte) { + typedMessage, typeStr, err := DecodeData(data, cs.TypeDecoder) + if err != nil { + logger.Warn(fmt.Sprintf("Unable to handle message from backend: %v", err)) + return + } + + logger.Debug(fmt.Sprintf("Received message of type: %s", typeStr)) + + if cs.AnyRequestHandler != nil { + reflect.ValueOf(cs.AnyRequestHandler).Call([]reflect.Value{reflect.ValueOf(typedMessage)}) + } + + if handler, ok := cs.RequestHandlers[typeStr]; ok { + reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(typedMessage)}) + } else { + logger.Info(fmt.Sprintf("No handler for message type: %s %s", typeStr, typedMessage)) + } +} + +func websocketScheme(httpScheme string) (string, error) { + // gorilla/websocket expects the websocket scheme (ws[s]://) + var wsScheme string + switch httpScheme { + case "http": + wsScheme = "ws" + case "https": + wsScheme = "wss" + default: + return "", fmt.Errorf("wsclient: unknown scheme %s", httpScheme) + } + return wsScheme, nil +} + +// See https://github.com/gorilla/websocket/blob/87f6f6a22ebfbc3f89b9ccdc7fddd1b914c095f9/conn.go#L650 +func permissibleCloseCode(err error) bool { + return websocket.IsCloseError(err, + websocket.CloseNormalClosure, // websocket error code 1000 + websocket.CloseAbnormalClosure, // websocket error code 1006 + websocket.CloseGoingAway, // websocket error code 1001 + websocket.CloseInternalServerErr) // websocket error code 1011 +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client_factory.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client_factory.go new file mode 100644 index 00000000000..4e2f5e63af9 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client_factory.go @@ -0,0 +1,13 @@ +package wsclient + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +// ClientFactory interface abstracts the method that creates new ClientServer +// objects. This is helpful when writing unit tests. +type ClientFactory interface { + New(url string, credentialProvider *credentials.Credentials, rwTimeout time.Duration, cfg *WSClientMinAgentConfig) ClientServer +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/decode.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/decode.go new file mode 100644 index 00000000000..b65f827452f --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/decode.go @@ -0,0 +1,101 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +// Package wsclient wraps the generated aws-sdk-go client to provide marshalling +// and unmarshalling of data over a websocket connection in the format expected +// by backend. It allows for bidirectional communication and acts as both a +// client-and-server in terms of requests, but only as a client in terms of +// connecting. +package wsclient + +import ( + "bytes" + "encoding/json" + "reflect" + + "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" +) + +// DecodeData decodes a raw message into its type. E.g. An ACS message of the +// form {"type":"FooMessage","message":{"foo":1}} will be decoded into the +// corresponding *ecsacs.FooMessage type. The type string, "FooMessage", will +// also be returned as a convenience. +func DecodeData(data []byte, dec TypeDecoder) (interface{}, string, error) { + raw := &ReceivedMessage{} + // Delay unmarshal until we know the type + err := json.Unmarshal(data, raw) + if err != nil || raw.Type == "" { + // Unframed messages can be of the {"Type":"Message"} form as well, try + // that. + connErr, connErrType, decodeErr := DecodeConnectionError(data, dec) + if decodeErr == nil && connErrType != "" { + return connErr, connErrType, nil + } + return nil, "", decodeErr + } + + reqMessage, ok := dec.NewOfType(raw.Type) + if !ok { + return nil, raw.Type, &UnrecognizedWSRequestType{raw.Type} + } + err = jsonutil.UnmarshalJSON(reqMessage, bytes.NewReader(raw.Message)) + return reqMessage, raw.Type, err +} + +// DecodeConnectionError decodes some of the connection errors returned by the +// backend. Some differ from the usual ones in that they do not have a 'type' +// and 'message' field, but rather are of the form {"ErrorType":"ErrorMessage"} +func DecodeConnectionError(data []byte, dec TypeDecoder) (interface{}, string, error) { + var acsErr map[string]string + err := json.Unmarshal(data, &acsErr) + if err != nil { + return nil, "", &UndecodableMessage{string(data)} + } + if len(acsErr) != 1 { + return nil, "", &UndecodableMessage{string(data)} + } + var typeStr string + for key := range acsErr { + typeStr = key + } + errType, ok := dec.NewOfType(typeStr) + if !ok { + return nil, typeStr, &UnrecognizedWSRequestType{} + } + + val := reflect.ValueOf(errType) + if val.Kind() != reflect.Ptr { + return nil, typeStr, &UnrecognizedWSRequestType{"Non-pointer kind: " + val.Kind().String()} + } + ret := reflect.New(val.Elem().Type()) + retObj := ret.Elem() + + if retObj.Kind() != reflect.Struct { + return nil, typeStr, &UnrecognizedWSRequestType{"Pointer to non-struct kind: " + retObj.Kind().String()} + } + + msgField := retObj.FieldByName("Message_") + if !msgField.IsValid() { + return nil, typeStr, &UnrecognizedWSRequestType{"Expected error type to have 'Message' field"} + } + if msgField.IsValid() && msgField.CanSet() { + msgStr := acsErr[typeStr] + msgStrVal := reflect.ValueOf(&msgStr) + if !msgStrVal.Type().AssignableTo(msgField.Type()) { + return nil, typeStr, &UnrecognizedWSRequestType{"Type mismatch; 'Message' field must be a *string"} + } + msgField.Set(msgStrVal) + return ret.Interface(), typeStr, nil + } + return nil, typeStr, &UnrecognizedWSRequestType{"Invalid message field; must not be nil"} +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/error.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/error.go new file mode 100644 index 00000000000..3e32178658c --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/error.go @@ -0,0 +1,120 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package wsclient + +import "reflect" + +// UnrecognizedWSRequestType specifies that a given type is not recognized. +// This error is not retriable. +type UnrecognizedWSRequestType struct { + Type string +} + +// Error implements error +func (u *UnrecognizedWSRequestType) Error() string { + return "Could not recognize given argument as a valid type: " + u.Type +} + +// Retry implements Retriable +func (u *UnrecognizedWSRequestType) Retry() bool { + return false +} + +// NotMarshallableWSRequest represents that the given request input could not be +// marshalled +type NotMarshallableWSRequest struct { + Type string + + Err error +} + +// Retry implements Retriable +func (u *NotMarshallableWSRequest) Retry() bool { + return false +} + +// Error implements error +func (u *NotMarshallableWSRequest) Error() string { + ret := "Could not marshal Request" + if u.Type != "" { + ret += " (" + u.Type + ")" + } + return ret + ": " + u.Err.Error() +} + +// UndecodableMessage indicates that a message from the backend could not be decoded +type UndecodableMessage struct { + Msg string +} + +func (u *UndecodableMessage) Error() string { + return "Could not decode message into any expected format: " + u.Msg +} + +// WSUnretriableErrors defines methods to retrieve the list of unretriable +// errors. +type WSUnretriableErrors interface { + Get() []interface{} +} + +// ServiceError defines methods to return new backend specific error objects. +type ServiceError interface { + NewError(err interface{}) *WSError +} + +// WSError wraps all the typed errors that the backend may return +// This will not be needed once the aws-sdk-go generation handles error types +// more cleanly +type WSError struct { + ErrObj interface{} + Type string + WSUnretriableErrors +} + +// Error returns an error string +func (err *WSError) Error() string { + val := reflect.ValueOf(err.ErrObj) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + var typeStr = "Unknown type" + if val.IsValid() { + typeStr = val.Type().Name() + msg := val.FieldByName("Message_") + if msg.IsValid() && msg.CanInterface() { + str, ok := msg.Interface().(*string) + if ok { + if str == nil { + return typeStr + ": null" + } + return typeStr + ": " + *str + } + } + } + + if asErr, ok := err.ErrObj.(error); ok { + return err.Type + ": " + asErr.Error() + } + return err.Type + ": Unknown error (" + typeStr + ")" +} + +// Retry returns true if this error should be considered retriable +func (err *WSError) Retry() bool { + for _, unretriable := range err.Get() { + if reflect.TypeOf(err.ErrObj) == reflect.TypeOf(unretriable) { + return false + } + } + return true +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/generate_mocks.go new file mode 100644 index 00000000000..099243d9962 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/generate_mocks.go @@ -0,0 +1,3 @@ +package wsclient + +//go:generate mockgen -destination=mock/client.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/wsclient ClientServer,RequestResponder,ClientFactory diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock/client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock/client.go new file mode 100644 index 00000000000..c091075eebd --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock/client.go @@ -0,0 +1,319 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/wsclient (interfaces: ClientServer,RequestResponder,ClientFactory) + +// Package mock_wsclient is a generated GoMock package. +package mock_wsclient + +import ( + context "context" + reflect "reflect" + time "time" + + wsclient "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + wsconn "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn" + credentials "github.com/aws/aws-sdk-go/aws/credentials" + gomock "github.com/golang/mock/gomock" +) + +// MockClientServer is a mock of ClientServer interface. +type MockClientServer struct { + ctrl *gomock.Controller + recorder *MockClientServerMockRecorder +} + +// MockClientServerMockRecorder is the mock recorder for MockClientServer. +type MockClientServerMockRecorder struct { + mock *MockClientServer +} + +// NewMockClientServer creates a new mock instance. +func NewMockClientServer(ctrl *gomock.Controller) *MockClientServer { + mock := &MockClientServer{ctrl: ctrl} + mock.recorder = &MockClientServerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientServer) EXPECT() *MockClientServerMockRecorder { + return m.recorder +} + +// AddRequestHandler mocks base method. +func (m *MockClientServer) AddRequestHandler(arg0 wsclient.RequestHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddRequestHandler", arg0) +} + +// AddRequestHandler indicates an expected call of AddRequestHandler. +func (mr *MockClientServerMockRecorder) AddRequestHandler(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRequestHandler", reflect.TypeOf((*MockClientServer)(nil).AddRequestHandler), arg0) +} + +// Close mocks base method. +func (m *MockClientServer) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockClientServerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockClientServer)(nil).Close)) +} + +// Connect mocks base method. +func (m *MockClientServer) Connect() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect") + ret0, _ := ret[0].(error) + return ret0 +} + +// Connect indicates an expected call of Connect. +func (mr *MockClientServerMockRecorder) Connect() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockClientServer)(nil).Connect)) +} + +// Disconnect mocks base method. +func (m *MockClientServer) Disconnect(arg0 ...interface{}) error { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Disconnect", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Disconnect indicates an expected call of Disconnect. +func (mr *MockClientServerMockRecorder) Disconnect(arg0 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockClientServer)(nil).Disconnect), arg0...) +} + +// IsConnected mocks base method. +func (m *MockClientServer) IsConnected() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsConnected") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsConnected indicates an expected call of IsConnected. +func (mr *MockClientServerMockRecorder) IsConnected() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockClientServer)(nil).IsConnected)) +} + +// MakeRequest mocks base method. +func (m *MockClientServer) MakeRequest(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MakeRequest", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// MakeRequest indicates an expected call of MakeRequest. +func (mr *MockClientServerMockRecorder) MakeRequest(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MakeRequest", reflect.TypeOf((*MockClientServer)(nil).MakeRequest), arg0) +} + +// Serve mocks base method. +func (m *MockClientServer) Serve(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Serve", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Serve indicates an expected call of Serve. +func (mr *MockClientServerMockRecorder) Serve(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockClientServer)(nil).Serve), arg0) +} + +// SetAnyRequestHandler mocks base method. +func (m *MockClientServer) SetAnyRequestHandler(arg0 wsclient.RequestHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAnyRequestHandler", arg0) +} + +// SetAnyRequestHandler indicates an expected call of SetAnyRequestHandler. +func (mr *MockClientServerMockRecorder) SetAnyRequestHandler(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAnyRequestHandler", reflect.TypeOf((*MockClientServer)(nil).SetAnyRequestHandler), arg0) +} + +// SetConnection mocks base method. +func (m *MockClientServer) SetConnection(arg0 wsconn.WebsocketConn) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetConnection", arg0) +} + +// SetConnection indicates an expected call of SetConnection. +func (mr *MockClientServerMockRecorder) SetConnection(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetConnection", reflect.TypeOf((*MockClientServer)(nil).SetConnection), arg0) +} + +// SetReadDeadline mocks base method. +func (m *MockClientServer) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline. +func (mr *MockClientServerMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockClientServer)(nil).SetReadDeadline), arg0) +} + +// WriteCloseMessage mocks base method. +func (m *MockClientServer) WriteCloseMessage() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteCloseMessage") + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteCloseMessage indicates an expected call of WriteCloseMessage. +func (mr *MockClientServerMockRecorder) WriteCloseMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteCloseMessage", reflect.TypeOf((*MockClientServer)(nil).WriteCloseMessage)) +} + +// WriteMessage mocks base method. +func (m *MockClientServer) WriteMessage(arg0 []byte) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteMessage", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteMessage indicates an expected call of WriteMessage. +func (mr *MockClientServerMockRecorder) WriteMessage(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteMessage", reflect.TypeOf((*MockClientServer)(nil).WriteMessage), arg0) +} + +// MockRequestResponder is a mock of RequestResponder interface. +type MockRequestResponder struct { + ctrl *gomock.Controller + recorder *MockRequestResponderMockRecorder +} + +// MockRequestResponderMockRecorder is the mock recorder for MockRequestResponder. +type MockRequestResponderMockRecorder struct { + mock *MockRequestResponder +} + +// NewMockRequestResponder creates a new mock instance. +func NewMockRequestResponder(ctrl *gomock.Controller) *MockRequestResponder { + mock := &MockRequestResponder{ctrl: ctrl} + mock.recorder = &MockRequestResponderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequestResponder) EXPECT() *MockRequestResponderMockRecorder { + return m.recorder +} + +// HandlerFunc mocks base method. +func (m *MockRequestResponder) HandlerFunc() wsclient.RequestHandler { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandlerFunc") + ret0, _ := ret[0].(wsclient.RequestHandler) + return ret0 +} + +// HandlerFunc indicates an expected call of HandlerFunc. +func (mr *MockRequestResponderMockRecorder) HandlerFunc() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandlerFunc", reflect.TypeOf((*MockRequestResponder)(nil).HandlerFunc)) +} + +// Name mocks base method. +func (m *MockRequestResponder) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockRequestResponderMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockRequestResponder)(nil).Name)) +} + +// RegisterResponder mocks base method. +func (m *MockRequestResponder) RegisterResponder(arg0 wsclient.RespondFunc) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterResponder", arg0) +} + +// RegisterResponder indicates an expected call of RegisterResponder. +func (mr *MockRequestResponderMockRecorder) RegisterResponder(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResponder", reflect.TypeOf((*MockRequestResponder)(nil).RegisterResponder), arg0) +} + +// MockClientFactory is a mock of ClientFactory interface. +type MockClientFactory struct { + ctrl *gomock.Controller + recorder *MockClientFactoryMockRecorder +} + +// MockClientFactoryMockRecorder is the mock recorder for MockClientFactory. +type MockClientFactoryMockRecorder struct { + mock *MockClientFactory +} + +// NewMockClientFactory creates a new mock instance. +func NewMockClientFactory(ctrl *gomock.Controller) *MockClientFactory { + mock := &MockClientFactory{ctrl: ctrl} + mock.recorder = &MockClientFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClientFactory) EXPECT() *MockClientFactoryMockRecorder { + return m.recorder +} + +// New mocks base method. +func (m *MockClientFactory) New(arg0 string, arg1 *credentials.Credentials, arg2 time.Duration, arg3 *wsclient.WSClientMinAgentConfig) wsclient.ClientServer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "New", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(wsclient.ClientServer) + return ret0 +} + +// New indicates an expected call of New. +func (mr *MockClientFactoryMockRecorder) New(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "New", reflect.TypeOf((*MockClientFactory)(nil).New), arg0, arg1, arg2, arg3) +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/types.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/types.go new file mode 100644 index 00000000000..5bb29606792 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/types.go @@ -0,0 +1,59 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package wsclient + +import "reflect" + +// TypeDecoder interface defines methods to decode ecs types. +type TypeDecoder interface { + // NewOfType returns an object of a recognized type for a given type name. + // It additionally returns a boolean value which is set to false for an + // unrecognized type. + NewOfType(string) (interface{}, bool) + + // GetRecognizedTypes returns a map of type-strings (as passed in acs/tcs messages as + // the 'type' field) to a pointer to the corresponding struct type this type should + // be marshalled/unmarshalled to/from. + GetRecognizedTypes() map[string]reflect.Type +} + +// TypeDecoderImpl is an implementation for general use between ACS and +// TCS clients +type TypeDecoderImpl struct { + typeMappings map[string]reflect.Type +} + +// BuildTypeDecoder takes a list of interfaces and stores them internally as a +// list of typeMappings in the format below. +// "MyMessage": TypeOf(ecstcs.MyMessage) +func BuildTypeDecoder(recognizedTypes []interface{}) TypeDecoder { + typeMappings := make(map[string]reflect.Type) + for _, recognizedType := range recognizedTypes { + typeMappings[reflect.TypeOf(recognizedType).Name()] = reflect.TypeOf(recognizedType) + } + + return &TypeDecoderImpl{typeMappings: typeMappings} +} + +func (d *TypeDecoderImpl) NewOfType(typeString string) (interface{}, bool) { + rtype, ok := d.typeMappings[typeString] + if !ok { + return nil, false + } + return reflect.New(rtype).Interface(), true +} + +func (d *TypeDecoderImpl) GetRecognizedTypes() map[string]reflect.Type { + return d.typeMappings +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/conn.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/conn.go new file mode 100644 index 00000000000..af29a55c93c --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/conn.go @@ -0,0 +1,27 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package wsconn + +import "time" + +// WebsocketConn specifies the subset of gorilla/websocket's +// connection's methods that this client uses. +type WebsocketConn interface { + WriteMessage(messageType int, data []byte) error + WriteControl(messageType int, data []byte, deadline time.Time) error + ReadMessage() (messageType int, data []byte, err error) + Close() error + SetWriteDeadline(t time.Time) error + SetReadDeadline(t time.Time) error +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/generate_mocks.go new file mode 100644 index 00000000000..9eb2f4a97f5 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/generate_mocks.go @@ -0,0 +1,16 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package wsconn + +//go:generate mockgen -destination=mock/conn.go -copyright_file=../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn WebsocketConn diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index a41a6c4459c..3f6db267e23 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -7,6 +7,7 @@ github.com/Microsoft/go-winio/pkg/guid github.com/Microsoft/hcsshim/osversion # github.com/aws/amazon-ecs-agent/ecs-agent v0.0.0 => ../ecs-agent ## explicit; go 1.19 +github.com/aws/amazon-ecs-agent/ecs-agent/acs/client github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs github.com/aws/amazon-ecs-agent/ecs-agent/acs/session github.com/aws/amazon-ecs-agent/ecs-agent/api/attachmentinfo @@ -37,6 +38,9 @@ github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime/mocks +github.com/aws/amazon-ecs-agent/ecs-agent/wsclient +github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/mock +github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn # github.com/aws/aws-sdk-go v1.36.0 ## explicit; go 1.11 github.com/aws/aws-sdk-go/aws diff --git a/ecs-agent/acs/client/acs_client.go b/ecs-agent/acs/client/acs_client.go new file mode 100644 index 00000000000..35edea2c3a7 --- /dev/null +++ b/ecs-agent/acs/client/acs_client.go @@ -0,0 +1,73 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +// Package acsclient wraps the generated aws-sdk-go client to provide marshalling +// and unmarshalling of data over a websocket connection in the format expected +// by ACS. It allows for bidirectional communication and acts as both a +// client-and-server in terms of requests, but only as a client in terms of +// connecting. +package acsclient + +import ( + "context" + "errors" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +// clientServer implements ClientServer for acs. +type clientServer struct { + wsclient.ClientServerImpl +} + +type acsClientFactory struct{} + +// NewACSClientFactory creates a new ACS client factory object. This can be +// used to create new ACS clients. +func NewACSClientFactory() wsclient.ClientFactory { + return &acsClientFactory{} +} + +// New returns a client/server to bidirectionally communicate with ACS +// The returned struct should have both 'Connect' and 'Serve' called upon it +// before being used. +func (*acsClientFactory) New(url string, credentialProvider *credentials.Credentials, rwTimeout time.Duration, cfg *wsclient.WSClientMinAgentConfig) wsclient.ClientServer { + cs := &clientServer{} + cs.URL = url + cs.CredentialProvider = credentialProvider + cs.Cfg = cfg + cs.ServiceError = &acsError{} + cs.RequestHandlers = make(map[string]wsclient.RequestHandler) + cs.TypeDecoder = NewACSDecoder() + cs.RWTimeout = rwTimeout + return cs +} + +// Serve begins serving requests using previously registered handlers (see +// AddRequestHandler). All request handlers should be added prior to making this +// call as unhandled requests will be discarded. +func (cs *clientServer) Serve(ctx context.Context) error { + logger.Debug("ACS client starting websocket poll loop") + if !cs.IsReady() { + return errors.New("acs client: websocket not ready for connections") + } + return cs.ConsumeMessages(ctx) +} + +// Close closes the underlying connection +func (cs *clientServer) Close() error { + return cs.Disconnect() +} diff --git a/agent/acs/client/acs_client_test.go b/ecs-agent/acs/client/acs_client_test.go similarity index 94% rename from agent/acs/client/acs_client_test.go rename to ecs-agent/acs/client/acs_client_test.go index 6c4ea82491c..3b4b704983b 100644 --- a/agent/acs/client/acs_client_test.go +++ b/ecs-agent/acs/client/acs_client_test.go @@ -17,6 +17,7 @@ package acsclient import ( + "context" "encoding/json" "errors" "io" @@ -25,10 +26,9 @@ import ( "testing" "time" - "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/wsclient" - mock_wsconn "github.com/aws/amazon-ecs-agent/agent/wsclient/wsconn/mock" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" + mock_wsconn "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/wsconn/mock" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/golang/mock/gomock" @@ -108,11 +108,13 @@ const ( var testCreds = credentials.NewStaticCredentials("test-id", "test-secret", "test-token") -var testCfg = &config.Config{ +var testCfg = &wsclient.WSClientMinAgentConfig{ AcceptInsecureCert: true, AWSRegion: "us-east-1", } +var testACSClientFactory = NewACSClientFactory() + func TestMakeUnrecognizedRequest(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -178,7 +180,7 @@ func TestPayloadHandlerCalled(t *testing.T) { messageChannel <- payload } cs.AddRequestHandler(reqHandler) - go cs.Serve() + go cs.Serve(context.Background()) expectedMessage := &ecsacs.PayloadMessage{ Tasks: []*ecsacs.Task{{ @@ -210,7 +212,7 @@ func TestRefreshCredentialsHandlerCalled(t *testing.T) { } cs.AddRequestHandler(reqHandler) - go cs.Serve() + go cs.Serve(context.Background()) expectedMessage := &ecsacs.IAMRoleCredentialsMessage{ MessageId: aws.String("123"), @@ -243,7 +245,7 @@ func TestClosingConnection(t *testing.T) { cs := testCS(conn) defer cs.Close() - serveErr := cs.Serve() + serveErr := cs.Serve(context.Background()) assert.Error(t, serveErr) err := cs.MakeRequest(&ecsacs.AckRequest{}) @@ -261,7 +263,7 @@ func TestConnect(t *testing.T) { t.Fatal(<-serverErr) }() - cs := New(server.URL, testCfg, testCreds, rwTimeout) + cs := testACSClientFactory.New(server.URL, testCreds, rwTimeout, testCfg) // Wait for up to a second for the mock server to launch for i := 0; i < 100; i++ { err = cs.Connect() @@ -284,7 +286,7 @@ func TestConnect(t *testing.T) { }) go func() { - _ = cs.Serve() + _ = cs.Serve(context.Background()) }() go func() { @@ -332,7 +334,7 @@ func TestConnectClientError(t *testing.T) { })) defer testServer.Close() - cs := New(testServer.URL, testCfg, testCreds, rwTimeout) + cs := testACSClientFactory.New(testServer.URL, testCreds, rwTimeout, testCfg) err := cs.Connect() _, ok := err.(*wsclient.WSError) assert.True(t, ok, "Connect error expected to be a WSError type") @@ -340,7 +342,7 @@ func TestConnectClientError(t *testing.T) { } func testCS(conn *mock_wsconn.MockWebsocketConn) wsclient.ClientServer { - foo := New("localhost:443", testCfg, testCreds, rwTimeout) + foo := testACSClientFactory.New("localhost:443", testCreds, rwTimeout, testCfg) cs := foo.(*clientServer) cs.SetConnection(conn) return cs @@ -405,7 +407,7 @@ func TestAttachENIHandlerCalled(t *testing.T) { cs.AddRequestHandler(reqHandler) - go cs.Serve() + go cs.Serve(context.Background()) expectedMessage := &ecsacs.AttachTaskNetworkInterfacesMessage{ MessageId: aws.String("123"), @@ -456,7 +458,7 @@ func TestAttachInstanceENIHandlerCalled(t *testing.T) { cs.AddRequestHandler(reqHandler) - go cs.Serve() + go cs.Serve(context.Background()) expectedMessage := &ecsacs.AttachInstanceNetworkInterfacesMessage{ MessageId: aws.String("123"), diff --git a/ecs-agent/acs/client/acs_client_types.go b/ecs-agent/acs/client/acs_client_types.go new file mode 100644 index 00000000000..9f86ae16047 --- /dev/null +++ b/ecs-agent/acs/client/acs_client_types.go @@ -0,0 +1,60 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package acsclient + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" +) + +var acsRecognizedTypes []interface{} + +func init() { + // This list is currently *manually updated* and assumes that the generated + // struct type-names within the package *exactly match* the type sent by ACS/TCS + // (true so far; careful with inflections) + // TODO, this list should be autogenerated + // I couldn't figure out how to get a list of all structs in a package via + // reflection, but that would solve this. The alternative is to either parse + // the .json model or the generated struct names. + acsRecognizedTypes = []interface{}{ + ecsacs.HeartbeatMessage{}, + ecsacs.HeartbeatAckRequest{}, + ecsacs.PayloadMessage{}, + ecsacs.CloseMessage{}, + ecsacs.AckRequest{}, + ecsacs.NackRequest{}, + ecsacs.PerformUpdateMessage{}, + ecsacs.StageUpdateMessage{}, + ecsacs.IAMRoleCredentialsMessage{}, + ecsacs.IAMRoleCredentialsAckRequest{}, + ecsacs.ServerException{}, + ecsacs.BadRequestException{}, + ecsacs.InvalidClusterException{}, + ecsacs.InvalidInstanceException{}, + ecsacs.AccessDeniedException{}, + ecsacs.InactiveInstanceException{}, + ecsacs.ErrorMessage{}, + ecsacs.AttachTaskNetworkInterfacesMessage{}, + ecsacs.AttachInstanceNetworkInterfacesMessage{}, + ecsacs.ConfirmAttachmentMessage{}, + ecsacs.TaskManifestMessage{}, + ecsacs.TaskStopVerificationAck{}, + ecsacs.TaskStopVerificationMessage{}, + } +} + +func NewACSDecoder() wsclient.TypeDecoder { + return wsclient.BuildTypeDecoder(acsRecognizedTypes) +} diff --git a/ecs-agent/acs/client/acs_error.go b/ecs-agent/acs/client/acs_error.go new file mode 100644 index 00000000000..99d7072c6d9 --- /dev/null +++ b/ecs-agent/acs/client/acs_error.go @@ -0,0 +1,52 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package acsclient + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" +) + +const errType = "ACSError" + +// ACSUnretriableErrors wraps all the typed errors that ACS may return +type ACSUnretriableErrors struct{} + +// Get gets the list of unretriable error types. +func (err *ACSUnretriableErrors) Get() []interface{} { + return unretriableErrors +} + +// acsError implements wsclient.ServiceError interface. +type acsError struct{} + +// NewError returns an error corresponding to a typed error returned from +// ACS. It is expected that the passed in interface{} is really a struct which +// has a 'Message' field of type *string. In that case, the Message will be +// conveyed as part of the Error string as well as the type. It is safe to pass +// anything into this constructor and it will also work reasonably well with +// anything fulfilling the 'error' interface. +func (ae *acsError) NewError(err interface{}) *wsclient.WSError { + return &wsclient.WSError{ErrObj: err, Type: errType, WSUnretriableErrors: &ACSUnretriableErrors{}} +} + +// These errors are all fatal and there's nothing we can do about them. +// AccessDeniedException is actually potentially fixable because you can change +// credentials at runtime, but still close to unretriable. +var unretriableErrors = []interface{}{ + &ecsacs.InvalidInstanceException{}, + &ecsacs.InvalidClusterException{}, + &ecsacs.InactiveInstanceException{}, + &ecsacs.AccessDeniedException{}, +} diff --git a/agent/acs/client/acs_error_test.go b/ecs-agent/acs/client/acs_error_test.go similarity index 100% rename from agent/acs/client/acs_error_test.go rename to ecs-agent/acs/client/acs_error_test.go