Skip to content

Commit cd8844c

Browse files
committed
Move ACS client to ecs-agent module and refactor
1 parent 6f994c6 commit cd8844c

36 files changed

+1653
-214
lines changed

agent/acs/handler/acs_handler.go

+22-96
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"sync"
2525
"time"
2626

27-
acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
2827
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
2928
"github.com/aws/amazon-ecs-agent/agent/api"
3029
"github.com/aws/amazon-ecs-agent/agent/config"
@@ -36,10 +35,10 @@ import (
3635
"github.com/aws/amazon-ecs-agent/agent/eventstream"
3736
"github.com/aws/amazon-ecs-agent/agent/utils/retry"
3837
"github.com/aws/amazon-ecs-agent/agent/version"
39-
"github.com/aws/amazon-ecs-agent/agent/wsclient"
4038
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
4139
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
4240
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
41+
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"
4342

4443
"github.com/aws/aws-sdk-go/aws/credentials"
4544
"github.com/cihub/seelog"
@@ -105,7 +104,8 @@ type session struct {
105104
ctx context.Context
106105
cancel context.CancelFunc
107106
backoff retry.Backoff
108-
resources sessionResources
107+
clientFactory wsclient.ClientFactory
108+
sendCredentials bool
109109
latestSeqNumTaskManifest *int64
110110
doctor *doctor.Doctor
111111
_heartbeatTimeout time.Duration
@@ -115,43 +115,6 @@ type session struct {
115115
_inactiveInstanceReconnectDelay time.Duration
116116
}
117117

118-
// sessionResources defines the resource creator interface for starting
119-
// a session with ACS. This interface is intended to define methods
120-
// that create resources used to establish the connection to ACS
121-
// It is confined to just the createACSClient() method for now. It can be
122-
// extended to include the acsWsURL() and newHeartbeatTimer() methods
123-
// when needed
124-
// The goal is to make it easier to test and inject dependencies
125-
type sessionResources interface {
126-
// createACSClient creates a new websocket client
127-
createACSClient(url string, cfg *config.Config) wsclient.ClientServer
128-
sessionState
129-
}
130-
131-
// acsSessionResources implements resource creator and session state interfaces
132-
// to create resources needed to connect to ACS and to record session state
133-
// for the same
134-
type acsSessionResources struct {
135-
credentialsProvider *credentials.Credentials
136-
// sendCredentials is used to set the 'sendCredentials' URL parameter
137-
// used to connect to ACS
138-
// It is set to 'true' for the very first successful connection on
139-
// agent start. It is set to false for all successive connections
140-
sendCredentials bool
141-
}
142-
143-
// sessionState defines state recorder interface for the
144-
// session established with ACS. It can be used to record and
145-
// retrieve data shared across multiple connections to ACS
146-
type sessionState interface {
147-
// connectedToACS callback indicates that the client has
148-
// connected to ACS
149-
connectedToACS()
150-
// getSendCredentialsURLParameter retrieves the value for
151-
// the 'sendCredentials' URL parameter
152-
getSendCredentialsURLParameter() string
153-
}
154-
155118
// NewSession creates a new Session object
156119
func NewSession(
157120
ctx context.Context,
@@ -168,8 +131,8 @@ func NewSession(
168131
taskHandler *eventhandler.TaskHandler,
169132
latestSeqNumTaskManifest *int64,
170133
doctor *doctor.Doctor,
134+
clientFactory wsclient.ClientFactory,
171135
) Session {
172-
resources := newSessionResources(credentialsProvider)
173136
backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax,
174137
connectionBackoffJitter, connectionBackoffMultiplier)
175138
derivedContext, cancel := context.WithCancel(ctx)
@@ -189,9 +152,10 @@ func NewSession(
189152
ctx: derivedContext,
190153
cancel: cancel,
191154
backoff: backoff,
192-
resources: resources,
193155
latestSeqNumTaskManifest: latestSeqNumTaskManifest,
194156
doctor: doctor,
157+
clientFactory: clientFactory,
158+
sendCredentials: true,
195159
_heartbeatTimeout: heartbeatTimeout,
196160
_heartbeatJitter: heartbeatJitter,
197161
connectionTime: connectionTime,
@@ -257,14 +221,23 @@ func (acsSession *session) Start() error {
257221
// startSessionOnce creates a session with ACS and handles requests using the passed
258222
// in arguments
259223
func (acsSession *session) startSessionOnce() error {
224+
minAgentCfg := &wsclient.WSClientMinAgentConfig{
225+
AcceptInsecureCert: acsSession.agentConfig.AcceptInsecureCert,
226+
AWSRegion: acsSession.agentConfig.AWSRegion,
227+
}
228+
260229
acsEndpoint, err := acsSession.ecsClient.DiscoverPollEndpoint(acsSession.containerInstanceARN)
261230
if err != nil {
262231
seelog.Errorf("acs: unable to discover poll endpoint, err: %v", err)
263232
return err
264233
}
265234

266-
url := acsWsURL(acsEndpoint, acsSession.agentConfig.Cluster, acsSession.containerInstanceARN, acsSession.taskEngine, acsSession.resources)
267-
client := acsSession.resources.createACSClient(url, acsSession.agentConfig)
235+
url := acsSession.acsWsURL(acsEndpoint, acsSession.agentConfig.Cluster, acsSession.containerInstanceARN)
236+
client := acsSession.clientFactory.New(
237+
url,
238+
acsSession.credentialsProvider,
239+
wsRWTimeout,
240+
minAgentCfg)
268241
defer client.Close()
269242

270243
return acsSession.startACSSession(client)
@@ -371,7 +344,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
371344
client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client))
372345
defer heartbeatTimer.Stop()
373346

374-
acsSession.resources.connectedToACS()
347+
acsSession.sendCredentials = false
375348

376349
backoffResetTimer := time.AfterFunc(
377350
retry.AddJitter(acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()), func() {
@@ -383,30 +356,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
383356
})
384357
defer backoffResetTimer.Stop()
385358

386-
serveErr := make(chan error, 1)
387-
go func() {
388-
serveErr <- client.Serve()
389-
}()
390-
391-
for {
392-
select {
393-
case <-acsSession.ctx.Done():
394-
// Stop receiving and sending messages from and to ACS when
395-
// the context received from the main function is canceled
396-
seelog.Infof("ACS session exited cleanly.")
397-
return acsSession.ctx.Err()
398-
case err := <-serveErr:
399-
// Stop receiving and sending messages from and to ACS when
400-
// client.Serve returns an error. This can happen when the
401-
// connection is closed by ACS or the agent
402-
if err == nil || err == io.EOF {
403-
seelog.Info("ACS Websocket connection closed for a valid reason")
404-
} else {
405-
seelog.Errorf("Error: lost websocket connection with Agent Communication Service (ACS): %v", err)
406-
}
407-
return err
408-
}
409-
}
359+
return client.Serve(acsSession.ctx)
410360
}
411361

412362
func (acsSession *session) computeReconnectDelay(isInactiveInstance bool) time.Duration {
@@ -438,32 +388,8 @@ func (acsSession *session) heartbeatJitter() time.Duration {
438388
return acsSession._heartbeatJitter
439389
}
440390

441-
// createACSClient creates the ACS Client using the specified URL
442-
func (acsResources *acsSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer {
443-
return acsclient.New(url, cfg, acsResources.credentialsProvider, wsRWTimeout)
444-
}
445-
446-
// connectedToACS records a successful connection to ACS
447-
// It sets sendCredentials to false on such an event
448-
func (acsResources *acsSessionResources) connectedToACS() {
449-
acsResources.sendCredentials = false
450-
}
451-
452-
// getSendCredentialsURLParameter gets the value to be set for the
453-
// 'sendCredentials' URL parameter
454-
func (acsResources *acsSessionResources) getSendCredentialsURLParameter() string {
455-
return strconv.FormatBool(acsResources.sendCredentials)
456-
}
457-
458-
func newSessionResources(credentialsProvider *credentials.Credentials) sessionResources {
459-
return &acsSessionResources{
460-
credentialsProvider: credentialsProvider,
461-
sendCredentials: true,
462-
}
463-
}
464-
465391
// acsWsURL returns the websocket url for ACS given the endpoint
466-
func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.TaskEngine, acsSessionState sessionState) string {
392+
func (acsSession *session) acsWsURL(endpoint, cluster, containerInstanceArn string) string {
467393
acsURL := endpoint
468394
if endpoint[len(endpoint)-1] != '/' {
469395
acsURL += "/"
@@ -476,10 +402,10 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
476402
query.Set("agentVersion", version.Version)
477403
query.Set("seqNum", "1")
478404
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
479-
if dockerVersion, err := taskEngine.Version(); err == nil {
405+
if dockerVersion, err := acsSession.taskEngine.Version(); err == nil {
480406
query.Set("dockerVersion", "DockerVersion: "+dockerVersion)
481407
}
482-
query.Set(sendCredentialsURLParameterName, acsSessionState.getSendCredentialsURLParameter())
408+
query.Set(sendCredentialsURLParameterName, strconv.FormatBool(acsSession.sendCredentials))
483409
return acsURL + "?" + query.Encode()
484410
}
485411

0 commit comments

Comments
 (0)