Skip to content

Commit

Permalink
send pending acks before closing ACS connection
Browse files Browse the repository at this point in the history
  • Loading branch information
singholt committed Feb 24, 2023
1 parent 6d54c7c commit 31b29d0
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 6 deletions.
58 changes: 54 additions & 4 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
Expand Down Expand Up @@ -75,6 +76,10 @@ const (
// 1: default protocol version
// 2: ACS will proactively close the connection when heartbeat acks are missing
acsProtocolVersion = 2
// numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across
// sessions. We use this to send pending acks, before agent initiates a disconnect to ACS.
// they are: refreshCredentialsHandler, taskManifestHandler, payloadHandler and heartbeatHandler
numOfHandlersSendingAcks = 4
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -369,8 +374,10 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
}

seelog.Info("Connected to ACS endpoint")
// Start a connection timer; agent will close its ACS websocket connection after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter)
// Start a connection timer; agent will send pending acks and close its ACS websocket connection
// after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter,
&refreshCredsHandler, &taskManifestHandler, &payloadHandler, &heartbeatHandler)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
Expand Down Expand Up @@ -505,10 +512,53 @@ func newHeartbeatTimer(client wsclient.ClientServer, timeout time.Duration, jitt
return timer
}

// newConnectionTimer creates a new timer, after which agent closes its ACS websocket connection
func newConnectionTimer(client wsclient.ClientServer, connectionTime time.Duration, connectionJitter time.Duration) ttime.Timer {
// newConnectionTimer creates a new timer, after which agent sends any pending acks to ACS and closes
// its websocket connection
func newConnectionTimer(
client wsclient.ClientServer,
connectionTime time.Duration,
connectionJitter time.Duration,
refreshCredsHandler *refreshCredentialsHandler,
taskManifestHandler *taskManifestHandler,
payloadHandler *payloadRequestHandler,
heartbeatHandler *heartbeatHandler,
) ttime.Timer {
expiresAt := retry.AddJitter(connectionTime, connectionJitter)
timer := time.AfterFunc(expiresAt, func() {
seelog.Debugf("Sending pending acks to ACS before closing the connection")

wg := sync.WaitGroup{}
wg.Add(numOfHandlersSendingAcks)

// send pending creds refresh acks to ACS
go func() {
refreshCredsHandler.sendPendingAcks()
wg.Done()
}()

// send pending task manifest acks and task stop verification acks to ACS
go func() {
taskManifestHandler.sendPendingTaskManifestMessageAck()
taskManifestHandler.handlePendingTaskStopVerificationAck()
wg.Done()
}()

// send pending payload acks to ACS
go func() {
payloadHandler.sendPendingAcks()
wg.Done()
}()

// send pending heartbeat acks to ACS
go func() {
heartbeatHandler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for acks from all the handlers above to be sent to ACS before closing the websocket connection.
// the methods used to read pending acks are non-blocking, so it is safe to wait here.
wg.Wait()

seelog.Infof("Closing ACS websocket connection after %v minutes", expiresAt.Minutes())
// WriteCloseMessage() writes a close message using websocket control messages
// Ref: https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages
Expand Down
12 changes: 12 additions & 0 deletions agent/acs/handler/heartbeat_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
}
}

// sendPendingHeartbeatAck sends all pending heartbeat acks to ACS before closing the connection
func (heartbeatHandler *heartbeatHandler) sendPendingHeartbeatAck() {
for {
select {
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
heartbeatHandler.sendSingleHeartbeatAck(ack)
default:
return
}
}
}

func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
err := heartbeatHandler.acsClient.MakeRequest(ack)
if err != nil {
Expand Down
42 changes: 42 additions & 0 deletions agent/acs/handler/heartbeat_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ package handler

import (
"context"
"sync"
"testing"
"time"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
mock_dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi/mocks"
"github.com/aws/amazon-ecs-agent/agent/doctor"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"

"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -112,3 +116,41 @@ func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessa

require.Equal(t, heartbeatAckExpected, heartbeatAckSent)
}

func TestHeartbeatHandler(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.TODO()
emptyHealthCheckList := []doctor.Healthcheck{}
emptyDoctor, _ := doctor.NewDoctor(emptyHealthCheckList, "testCluster",
"this:is:an:instance:arn")
mockWSClient := mock_wsclient.NewMockClientServer(ctrl)
mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1)
handler := newHeartbeatHandler(ctx, mockWSClient, emptyDoctor)

wg := sync.WaitGroup{}
wg.Add(2)

// write a dummy ack into the heartbeatAckMessageBuffer
go func() {
handler.heartbeatAckMessageBuffer <- &ecsacs.HeartbeatAckRequest{}
wg.Done()
}()

// sleep here to ensure that the sending go routine executes before the receiving one below. if not, then the
// receiving go routine will finish without receiving the ack since sendPendingHeartbeatAck() is non-blocking.
time.Sleep(1 * time.Second)

go func() {
handler.sendPendingHeartbeatAck()
wg.Done()
}()

// wait for both go routines above to finish before we verify that ack channel is empty and exit the test.
// this also ensures that the mock MakeRequest call happened as expected.
wg.Wait()

// verify that the heartbeatAckMessageBuffer channel is empty
assert.Equal(t, 0, len(handler.heartbeatAckMessageBuffer))
}
12 changes: 12 additions & 0 deletions agent/acs/handler/payload_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ func (payloadHandler *payloadRequestHandler) sendAcks() {
}
}

// sendPendingAcks sends ack requests to ACS before closing the connection
func (payloadHandler *payloadRequestHandler) sendPendingAcks() {
for {
select {
case mid := <-payloadHandler.ackRequest:
payloadHandler.ackMessageId(mid)
default:
return
}
}
}

// ackMessageId sends an AckRequest for a message id
func (payloadHandler *payloadRequestHandler) ackMessageId(messageID string) {
seelog.Debugf("Acking payload message id: %s", messageID)
Expand Down
33 changes: 33 additions & 0 deletions agent/acs/handler/payload_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"reflect"
"sync"
"testing"
"time"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/agent/api"
Expand Down Expand Up @@ -1035,3 +1036,35 @@ func TestPayloadHandlerAddedFirelensData(t *testing.T) {
assert.NotNil(t, actual.Options)
assert.Equal(t, aws.StringValue(expected.Options["enable-ecs-log-metadata"]), actual.Options["enable-ecs-log-metadata"])
}

func TestPayloadHandlerSendPendingAcks(t *testing.T) {
tester := setup(t)
defer tester.ctrl.Finish()

tester.mockWsClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1)

wg := sync.WaitGroup{}
wg.Add(2)

// write a dummy ack into the ackRequest
go func() {
tester.payloadHandler.ackRequest <- "testMessageID"
wg.Done()
}()

// sleep here to ensure that the sending go routine above executes before the receiving one below. if not, then the
// receiving go routine will finish without receiving the ack msg since sendPendingAcks() is non-blocking.
time.Sleep(1 * time.Second)

go func() {
tester.payloadHandler.sendPendingAcks()
wg.Done()
}()

// wait for both go routines above to finish before we verify that ack channel is empty and exit the test.
// this also ensures that the mock MakeRequest call happened as expected.
wg.Wait()

// verify that the ackRequest channel is empty
assert.Equal(t, 0, len(tester.payloadHandler.ackRequest))
}
12 changes: 12 additions & 0 deletions agent/acs/handler/refresh_credentials_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ func (refreshHandler *refreshCredentialsHandler) sendAcks() {
}
}

// sendPendingAcks sends pending acks to ACS before closing the connection
func (refreshHandler *refreshCredentialsHandler) sendPendingAcks() {
for {
select {
case ack := <-refreshHandler.ackRequest:
refreshHandler.ackMessage(ack)
default:
return
}
}
}

// ackMessageId sends an IAMRoleCredentialsAckRequest to the backend
func (refreshHandler *refreshCredentialsHandler) ackMessage(ack *ecsacs.IAMRoleCredentialsAckRequest) {
err := refreshHandler.acsClient.MakeRequest(ack)
Expand Down
46 changes: 44 additions & 2 deletions agent/acs/handler/refresh_credentials_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@
package handler

import (
"context"
"reflect"
"sync"
"testing"

"context"
"time"

"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
apitask "github.com/aws/amazon-ecs-agent/agent/api/task"
"github.com/aws/amazon-ecs-agent/agent/credentials"
mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks"
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"

"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
)

const (
Expand Down Expand Up @@ -277,6 +280,45 @@ func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) {
}
}

func TestRefreshCredentialsHandlerSendPendingAcks(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.TODO()
credentialsManager := credentials.NewManager()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
mockWSClient := mock_wsclient.NewMockClientServer(ctrl)
mockWSClient.EXPECT().MakeRequest(gomock.Any()).Return(nil).Times(1)

handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWSClient,
credentialsManager, taskEngine)

wg := sync.WaitGroup{}
wg.Add(2)

// write a dummy ack into the ackRequest
go func() {
handler.ackRequest <- expectedAck
wg.Done()
}()

// sleep here to ensure that the sending go routine above executes before the receiving one below. if not, then the
// receiving go routine will finish without receiving the ack msg since sendPendingAcks() is non-blocking.
time.Sleep(1 * time.Second)

go func() {
handler.sendPendingAcks()
wg.Done()
}()

// wait for both go routines above to finish before we verify that ack channel is empty and exit the test.
// this also ensures that the mock MakeRequest call happened as expected.
wg.Wait()

// verify that the ackRequest channel is empty
assert.Equal(t, 0, len(handler.ackRequest))
}

