diff --git a/agent/acs/client/acs_client.go b/agent/acs/client/acs_client.go index c46962928b4..e2bfb81a21e 100644 --- a/agent/acs/client/acs_client.go +++ b/agent/acs/client/acs_client.go @@ -20,15 +20,14 @@ package acsclient import ( "errors" + "time" "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/logger" "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/cihub/seelog" ) -var log = logger.ForModule("acs client") - // clientServer implements ClientServer for acs. type clientServer struct { wsclient.ClientServerImpl @@ -37,7 +36,7 @@ type clientServer struct { // New returns a client/server to bidirectionally communicate with ACS // The returned struct should have both 'Connect' and 'Serve' called upon it // before being used. -func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials) wsclient.ClientServer { +func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials, rwTimeout time.Duration) wsclient.ClientServer { cs := &clientServer{} cs.URL = url cs.CredentialProvider = credentialProvider @@ -45,6 +44,7 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent cs.ServiceError = &acsError{} cs.RequestHandlers = make(map[string]wsclient.RequestHandler) cs.TypeDecoder = NewACSDecoder() + cs.RWTimeout = rwTimeout return cs } @@ -52,9 +52,9 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent // AddRequestHandler). All request handlers should be added prior to making this // call as unhandled requests will be discarded. func (cs *clientServer) Serve() error { - log.Debug("Starting websocket poll loop") + seelog.Debug("ACS client starting websocket poll loop") if !cs.IsReady() { - return errors.New("Websocket not ready for connections") + return errors.New("acs client: websocket not ready for connections") } return cs.ConsumeMessages() } diff --git a/agent/acs/client/acs_client_test.go b/agent/acs/client/acs_client_test.go index 76dbba1800e..f5c6ba62218 100644 --- a/agent/acs/client/acs_client_test.go +++ b/agent/acs/client/acs_client_test.go @@ -79,6 +79,7 @@ const ( const ( TestClusterArn = "arn:aws:ec2:123:container/cluster:123456" TestInstanceArn = "arn:aws:ec2:123:container/containerInstance/12345678" + rwTimeout = time.Second ) var testCfg = &config.Config{ @@ -91,6 +92,7 @@ func TestMakeUnrecognizedRequest(t *testing.T) { defer ctrl.Finish() conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) conn.EXPECT().Close() cs := testCS(conn) @@ -107,6 +109,7 @@ func TestWriteAckRequest(t *testing.T) { defer ctrl.Finish() conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2) conn.EXPECT().Close() cs := testCS(conn) defer cs.Close() @@ -133,7 +136,13 @@ func TestPayloadHandlerCalled(t *testing.T) { defer ctrl.Finish() conn := mock_wsclient.NewMockWebsocketConn(ctrl) - conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`), nil) + // Messages should be read from the connection at least once + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1) + conn.EXPECT().ReadMessage().Return(websocket.TextMessage, + []byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`), + nil).MinTimes(1) + // Invoked when closing the connection + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) conn.EXPECT().Close() cs := testCS(conn) defer cs.Close() @@ -159,7 +168,12 @@ func TestRefreshCredentialsHandlerCalled(t *testing.T) { defer ctrl.Finish() conn := mock_wsclient.NewMockWebsocketConn(ctrl) - conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(sampleCredentialsMessage), nil) + // Messages should be read from the connection at least once + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1) + conn.EXPECT().ReadMessage().Return(websocket.TextMessage, + []byte(sampleCredentialsMessage), nil).MinTimes(1) + // Invoked when closing the connection + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) conn.EXPECT().Close() cs := testCS(conn) defer cs.Close() @@ -193,9 +207,15 @@ func TestClosingConnection(t *testing.T) { // Returning EOF tells the ClientServer that the connection is closed conn := mock_wsclient.NewMockWebsocketConn(ctrl) + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil) conn.EXPECT().ReadMessage().Return(0, nil, io.EOF) + // SetWriteDeadline will be invoked once for WriteMessage() and + // once for Close() + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2) conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(io.EOF) + conn.EXPECT().Close() cs := testCS(conn) + defer cs.Close() serveErr := cs.Serve() assert.Error(t, serveErr) @@ -215,7 +235,7 @@ func TestConnect(t *testing.T) { t.Fatal(<-serverErr) }() - cs := New(server.URL, testCfg, credentials.AnonymousCredentials) + cs := New(server.URL, testCfg, credentials.AnonymousCredentials, rwTimeout) // Wait for up to a second for the mock server to launch for i := 0; i < 100; i++ { err = cs.Connect() @@ -286,7 +306,7 @@ func TestConnectClientError(t *testing.T) { })) defer testServer.Close() - cs := New(testServer.URL, testCfg, credentials.AnonymousCredentials) + cs := New(testServer.URL, testCfg, credentials.AnonymousCredentials, rwTimeout) err := cs.Connect() _, ok := err.(*wsclient.WSError) assert.True(t, ok) @@ -295,7 +315,8 @@ func TestConnectClientError(t *testing.T) { func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer { testCreds := credentials.AnonymousCredentials - cs := New("localhost:443", testCfg, testCreds).(*clientServer) + foo := New("localhost:443", testCfg, testCreds, rwTimeout) + cs := foo.(*clientServer) cs.SetConnection(conn) return cs } @@ -344,7 +365,12 @@ func TestAttachENIHandlerCalled(t *testing.T) { cs := testCS(conn) defer cs.Close() - conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(sampleAttachENIMessage), nil) + // Messages should be read from the connection at least once + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1) + conn.EXPECT().ReadMessage().Return(websocket.TextMessage, + []byte(sampleAttachENIMessage), nil).MinTimes(1) + // Invoked when closing the connection + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) conn.EXPECT().Close() messageChannel := make(chan *ecsacs.AttachTaskNetworkInterfacesMessage) diff --git a/agent/acs/handler/acs_handler.go b/agent/acs/handler/acs_handler.go index 69a300f1db9..d8894007709 100644 --- a/agent/acs/handler/acs_handler.go +++ b/agent/acs/handler/acs_handler.go @@ -246,19 +246,13 @@ func (acsSession *session) startSessionOnce() error { client := acsSession.resources.createACSClient(url, acsSession.agentConfig) defer client.Close() - // Start inactivity timer for closing the connection - timer := newDisconnectionTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) - defer timer.Stop() - - return acsSession.startACSSession(client, timer) + return acsSession.startACSSession(client) } // startACSSession starts a session with ACS. It adds request handlers for various // kinds of messages expected from ACS. It returns on server disconnection or when // the context is cancelled -func (acsSession *session) startACSSession(client wsclient.ClientServer, timer ttime.Timer) error { - // Any message from the server resets the disconnect timeout - client.SetAnyRequestHandler(anyMessageHandler(timer)) +func (acsSession *session) startACSSession(client wsclient.ClientServer) error { cfg := acsSession.agentConfig refreshCredsHandler := newRefreshCredentialsHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN, @@ -313,6 +307,11 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer, timer t return err } seelog.Info("Connected to ACS endpoint") + // Start inactivity timer for closing the connection + timer := newDisconnectionTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) + // Any message from the server resets the disconnect timeout + client.SetAnyRequestHandler(anyMessageHandler(timer, client)) + defer timer.Stop() acsSession.resources.connectedToACS() @@ -377,7 +376,7 @@ func (acsSession *session) heartbeatJitter() time.Duration { // createACSClient creates the ACS Client using the specified URL func (acsResources *acsSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer { - return acsclient.New(url, cfg, acsResources.credentialsProvider) + return acsclient.New(url, cfg, acsResources.credentialsProvider, heartbeatTimeout+heartbeatJitter) } // connectedToACS records a successful connection to ACS @@ -424,9 +423,8 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine. func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer { timer := time.AfterFunc(utils.AddJitter(timeout, jitter), func() { seelog.Warn("ACS Connection hasn't had any activity for too long; closing connection") - closeErr := client.Close() - if closeErr != nil { - seelog.Warnf("Error disconnecting: %v", closeErr) + if err := client.Close(); err != nil { + seelog.Warnf("Error disconnecting: %v", err) } }) @@ -435,9 +433,15 @@ func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration, // anyMessageHandler handles any server message. Any server message means the // connection is active and thus the heartbeat disconnect should not occur -func anyMessageHandler(timer ttime.Timer) func(interface{}) { +func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) { return func(interface{}) { - seelog.Debug("ACS activity occured") + seelog.Debug("ACS activity occurred") + // Reset read deadline as there's activity on the channel + if err := client.SetReadDeadline(time.Now().Add(heartbeatTimeout + heartbeatJitter)); err != nil { + seelog.Warn("Unable to extend read deadline for ACS connection: %v", err) + } + + // Reset heearbeat timer timer.Reset(utils.AddJitter(heartbeatTimeout, heartbeatJitter)) } } diff --git a/agent/acs/handler/acs_handler_test.go b/agent/acs/handler/acs_handler_test.go index 0cf785713ba..b3706e725bd 100644 --- a/agent/acs/handler/acs_handler_test.go +++ b/agent/acs/handler/acs_handler_test.go @@ -804,9 +804,7 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { _heartbeatJitter: 10 * time.Millisecond, } go func() { - timer := newDisconnectionTimer(mockWsClient, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) - defer timer.Stop() - acsSession.startACSSession(mockWsClient, timer) + acsSession.startACSSession(mockWsClient) }() // Wait for connection to be closed. If the connection is not closed @@ -1060,11 +1058,9 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T) _heartbeatTimeout: 20 * time.Millisecond, _heartbeatJitter: 10 * time.Millisecond, } - timer := newDisconnectionTimer(mockWsClient, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()) - defer timer.Stop() go func() { for i := 0; i < 10; i++ { - acsSession.startACSSession(mockWsClient, timer) + acsSession.startACSSession(mockWsClient) } cancel() }() diff --git a/agent/tcs/client/client.go b/agent/tcs/client/client.go index b0092bc3ce1..da704ce7021 100644 --- a/agent/tcs/client/client.go +++ b/agent/tcs/client/client.go @@ -46,7 +46,12 @@ type clientServer struct { // New returns a client/server to bidirectionally communicate with the backend. // The returned struct should have both 'Connect' and 'Serve' called upon it // before being used. -func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials, statsEngine stats.Engine, publishMetricsInterval time.Duration) wsclient.ClientServer { +func New(url string, + cfg *config.Config, + credentialProvider *credentials.Credentials, + statsEngine stats.Engine, + publishMetricsInterval time.Duration, + rwTimeout time.Duration) wsclient.ClientServer { cs := &clientServer{ statsEngine: statsEngine, publishTicker: nil, @@ -58,6 +63,7 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent cs.ServiceError = &tcsError{} cs.RequestHandlers = make(map[string]wsclient.RequestHandler) cs.TypeDecoder = NewTCSDecoder() + cs.RWTimeout = rwTimeout return cs } @@ -67,11 +73,11 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent func (cs *clientServer) Serve() error { seelog.Debug("TCS client starting websocket poll loop") if !cs.IsReady() { - return fmt.Errorf("Websocket not ready for connections") + return fmt.Errorf("tcs client: websocket not ready for connections") } if cs.statsEngine == nil { - return fmt.Errorf("uninitialized stats engine") + return fmt.Errorf("tcs client: uninitialized stats engine") } // Start the timer function to publish metrics to the backend. diff --git a/agent/tcs/client/client_test.go b/agent/tcs/client/client_test.go index 23b59899074..08214021776 100644 --- a/agent/tcs/client/client_test.go +++ b/agent/tcs/client/client_test.go @@ -40,6 +40,7 @@ const ( testMessageId = "testMessageId" testCluster = "default" testContainerInstance = "containerInstance" + rwTimeout = time.Second ) type mockStatsEngine struct{} @@ -97,7 +98,12 @@ func TestPayloadHandlerCalled(t *testing.T) { conn := mock_wsclient.NewMockWebsocketConn(ctrl) cs := testCS(conn) - conn.EXPECT().ReadMessage().AnyTimes().Return(1, []byte(`{"type":"AckPublishMetric","message":{}}`), nil) + // Messages should be read from the connection at least once + conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1) + conn.EXPECT().ReadMessage().Return(1, + []byte(`{"type":"AckPublishMetric","message":{}}`), nil).MinTimes(1) + // Invoked when closing the connection + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil) conn.EXPECT().Close() handledPayload := make(chan *ecstcs.AckPublishMetric) @@ -119,6 +125,8 @@ func TestPublishMetricsRequest(t *testing.T) { defer ctrl.Finish() conn := mock_wsclient.NewMockWebsocketConn(ctrl) + // Invoked when closing the connection + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2) conn.EXPECT().Close() // TODO: should use explicit values conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()) @@ -201,7 +209,8 @@ func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer { AWSRegion: "us-east-1", AcceptInsecureCert: true, } - cs := New("localhost:443", cfg, testCreds, &mockStatsEngine{}, testPublishMetricsInterval).(*clientServer) + cs := New("localhost:443", cfg, testCreds, &mockStatsEngine{}, + testPublishMetricsInterval, rwTimeout).(*clientServer) cs.SetConnection(conn) return cs } @@ -216,7 +225,9 @@ func TestCloseClientServer(t *testing.T) { cs := testCS(conn) gomock.InOrder( + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil), conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()), + conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil), conn.EXPECT().Close(), ) diff --git a/agent/tcs/handler/handler.go b/agent/tcs/handler/handler.go index 740997f3a11..86132f66565 100644 --- a/agent/tcs/handler/handler.go +++ b/agent/tcs/handler/handler.go @@ -25,6 +25,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/tcs/client" "github.com/aws/amazon-ecs-agent/agent/tcs/model/ecstcs" "github.com/aws/amazon-ecs-agent/agent/utils" + "github.com/aws/amazon-ecs-agent/agent/wsclient" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/cihub/seelog" ) @@ -35,8 +36,8 @@ const ( defaultPublishMetricsInterval = 20 * time.Second // The maximum time to wait between heartbeats without disconnecting - defaultHeartbeatTimeout = 5 * time.Minute - defaultHeartbeatJitter = 3 * time.Minute + defaultHeartbeatTimeout = 1 * time.Minute + defaultHeartbeatJitter = 1 * time.Minute deregisterContainerInstanceHandler = "TCSDeregisterContainerInstanceHandler" ) @@ -94,10 +95,15 @@ func startTelemetrySession(params TelemetrySessionParams, statsEngine stats.Engi return startSession(url, params.Cfg, params.CredentialProvider, statsEngine, defaultHeartbeatTimeout, defaultHeartbeatJitter, defaultPublishMetricsInterval, params.DeregisterInstanceEventStream) } -func startSession(url string, cfg *config.Config, credentialProvider *credentials.Credentials, - statsEngine stats.Engine, heartbeatTimeout, heartbeatJitter, publishMetricsInterval time.Duration, +func startSession(url string, + cfg *config.Config, + credentialProvider *credentials.Credentials, + statsEngine stats.Engine, + heartbeatTimeout, heartbeatJitter, + publishMetricsInterval time.Duration, deregisterInstanceEventStream *eventstream.EventStream) error { - client := tcsclient.New(url, cfg, credentialProvider, statsEngine, publishMetricsInterval) + client := tcsclient.New(url, cfg, credentialProvider, statsEngine, + publishMetricsInterval, defaultHeartbeatTimeout+defaultHeartbeatJitter) defer client.Close() err := deregisterInstanceEventStream.Subscribe(deregisterContainerInstanceHandler, client.Disconnect) @@ -112,7 +118,7 @@ func startSession(url string, cfg *config.Config, credentialProvider *credential timer := time.AfterFunc(utils.AddJitter(heartbeatTimeout, heartbeatJitter), func() { // Close the connection if there haven't been any messages received from backend // for a long time. - seelog.Debug("TCS Connection hasn't had a heartbeat or an ack message in too long of a timeout; disconnecting") + seelog.Info("TCS Connection hasn't had a heartbeat or an ack message in too long of a timeout; disconnecting") client.Disconnect() }) defer timer.Stop() @@ -124,6 +130,7 @@ func startSession(url string, cfg *config.Config, credentialProvider *credential return err } seelog.Info("Connected to TCS endpoint") + client.SetAnyRequestHandler(anyMessageHandler(client)) return client.Serve() } @@ -144,6 +151,18 @@ func ackPublishMetricHandler(timer *time.Timer) func(*ecstcs.AckPublishMetric) { } } +// anyMessageHandler handles any server message. Any server message means the +// connection is active +func anyMessageHandler(client wsclient.ClientServer) func(interface{}) { + return func(interface{}) { + seelog.Trace("TCS activity occured") + // Reset read deadline as there's activity on the channel + if err := client.SetReadDeadline(time.Now().Add(defaultHeartbeatTimeout + defaultHeartbeatJitter)); err != nil { + seelog.Warnf("Unable to extend read deadline for TCS connection: %v", err) + } + } +} + // formatURL returns formatted url for tcs endpoint. func formatURL(endpoint string, cluster string, containerInstance string) string { tcsURL := endpoint diff --git a/agent/tcs/handler/handler_test.go b/agent/tcs/handler/handler_test.go index 45db31fd084..2ea36e58942 100644 --- a/agent/tcs/handler/handler_test.go +++ b/agent/tcs/handler/handler_test.go @@ -82,7 +82,7 @@ func TestFormatURL(t *testing.T) { func TestStartSession(t *testing.T) { // Start test server. closeWS := make(chan []byte) - server, serverChan, requestChan, serverErr, err := wsmock.GetMockServer(t, closeWS) + server, serverChan, requestChan, serverErr, err := wsmock.GetMockServer(closeWS) server.StartTLS() defer server.Close() if err != nil { @@ -143,7 +143,7 @@ func TestStartSession(t *testing.T) { func TestSessionConnectionClosedByRemote(t *testing.T) { // Start test server. closeWS := make(chan []byte) - server, serverChan, _, serverErr, err := wsmock.GetMockServer(t, closeWS) + server, serverChan, _, serverErr, err := wsmock.GetMockServer(closeWS) server.StartTLS() defer server.Close() if err != nil { @@ -183,7 +183,7 @@ func TestSessionConnectionClosedByRemote(t *testing.T) { func TestConnectionInactiveTimeout(t *testing.T) { // Start test server. closeWS := make(chan []byte) - server, _, requestChan, serverErr, err := wsmock.GetMockServer(t, closeWS) + server, _, requestChan, serverErr, err := wsmock.GetMockServer(closeWS) server.StartTLS() defer server.Close() if err != nil { diff --git a/agent/wsclient/client.go b/agent/wsclient/client.go index 86b693066b1..18ccfbd8f8b 100644 --- a/agent/wsclient/client.go +++ b/agent/wsclient/client.go @@ -38,6 +38,7 @@ import ( "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" "github.com/cihub/seelog" "github.com/gorilla/websocket" + "github.com/pkg/errors" ) const ( @@ -48,6 +49,9 @@ const ( // wsConnectTimeout specifies the default connection timeout to the backend. wsConnectTimeout = 30 * time.Second + // wsHandshakeTimeout specifies the default handshake timeout for the websocket client + wsHandshakeTimeout = wsConnectTimeout + // readBufSize is the size of the read buffer for the ws connection. readBufSize = 4096 @@ -77,6 +81,8 @@ type WebsocketConn interface { WriteMessage(messageType int, data []byte) error ReadMessage() (messageType int, data []byte, err error) Close() error + SetWriteDeadline(t time.Time) error + SetReadDeadline(t time.Time) error } // RequestHandler would be func(*ecsacs.T for T in ecsacs.*) to be more proper, but it needs @@ -98,6 +104,7 @@ type ClientServer interface { SetConnection(conn WebsocketConn) Disconnect(...interface{}) error Serve() error + SetReadDeadline(t time.Time) error io.Closer } @@ -121,8 +128,11 @@ type ClientServerImpl struct { AnyRequestHandler RequestHandler // URL is the full url to the backend, including path, querystring, and so on. URL string + // RWTimeout is the duration used for setting read and write deadlines + // for the websocket connection + RWTimeout time.Duration // writeLock needed to ensure that only one routine is writing to the socket - writeLock sync.Mutex + writeLock sync.RWMutex ClientServer ServiceError TypeDecoder @@ -133,8 +143,6 @@ type ClientServerImpl struct { // receivable until 'Serve' is also called. func (cs *ClientServerImpl) Connect() error { seelog.Debugf("Establishing a Websocket connection to %s", cs.URL) - cs.writeLock.Lock() - defer cs.writeLock.Unlock() parsedURL, err := url.Parse(cs.URL) if err != nil { return err @@ -170,11 +178,12 @@ func (cs *ClientServerImpl) Connect() error { } dialer := websocket.Dialer{ - ReadBufferSize: readBufSize, - WriteBufferSize: writeBufSize, - TLSClientConfig: tlsConfig, - Proxy: http.ProxyFromEnvironment, - NetDial: timeoutDialer.Dial, + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + TLSClientConfig: tlsConfig, + Proxy: http.ProxyFromEnvironment, + NetDial: timeoutDialer.Dial, + HandshakeTimeout: wsHandshakeTimeout, } websocketConn, httpResponse, err := dialer.Dial(parsedURL.String(), request.Header) @@ -198,35 +207,56 @@ func (cs *ClientServerImpl) Connect() error { } } seelog.Warnf("Error creating a websocket client: %v", err) - return fmt.Errorf(string(resp) + ", " + err.Error()) + return errors.Wrapf(err, "websocket client: unable to dial %s response: %s", + parsedURL.Host, string(resp)) } + + cs.writeLock.Lock() + defer cs.writeLock.Unlock() + cs.conn = websocketConn + seelog.Debugf("Established a Websocket connection to %s", cs.URL) return nil } // IsReady gives a boolean response that informs the caller if the websocket // connection is fully established. func (cs *ClientServerImpl) IsReady() bool { - cs.writeLock.Lock() - defer cs.writeLock.Unlock() + cs.writeLock.RLock() + defer cs.writeLock.RUnlock() + return cs.conn != nil } -// SetConnection passes a websocket connection object into the client. +// SetConnection passes a websocket connection object into the client. This is used only in +// testing and should be avoided in non-test code. func (cs *ClientServerImpl) SetConnection(conn WebsocketConn) { cs.conn = conn } +// SetReadDeadline sets the read deadline for the websocket connection +// A read timeout results in an io error if there are any outstanding reads +// that exceed the deadline +func (cs *ClientServerImpl) SetReadDeadline(t time.Time) error { + return cs.conn.SetReadDeadline(t) +} + // Disconnect disconnects the connection func (cs *ClientServerImpl) Disconnect(...interface{}) error { cs.writeLock.Lock() defer cs.writeLock.Unlock() - if cs.conn != nil { - return cs.conn.Close() + if cs.conn == nil { + return fmt.Errorf("websocker client: no connection to close") } - return fmt.Errorf("No Connection to close") + // Close() in turn results in a an internal flushFrame() call in gorilla + // as the close frame needs to be sent to the server. Set the deadline + // for that as well. + if err := cs.conn.SetWriteDeadline(time.Now().Add(cs.RWTimeout)); err != nil { + seelog.Warnf("Unable to set write deadline for websocket connection: %v for %s", err, cs.URL) + } + return cs.conn.Close() } // AddRequestHandler adds a request handler to this client. @@ -272,6 +302,14 @@ func (cs *ClientServerImpl) MakeRequest(input interface{}) error { func (cs *ClientServerImpl) WriteMessage(send []byte) error { cs.writeLock.Lock() defer cs.writeLock.Unlock() + + // This is just future proofing. Ignore the error as the gorilla websocket + // library returns 'nil' anyway for SetWriteDeadline + // https://github.com/gorilla/websocket/blob/master/conn.go#L761 + if err := cs.conn.SetWriteDeadline(time.Now().Add(cs.RWTimeout)); err != nil { + seelog.Warnf("Unable to set write deadline for websocket connection: %v for %s", err, cs.URL) + } + return cs.conn.WriteMessage(websocket.TextMessage, send) } @@ -279,16 +317,19 @@ func (cs *ClientServerImpl) WriteMessage(send []byte) error { // messages from an active connection. func (cs *ClientServerImpl) ConsumeMessages() error { for { + // Ignore errors when setting the read deadline as any connection + // related errors would be caught by ReadMessage as well + if err := cs.SetReadDeadline(time.Now().Add(cs.RWTimeout)); err != nil { + seelog.Warnf("Unable to set read deadline for websocket connection: %v for %s", err, cs.URL) + } messageType, message, err := cs.conn.ReadMessage() switch { - case err == nil: if messageType != websocket.TextMessage { // maybe not fatal though, we'll try to process it anyways seelog.Errorf("Unexpected messageType: %v", messageType) } - seelog.Debug("Got a message from websocket") cs.handleMessage(message) case permissibleCloseCode(err): @@ -296,7 +337,7 @@ func (cs *ClientServerImpl) ConsumeMessages() error { return io.EOF default: - //Unexpected error occurred + // Unexpected error occurred seelog.Errorf("Error getting message from ws backend: error: [%v], message: [%s], messageType: [%v] ", err, message, messageType) return err } diff --git a/agent/wsclient/client_test.go b/agent/wsclient/client_test.go index 10c32706859..fcf2343337a 100644 --- a/agent/wsclient/client_test.go +++ b/agent/wsclient/client_test.go @@ -17,7 +17,9 @@ import ( "io" "net/url" "os" + "sync" "testing" + "time" "github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs" "github.com/aws/amazon-ecs-agent/agent/config" @@ -29,6 +31,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const dockerEndpoint = "/var/run/docker.sock" @@ -40,30 +43,37 @@ func TestConcurrentWritesDontPanic(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, requests, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, requests, _, _ := utils.GetMockServer(closeWS) mockServer.StartTLS() defer mockServer.Close() + var waitForRequests sync.WaitGroup + waitForRequests.Add(1) + + go func() { + for i := 0; i < 20; i++ { + <-requests + } + waitForRequests.Done() + }() req := ecsacs.AckRequest{Cluster: aws.String("test"), ContainerInstance: aws.String("test"), MessageId: aws.String("test")} cs := getClientServer(mockServer.URL) - cs.Connect() + require.NoError(t, cs.Connect()) executeTenRequests := func() { for i := 0; i < 10; i++ { - cs.MakeRequest(&req) + assert.NoError(t, cs.MakeRequest(&req)) } } // Make requests from two separate routines to try and force a // concurrent write go executeTenRequests() - executeTenRequests() + go executeTenRequests() t.Log("Waiting for all 20 requests to succeed") - for i := 0; i < 20; i++ { - <-requests - } + waitForRequests.Wait() } // TestProxyVariableCustomValue ensures that a user is able to override the @@ -72,13 +82,13 @@ func TestProxyVariableCustomValue(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.GetMockServer(closeWS) mockServer.StartTLS() defer mockServer.Close() testString := "Custom no proxy string" os.Setenv("NO_PROXY", testString) - getClientServer(mockServer.URL).Connect() + require.NoError(t, getClientServer(mockServer.URL).Connect()) assert.Equal(t, os.Getenv("NO_PROXY"), testString, "NO_PROXY should match user-supplied variable") } @@ -89,7 +99,7 @@ func TestProxyVariableDefaultValue(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.GetMockServer(closeWS) mockServer.StartTLS() defer mockServer.Close() @@ -108,10 +118,10 @@ func TestHandleMessagePermissibleCloseCode(t *testing.T) { defer close(closeWS) messageError := make(chan error) - mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.GetMockServer(closeWS) mockServer.StartTLS() cs := getClientServer(mockServer.URL) - cs.Connect() + require.NoError(t, cs.Connect()) go func() { messageError <- cs.ConsumeMessages() @@ -128,10 +138,10 @@ func TestHandleMessageUnexpectedCloseCode(t *testing.T) { defer close(closeWS) messageError := make(chan error) - mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.GetMockServer(closeWS) mockServer.StartTLS() cs := getClientServer(mockServer.URL) - cs.Connect() + require.NoError(t, cs.Connect()) go func() { messageError <- cs.ConsumeMessages() @@ -147,12 +157,12 @@ func TestHandleNonHTTPSEndpoint(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, requests, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, requests, _, _ := utils.GetMockServer(closeWS) mockServer.Start() defer mockServer.Close() cs := getClientServer(mockServer.URL) - cs.Connect() + require.NoError(t, cs.Connect()) req := ecsacs.AckRequest{Cluster: aws.String("test"), ContainerInstance: aws.String("test"), MessageId: aws.String("test")} cs.MakeRequest(&req) @@ -167,7 +177,7 @@ func TestHandleIncorrectURLScheme(t *testing.T) { closeWS := make(chan []byte) defer close(closeWS) - mockServer, _, _, _, _ := utils.GetMockServer(t, closeWS) + mockServer, _, _, _, _ := utils.GetMockServer(closeWS) mockServer.StartTLS() defer mockServer.Close() @@ -213,5 +223,6 @@ func getClientServer(url string) *ClientServerImpl { }, CredentialProvider: credentials.AnonymousCredentials, TypeDecoder: BuildTypeDecoder(types), + RWTimeout: time.Second, } } diff --git a/agent/wsclient/mock/client.go b/agent/wsclient/mock/client.go index a3bc2f55b96..e49ea719ca4 100644 --- a/agent/wsclient/mock/client.go +++ b/agent/wsclient/mock/client.go @@ -17,6 +17,8 @@ package mock_wsclient import ( + time "time" + wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient" gomock "github.com/golang/mock/gomock" ) @@ -130,6 +132,16 @@ func (_mr *_MockClientServerRecorder) SetConnection(arg0 interface{}) *gomock.Ca return _mr.mock.ctrl.RecordCall(_mr.mock, "SetConnection", arg0) } +func (_m *MockClientServer) SetReadDeadline(_param0 time.Time) error { + ret := _m.ctrl.Call(_m, "SetReadDeadline", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockClientServerRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "SetReadDeadline", arg0) +} + func (_m *MockClientServer) WriteMessage(_param0 []byte) error { ret := _m.ctrl.Call(_m, "WriteMessage", _param0) ret0, _ := ret[0].(error) @@ -183,6 +195,26 @@ func (_mr *_MockWebsocketConnRecorder) ReadMessage() *gomock.Call { return _mr.mock.ctrl.RecordCall(_mr.mock, "ReadMessage") } +func (_m *MockWebsocketConn) SetReadDeadline(_param0 time.Time) error { + ret := _m.ctrl.Call(_m, "SetReadDeadline", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockWebsocketConnRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "SetReadDeadline", arg0) +} + +func (_m *MockWebsocketConn) SetWriteDeadline(_param0 time.Time) error { + ret := _m.ctrl.Call(_m, "SetWriteDeadline", _param0) + ret0, _ := ret[0].(error) + return ret0 +} + +func (_mr *_MockWebsocketConnRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + return _mr.mock.ctrl.RecordCall(_mr.mock, "SetWriteDeadline", arg0) +} + func (_m *MockWebsocketConn) WriteMessage(_param0 int, _param1 []byte) error { ret := _m.ctrl.Call(_m, "WriteMessage", _param0, _param1) ret0, _ := ret[0].(error) diff --git a/agent/wsclient/mock/utils/utils.go b/agent/wsclient/mock/utils/utils.go index add4b3d2102..3dc2e62b17d 100644 --- a/agent/wsclient/mock/utils/utils.go +++ b/agent/wsclient/mock/utils/utils.go @@ -16,7 +16,6 @@ package utils import ( "net/http" "net/http/httptest" - "testing" "time" "github.com/gorilla/websocket" @@ -24,7 +23,7 @@ import ( // GetMockServer retuns a mock websocket server that can be started up as TLS or not. // TODO replace with gomock -func GetMockServer(t *testing.T, closeWS <-chan []byte) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { +func GetMockServer(closeWS <-chan []byte) (*httptest.Server, chan<- string, <-chan string, <-chan error, error) { serverChan := make(chan string) requestsChan := make(chan string) errChan := make(chan error)