Skip to content

Commit

Permalink
wsclient: add read and write deadlines
Browse files Browse the repository at this point in the history
This commit aims to make the websocker connection management
better by implementing the following improvements:

1. Set read and write deadlines for websocket ReadMessage and
WriteMessage operations. This is to ensure that these methods
do not hang and result in io timeout if there's issues with
the connection
2. Reduce the scope of the lock in the Connect() method. The
lock was being held for the length of Connect() method, which
meant that it wouldn't be relnquished if there was any delay
in establishing the connection. The scope of the lock has now
been reduced to just accessing the cs.conn variable
3. Start ACS heartbeat timer after the connection has been
established. The timer was being started before a call to
Connect, which meant that the connection could be prematurely
terminated for being idle if there was a delay in establishing
the connection

These changes should improve the disconnection behavior of the
websocket connection, which should help with scenarios where the
Agent never reconnects to ACS because it's forever waiting in
Disconnect() method waiting to acquire the lock (aws#985)
  • Loading branch information
aaithal committed Sep 27, 2017
1 parent 9598854 commit 8703f78
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 83 deletions.
12 changes: 6 additions & 6 deletions agent/acs/client/acs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ package acsclient

import (
"errors"
"time"

"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/logger"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/cihub/seelog"
)

var log = logger.ForModule("acs client")

// clientServer implements ClientServer for acs.
type clientServer struct {
wsclient.ClientServerImpl
Expand All @@ -37,24 +36,25 @@ type clientServer struct {
// New returns a client/server to bidirectionally communicate with ACS
// The returned struct should have both 'Connect' and 'Serve' called upon it
// before being used.
func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials) wsclient.ClientServer {
func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials, rwTimeout time.Duration) wsclient.ClientServer {
cs := &clientServer{}
cs.URL = url
cs.CredentialProvider = credentialProvider
cs.AgentConfig = cfg
cs.ServiceError = &acsError{}
cs.RequestHandlers = make(map[string]wsclient.RequestHandler)
cs.TypeDecoder = NewACSDecoder()
cs.RWTimeout = rwTimeout
return cs
}

// Serve begins serving requests using previously registered handlers (see
// AddRequestHandler). All request handlers should be added prior to making this
// call as unhandled requests will be discarded.
func (cs *clientServer) Serve() error {
log.Debug("Starting websocket poll loop")
seelog.Debug("ACS client starting websocket poll loop")
if !cs.IsReady() {
return errors.New("Websocket not ready for connections")
return errors.New("acs client: websocket not ready for connections")
}
return cs.ConsumeMessages()
}
Expand Down
38 changes: 32 additions & 6 deletions agent/acs/client/acs_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ const (
const (
TestClusterArn = "arn:aws:ec2:123:container/cluster:123456"
TestInstanceArn = "arn:aws:ec2:123:container/containerInstance/12345678"
rwTimeout = time.Second
)

var testCfg = &config.Config{
Expand All @@ -91,6 +92,7 @@ func TestMakeUnrecognizedRequest(t *testing.T) {
defer ctrl.Finish()

conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil)
conn.EXPECT().Close()

cs := testCS(conn)
Expand All @@ -107,6 +109,7 @@ func TestWriteAckRequest(t *testing.T) {
defer ctrl.Finish()

conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()
Expand All @@ -133,7 +136,13 @@ func TestPayloadHandlerCalled(t *testing.T) {
defer ctrl.Finish()

conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`), nil)
// Messages should be read from the connection at least once
conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1)
conn.EXPECT().ReadMessage().Return(websocket.TextMessage,
[]byte(`{"type":"PayloadMessage","message":{"tasks":[{"arn":"arn"}]}}`),
nil).MinTimes(1)
// Invoked when closing the connection
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()
Expand All @@ -159,7 +168,12 @@ func TestRefreshCredentialsHandlerCalled(t *testing.T) {
defer ctrl.Finish()

conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(sampleCredentialsMessage), nil)
// Messages should be read from the connection at least once
conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1)
conn.EXPECT().ReadMessage().Return(websocket.TextMessage,
[]byte(sampleCredentialsMessage), nil).MinTimes(1)
// Invoked when closing the connection
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()
Expand Down Expand Up @@ -193,9 +207,15 @@ func TestClosingConnection(t *testing.T) {

// Returning EOF tells the ClientServer that the connection is closed
conn := mock_wsclient.NewMockWebsocketConn(ctrl)
conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil)
conn.EXPECT().ReadMessage().Return(0, nil, io.EOF)
// SetWriteDeadline will be invoked once for WriteMessage() and
// once for Close()
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2)
conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()).Return(io.EOF)
conn.EXPECT().Close()
cs := testCS(conn)
defer cs.Close()

serveErr := cs.Serve()
assert.Error(t, serveErr)
Expand All @@ -215,7 +235,7 @@ func TestConnect(t *testing.T) {
t.Fatal(<-serverErr)
}()

cs := New(server.URL, testCfg, credentials.AnonymousCredentials)
cs := New(server.URL, testCfg, credentials.AnonymousCredentials, rwTimeout)
// Wait for up to a second for the mock server to launch
for i := 0; i < 100; i++ {
err = cs.Connect()
Expand Down Expand Up @@ -286,7 +306,7 @@ func TestConnectClientError(t *testing.T) {
}))
defer testServer.Close()

cs := New(testServer.URL, testCfg, credentials.AnonymousCredentials)
cs := New(testServer.URL, testCfg, credentials.AnonymousCredentials, rwTimeout)
err := cs.Connect()
_, ok := err.(*wsclient.WSError)
assert.True(t, ok)
Expand All @@ -295,7 +315,8 @@ func TestConnectClientError(t *testing.T) {

func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer {
testCreds := credentials.AnonymousCredentials
cs := New("localhost:443", testCfg, testCreds).(*clientServer)
foo := New("localhost:443", testCfg, testCreds, rwTimeout)
cs := foo.(*clientServer)
cs.SetConnection(conn)
return cs
}
Expand Down Expand Up @@ -344,7 +365,12 @@ func TestAttachENIHandlerCalled(t *testing.T) {
cs := testCS(conn)
defer cs.Close()

conn.EXPECT().ReadMessage().AnyTimes().Return(websocket.TextMessage, []byte(sampleAttachENIMessage), nil)
// Messages should be read from the connection at least once
conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1)
conn.EXPECT().ReadMessage().Return(websocket.TextMessage,
[]byte(sampleAttachENIMessage), nil).MinTimes(1)
// Invoked when closing the connection
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil)
conn.EXPECT().Close()

messageChannel := make(chan *ecsacs.AttachTaskNetworkInterfacesMessage)
Expand Down
32 changes: 18 additions & 14 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,19 +246,13 @@ func (acsSession *session) startSessionOnce() error {
client := acsSession.resources.createACSClient(url, acsSession.agentConfig)
defer client.Close()

// Start inactivity timer for closing the connection
timer := newDisconnectionTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter())
defer timer.Stop()

return acsSession.startACSSession(client, timer)
return acsSession.startACSSession(client)
}

// startACSSession starts a session with ACS. It adds request handlers for various
// kinds of messages expected from ACS. It returns on server disconnection or when
// the context is cancelled
func (acsSession *session) startACSSession(client wsclient.ClientServer, timer ttime.Timer) error {
// Any message from the server resets the disconnect timeout
client.SetAnyRequestHandler(anyMessageHandler(timer))
func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
cfg := acsSession.agentConfig

refreshCredsHandler := newRefreshCredentialsHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN,
Expand Down Expand Up @@ -313,6 +307,11 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer, timer t
return err
}
seelog.Info("Connected to ACS endpoint")
// Start inactivity timer for closing the connection
timer := newDisconnectionTimer(client, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter())
// Any message from the server resets the disconnect timeout
client.SetAnyRequestHandler(anyMessageHandler(timer, client))
defer timer.Stop()

acsSession.resources.connectedToACS()

Expand Down Expand Up @@ -377,7 +376,7 @@ func (acsSession *session) heartbeatJitter() time.Duration {

// createACSClient creates the ACS Client using the specified URL
func (acsResources *acsSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer {
return acsclient.New(url, cfg, acsResources.credentialsProvider)
return acsclient.New(url, cfg, acsResources.credentialsProvider, heartbeatTimeout+heartbeatJitter)
}

// connectedToACS records a successful connection to ACS
Expand Down Expand Up @@ -424,9 +423,8 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration, jitter time.Duration) ttime.Timer {
timer := time.AfterFunc(utils.AddJitter(timeout, jitter), func() {
seelog.Warn("ACS Connection hasn't had any activity for too long; closing connection")
closeErr := client.Close()
if closeErr != nil {
seelog.Warnf("Error disconnecting: %v", closeErr)
if err := client.Close(); err != nil {
seelog.Warnf("Error disconnecting: %v", err)
}
})

Expand All @@ -435,9 +433,15 @@ func newDisconnectionTimer(client wsclient.ClientServer, timeout time.Duration,

// anyMessageHandler handles any server message. Any server message means the
// connection is active and thus the heartbeat disconnect should not occur
func anyMessageHandler(timer ttime.Timer) func(interface{}) {
func anyMessageHandler(timer ttime.Timer, client wsclient.ClientServer) func(interface{}) {
return func(interface{}) {
seelog.Debug("ACS activity occured")
seelog.Debug("ACS activity occurred")
// Reset read deadline as there's activity on the channel
if err := client.SetReadDeadline(time.Now().Add(heartbeatTimeout + heartbeatJitter)); err != nil {
seelog.Warn("Unable to extend read deadline for ACS connection: %v", err)
}

// Reset heearbeat timer
timer.Reset(utils.AddJitter(heartbeatTimeout, heartbeatJitter))
}
}
Expand Down
8 changes: 2 additions & 6 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,9 +804,7 @@ func TestConnectionIsClosedOnIdle(t *testing.T) {
_heartbeatJitter: 10 * time.Millisecond,
}
go func() {
timer := newDisconnectionTimer(mockWsClient, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter())
defer timer.Stop()
acsSession.startACSSession(mockWsClient, timer)
acsSession.startACSSession(mockWsClient)
}()

// Wait for connection to be closed. If the connection is not closed
Expand Down Expand Up @@ -1060,11 +1058,9 @@ func TestHandlerReconnectsCorrectlySetsSendCredentialsURLParameter(t *testing.T)
_heartbeatTimeout: 20 * time.Millisecond,
_heartbeatJitter: 10 * time.Millisecond,
}
timer := newDisconnectionTimer(mockWsClient, acsSession.heartbeatTimeout(), acsSession.heartbeatJitter())
defer timer.Stop()
go func() {
for i := 0; i < 10; i++ {
acsSession.startACSSession(mockWsClient, timer)
acsSession.startACSSession(mockWsClient)
}
cancel()
}()
Expand Down
12 changes: 9 additions & 3 deletions agent/tcs/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ type clientServer struct {
// New returns a client/server to bidirectionally communicate with the backend.
// The returned struct should have both 'Connect' and 'Serve' called upon it
// before being used.
func New(url string, cfg *config.Config, credentialProvider *credentials.Credentials, statsEngine stats.Engine, publishMetricsInterval time.Duration) wsclient.ClientServer {
func New(url string,
cfg *config.Config,
credentialProvider *credentials.Credentials,
statsEngine stats.Engine,
publishMetricsInterval time.Duration,
rwTimeout time.Duration) wsclient.ClientServer {
cs := &clientServer{
statsEngine: statsEngine,
publishTicker: nil,
Expand All @@ -58,6 +63,7 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent
cs.ServiceError = &tcsError{}
cs.RequestHandlers = make(map[string]wsclient.RequestHandler)
cs.TypeDecoder = NewTCSDecoder()
cs.RWTimeout = rwTimeout
return cs
}

Expand All @@ -67,11 +73,11 @@ func New(url string, cfg *config.Config, credentialProvider *credentials.Credent
func (cs *clientServer) Serve() error {
seelog.Debug("TCS client starting websocket poll loop")
if !cs.IsReady() {
return fmt.Errorf("Websocket not ready for connections")
return fmt.Errorf("tcs client: websocket not ready for connections")
}

if cs.statsEngine == nil {
return fmt.Errorf("uninitialized stats engine")
return fmt.Errorf("tcs client: uninitialized stats engine")
}

// Start the timer function to publish metrics to the backend.
Expand Down
15 changes: 13 additions & 2 deletions agent/tcs/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
testMessageId = "testMessageId"
testCluster = "default"
testContainerInstance = "containerInstance"
rwTimeout = time.Second
)

type mockStatsEngine struct{}
Expand Down Expand Up @@ -97,7 +98,12 @@ func TestPayloadHandlerCalled(t *testing.T) {
conn := mock_wsclient.NewMockWebsocketConn(ctrl)
cs := testCS(conn)

conn.EXPECT().ReadMessage().AnyTimes().Return(1, []byte(`{"type":"AckPublishMetric","message":{}}`), nil)
// Messages should be read from the connection at least once
conn.EXPECT().SetReadDeadline(gomock.Any()).Return(nil).MinTimes(1)
conn.EXPECT().ReadMessage().Return(1,
[]byte(`{"type":"AckPublishMetric","message":{}}`), nil).MinTimes(1)
// Invoked when closing the connection
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil)
conn.EXPECT().Close()

handledPayload := make(chan *ecstcs.AckPublishMetric)
Expand All @@ -119,6 +125,8 @@ func TestPublishMetricsRequest(t *testing.T) {
defer ctrl.Finish()

conn := mock_wsclient.NewMockWebsocketConn(ctrl)
// Invoked when closing the connection
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil).Times(2)
conn.EXPECT().Close()
// TODO: should use explicit values
conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any())
Expand Down Expand Up @@ -201,7 +209,8 @@ func testCS(conn *mock_wsclient.MockWebsocketConn) wsclient.ClientServer {
AWSRegion: "us-east-1",
AcceptInsecureCert: true,
}
cs := New("localhost:443", cfg, testCreds, &mockStatsEngine{}, testPublishMetricsInterval).(*clientServer)
cs := New("localhost:443", cfg, testCreds, &mockStatsEngine{},
testPublishMetricsInterval, rwTimeout).(*clientServer)
cs.SetConnection(conn)
return cs
}
Expand All @@ -216,7 +225,9 @@ func TestCloseClientServer(t *testing.T) {
cs := testCS(conn)

gomock.InOrder(
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil),
conn.EXPECT().WriteMessage(gomock.Any(), gomock.Any()),
conn.EXPECT().SetWriteDeadline(gomock.Any()).Return(nil),
conn.EXPECT().Close(),
)

Expand Down
Loading

0 comments on commit 8703f78

Please sign in to comment.