// TestRefreshCredentialsHandler tests if a credential message is acked when
// the message is sent to the messageBuffer channel
func TestRefreshCredentialsHandler(t *testing.T) {
Expand Down
27 changes: 27 additions & 0 deletions agent/acs/handler/task_manifest_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ func (taskManifestHandler *taskManifestHandler) sendTaskManifestMessageAck() {
}
}

// sendPendingTaskManifestMessageAck sends all pending task manifest acks to ACS before closing the connection
func (taskManifestHandler *taskManifestHandler) sendPendingTaskManifestMessageAck() {
for {
select {
case messageBufferTaskManifestAck := <-taskManifestHandler.messageBufferTaskManifestAck:
taskManifestHandler.ackTaskManifestMessage(messageBufferTaskManifestAck)
default:
return
}
}
}

func (taskManifestHandler *taskManifestHandler) handleTaskStopVerificationAck() {
for {
select {
Expand All @@ -130,6 +142,21 @@ func (taskManifestHandler *taskManifestHandler) handleTaskStopVerificationAck()
}
}

// handlePendingTaskStopVerificationAck sends pending task stop verification acks to ACS before closing the connection
func (taskManifestHandler *taskManifestHandler) handlePendingTaskStopVerificationAck() {
for {
select {
case messageBufferTaskStopVerificationAck := <-taskManifestHandler.messageBufferTaskStopVerificationAck:
if err := taskManifestHandler.handleSingleMessageVerificationAck(messageBufferTaskStopVerificationAck); err != nil {
seelog.Warnf("Error handling Verification ack with messageID: %s, error: %v",
messageBufferTaskStopVerificationAck.MessageId, err)
}
default:
return
}
}
}

func (taskManifestHandler *taskManifestHandler) clearAcks() {
for {
select {
Expand Down
Loading

0 comments on commit 31b29d0

Please sign in to comment.