Skip to content

Commit a585259

Browse files
committed
ECS agent to acknowledge server heartbeat messages
1 parent 73bfeb9 commit a585259

File tree

7 files changed

+276
-11
lines changed

7 files changed

+276
-11
lines changed

agent/acs/client/acs_client_types.go

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func init() {
3030
// the .json model or the generated struct names.
3131
acsRecognizedTypes = []interface{}{
3232
ecsacs.HeartbeatMessage{},
33+
ecsacs.HeartbeatAckRequest{},
3334
ecsacs.PayloadMessage{},
3435
ecsacs.CloseMessage{},
3536
ecsacs.AckRequest{},

agent/acs/handler/acs_handler.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"time"
2525

2626
acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
27-
"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
2827
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
2928
"github.com/aws/amazon-ecs-agent/agent/api"
3029
"github.com/aws/amazon-ecs-agent/agent/config"
@@ -65,6 +64,10 @@ const (
6564
// credentials for all tasks on establishing the connection
6665
sendCredentialsURLParameterName = "sendCredentials"
6766
inactiveInstanceExceptionPrefix = "InactiveInstanceException:"
67+
// ACS protocol version spec:
68+
// 1: default protocol version
69+
// 2: ACS will proactively close the connection when heartbeat acks are missing
70+
acsProtocolVersion = 2
6871
)
6972

7073
// Session defines an interface for handler's long-lived connection with ACS.
@@ -332,8 +335,13 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
332335

333336
client.AddRequestHandler(payloadHandler.handlerFunc())
334337

335-
// Ignore heartbeat messages; anyMessageHandler gets 'em
336-
client.AddRequestHandler(func(*ecsacs.HeartbeatMessage) {})
338+
// Add HeartbeatHandler to acknowledge ACS heartbeats
339+
heartbeatHandler := newHeartbeatHandler(acsSession.ctx, client)
340+
defer heartbeatHandler.clearAcks()
341+
heartbeatHandler.start()
342+
defer heartbeatHandler.stop()
343+
344+
client.AddRequestHandler(heartbeatHandler.handlerFunc())
337345

338346
updater.AddAgentUpdateHandlers(client, cfg, acsSession.state, acsSession.dataClient, acsSession.taskEngine)
339347

@@ -454,6 +462,7 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
454462
query.Set("agentHash", version.GitHashString())
455463
query.Set("agentVersion", version.Version)
456464
query.Set("seqNum", "1")
465+
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
457466
if dockerVersion, err := taskEngine.Version(); err == nil {
458467
query.Set("dockerVersion", "DockerVersion: "+dockerVersion)
459468
}

agent/acs/handler/acs_handler_test.go

+9-6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"reflect"
2525
"runtime"
2626
"runtime/pprof"
27+
"strconv"
2728
"sync"
2829
"testing"
2930
"time"
@@ -178,6 +179,8 @@ func TestACSWSURL(t *testing.T) {
178179
assert.Equal(t, "DockerVersion: Docker version result", parsed.Query().Get("dockerVersion"), "wrong docker version")
179180
assert.Equalf(t, "true", parsed.Query().Get(sendCredentialsURLParameterName), "Wrong value set for: %s", sendCredentialsURLParameterName)
180181
assert.Equal(t, "1", parsed.Query().Get("seqNum"), "wrong seqNum")
182+
protocolVersion, _ := strconv.Atoi(parsed.Query().Get("protocolVersion"))
183+
assert.True(t, protocolVersion > 1, "ACS protocol version should be greater than 1")
181184
}
182185

183186
// TestHandlerReconnectsOnConnectErrors tests if handler reconnects retries
@@ -844,12 +847,12 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {
844847
ended <- true
845848
}()
846849
// Warm it up
847-
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
850+
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}`
848851
serverIn <- samplePayloadMessage
849852

850853
beforeGoroutines := runtime.NumGoroutine()
851-
for i := 0; i < 100; i++ {
852-
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true}}`
854+
for i := 0; i < 40; i++ {
855+
serverIn <- `{"type":"HeartbeatMessage","message":{"healthy":true,"messageId":"123"}}`
853856
serverIn <- samplePayloadMessage
854857
closeWS <- true
855858
}
@@ -859,15 +862,15 @@ func TestHandlerDoesntLeakGoroutines(t *testing.T) {
859862

860863
// The number of goroutines finishing in the MockACSServer will affect
861864
// the result unless we wait here.
862-
time.Sleep(10 * time.Millisecond)
865+
time.Sleep(1 * time.Second)
863866
afterGoroutines := runtime.NumGoroutine()
864867

865868
t.Logf("Goroutines after 1 and after %v acs messages: %v and %v", timesConnected, beforeGoroutines, afterGoroutines)
866869

867-
if timesConnected < 50 {
870+
if timesConnected < 20 {
868871
t.Fatal("Expected times connected to be a large number, was ", timesConnected)
869872
}
870-
if afterGoroutines > beforeGoroutines+5 {
873+
if afterGoroutines > beforeGoroutines+2 {
871874
t.Error("Goroutine leak, oh no!")
872875
pprof.Lookup("goroutine").WriteTo(os.Stdout, 1)
873876
}
+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"). You may
4+
// not use this file except in compliance with the License. A copy of the
5+
// License is located at
6+
//
7+
// http://aws.amazon.com/apache2.0/
8+
//
9+
// or in the "license" file accompanying this file. This file is distributed
10+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
// express or implied. See the License for the specific language governing
12+
// permissions and limitations under the License.
13+
14+
package handler
15+
16+
import (
17+
"context"
18+
19+
"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
20+
"github.com/aws/amazon-ecs-agent/agent/wsclient"
21+
"github.com/aws/aws-sdk-go/aws"
22+
"github.com/cihub/seelog"
23+
)
24+
25+
// heartbeatHandler handles heartbeat messages from ACS
26+
type heartbeatHandler struct {
27+
heartbeatMessageBuffer chan *ecsacs.HeartbeatMessage
28+
heartbeatAckMessageBuffer chan *ecsacs.HeartbeatAckRequest
29+
ctx context.Context
30+
cancel context.CancelFunc
31+
acsClient wsclient.ClientServer
32+
}
33+
34+
// newHeartbeatHandler returns an instance of the heartbeatHandler struct
35+
func newHeartbeatHandler(ctx context.Context,
36+
acsClient wsclient.ClientServer) heartbeatHandler {
37+
38+
// Create a cancelable context from the parent context
39+
derivedContext, cancel := context.WithCancel(ctx)
40+
return heartbeatHandler{
41+
heartbeatMessageBuffer: make(chan *ecsacs.HeartbeatMessage),
42+
heartbeatAckMessageBuffer: make(chan *ecsacs.HeartbeatAckRequest),
43+
ctx: derivedContext,
44+
cancel: cancel,
45+
acsClient: acsClient,
46+
}
47+
}
48+
49+
// handlerFunc returns a function to enqueue requests onto the buffer
50+
func (heartbeatHandler *heartbeatHandler) handlerFunc() func(message *ecsacs.HeartbeatMessage) {
51+
return func(message *ecsacs.HeartbeatMessage) {
52+
heartbeatHandler.heartbeatMessageBuffer <- message
53+
}
54+
}
55+
56+
// start() invokes go routines to handle receive and respond to heartbeats
57+
func (heartbeatHandler *heartbeatHandler) start() {
58+
go heartbeatHandler.handleHeartbeatMessage()
59+
go heartbeatHandler.sendHeartbeatAck()
60+
}
61+
62+
func (heartbeatHandler *heartbeatHandler) handleHeartbeatMessage() {
63+
for {
64+
select {
65+
case message := <-heartbeatHandler.heartbeatMessageBuffer:
66+
if err := heartbeatHandler.handleSingleHeartbeatMessage(message); err != nil {
67+
seelog.Warnf("Unable to handle heartbeat message [%s]: %s", message.String(), err)
68+
}
69+
case <-heartbeatHandler.ctx.Done():
70+
return
71+
}
72+
}
73+
}
74+
75+
func (heartbeatHandler *heartbeatHandler) handleSingleHeartbeatMessage(message *ecsacs.HeartbeatMessage) error {
76+
// Agent currently has no other action hooked to heartbeat messages, except simple ack
77+
go func() {
78+
response := &ecsacs.HeartbeatAckRequest{
79+
MessageId: message.MessageId,
80+
}
81+
heartbeatHandler.heartbeatAckMessageBuffer <- response
82+
}()
83+
return nil
84+
}
85+
86+
func (heartbeatHandler *heartbeatHandler) sendHeartbeatAck() {
87+
for {
88+
select {
89+
case ack := <-heartbeatHandler.heartbeatAckMessageBuffer:
90+
heartbeatHandler.sendSingleHeartbeatAck(ack)
91+
case <-heartbeatHandler.ctx.Done():
92+
return
93+
}
94+
}
95+
}
96+
97+
func (heartbeatHandler *heartbeatHandler) sendSingleHeartbeatAck(ack *ecsacs.HeartbeatAckRequest) {
98+
err := heartbeatHandler.acsClient.MakeRequest(ack)
99+
if err != nil {
100+
seelog.Warnf("Error acknowledging server heartbeat, message id: %s, error: %s", aws.StringValue(ack.MessageId), err)
101+
}
102+
}
103+
104+
// stop() cancels the context being used by this handler, which stops the go routines started by 'start()'
105+
func (heartbeatHandler *heartbeatHandler) stop() {
106+
heartbeatHandler.cancel()
107+
}
108+
109+
// clearAcks drains the ack request channel
110+
func (heartbeatHandler *heartbeatHandler) clearAcks() {
111+
for {
112+
select {
113+
case <-heartbeatHandler.heartbeatAckMessageBuffer:
114+
default:
115+
return
116+
}
117+
}
118+
}
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// +build unit
2+
3+
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License"). You may
6+
// not use this file except in compliance with the License. A copy of the
7+
// License is located at
8+
//
9+
// http://aws.amazon.com/apache2.0/
10+
//
11+
// or in the "license" file accompanying this file. This file is distributed
12+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
13+
// express or implied. See the License for the specific language governing
14+
// permissions and limitations under the License.
15+
16+
package handler
17+
18+
import (
19+
"context"
20+
"testing"
21+
22+
"github.com/aws/amazon-ecs-agent/agent/acs/model/ecsacs"
23+
mock_wsclient "github.com/aws/amazon-ecs-agent/agent/wsclient/mock"
24+
"github.com/aws/aws-sdk-go/aws"
25+
"github.com/golang/mock/gomock"
26+
"github.com/stretchr/testify/require"
27+
)
28+
29+
const (
30+
heartbeatMessageId = "heartbeatMessageId"
31+
)
32+
33+
func TestAckHeartbeatMessage(t *testing.T) {
34+
heartbeatReceived := &ecsacs.HeartbeatMessage{
35+
MessageId: aws.String(heartbeatMessageId),
36+
Healthy: aws.Bool(true),
37+
}
38+
39+
heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
40+
MessageId: aws.String(heartbeatMessageId),
41+
}
42+
43+
validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
44+
}
45+
46+
func TestAckHeartbeatMessageNotHealthy(t *testing.T) {
47+
heartbeatReceived := &ecsacs.HeartbeatMessage{
48+
MessageId: aws.String(heartbeatMessageId),
49+
// ECS Agent currently ignores this field so we expect no behavior change
50+
Healthy: aws.Bool(false),
51+
}
52+
53+
heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
54+
MessageId: aws.String(heartbeatMessageId),
55+
}
56+
57+
validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
58+
}
59+
60+
func TestAckHeartbeatMessageWithoutMessageId(t *testing.T) {
61+
heartbeatReceived := &ecsacs.HeartbeatMessage{
62+
Healthy: aws.Bool(true),
63+
}
64+
65+
heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
66+
MessageId: nil,
67+
}
68+
69+
validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
70+
}
71+
72+
func TestAckHeartbeatMessageEmpty(t *testing.T) {
73+
heartbeatReceived := &ecsacs.HeartbeatMessage{}
74+
75+
heartbeatAckExpected := &ecsacs.HeartbeatAckRequest{
76+
MessageId: nil,
77+
}
78+
79+
validateHeartbeatAck(t, heartbeatReceived, heartbeatAckExpected)
80+
}
81+
82+
func validateHeartbeatAck(t *testing.T, heartbeatReceived *ecsacs.HeartbeatMessage, heartbeatAckExpected *ecsacs.HeartbeatAckRequest) {
83+
ctrl := gomock.NewController(t)
84+
defer ctrl.Finish()
85+
86+
ctx, cancel := context.WithCancel(context.Background())
87+
var heartbeatAckSent *ecsacs.HeartbeatAckRequest
88+
89+
mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
90+
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(message *ecsacs.HeartbeatAckRequest) {
91+
heartbeatAckSent = message
92+
cancel()
93+
}).Times(1)
94+
95+
handler := newHeartbeatHandler(ctx, mockWsClient)
96+
go handler.sendHeartbeatAck()
97+
98+
handler.handleSingleHeartbeatMessage(heartbeatReceived)
99+
100+
// wait till we get an ack from heartbeatAckMessageBuffer
101+
<-ctx.Done()
102+
103+
require.Equal(t, heartbeatAckExpected, heartbeatAckSent)
104+
}

agent/acs/model/api/api-2.json

+10-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
"requestUri":"/"
4848
},
4949
"input":{"shape":"HeartbeatMessage"},
50-
"documentation":"Heartbeat is a periodic message that informs the agent all is well."
50+
"output":{"shape":"HeartbeatAckRequest"},
51+
"documentation":"Heartbeat is a periodic message between the Agent and ECS backend to keep the connection alive."
5152
},
5253
"Payload":{
5354
"name":"Payload",
@@ -417,7 +418,14 @@
417418
"HeartbeatMessage":{
418419
"type":"structure",
419420
"members":{
420-
"healthy":{"shape":"Boolean"}
421+
"healthy":{"shape":"Boolean"},
422+
"messageId":{"shape":"String"}
423+
}
424+
},
425+
"HeartbeatAckRequest":{
426+
"type":"structure",
427+
"members":{
428+
"messageId":{"shape":"String"}
421429
}
422430
},
423431
"HostVolumeProperties":{

agent/acs/model/ecsacs/api.go

+22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)