@@ -24,7 +24,6 @@ import (
24
24
"sync"
25
25
"time"
26
26
27
- acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
28
27
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
29
28
"github.com/aws/amazon-ecs-agent/agent/api"
30
29
"github.com/aws/amazon-ecs-agent/agent/config"
@@ -36,10 +35,10 @@ import (
36
35
"github.com/aws/amazon-ecs-agent/agent/eventstream"
37
36
"github.com/aws/amazon-ecs-agent/agent/utils/retry"
38
37
"github.com/aws/amazon-ecs-agent/agent/version"
39
- "github.com/aws/amazon-ecs-agent/agent/wsclient"
40
38
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
41
39
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
42
40
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
41
+ "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"
43
42
44
43
"github.com/aws/aws-sdk-go/aws/credentials"
45
44
"github.com/cihub/seelog"
@@ -105,7 +104,8 @@ type session struct {
105
104
ctx context.Context
106
105
cancel context.CancelFunc
107
106
backoff retry.Backoff
108
- resources sessionResources
107
+ clientFactory wsclient.ClientFactory
108
+ sendCredentials bool
109
109
latestSeqNumTaskManifest * int64
110
110
doctor * doctor.Doctor
111
111
_heartbeatTimeout time.Duration
@@ -115,43 +115,6 @@ type session struct {
115
115
_inactiveInstanceReconnectDelay time.Duration
116
116
}
117
117
118
- // sessionResources defines the resource creator interface for starting
119
- // a session with ACS. This interface is intended to define methods
120
- // that create resources used to establish the connection to ACS
121
- // It is confined to just the createACSClient() method for now. It can be
122
- // extended to include the acsWsURL() and newHeartbeatTimer() methods
123
- // when needed
124
- // The goal is to make it easier to test and inject dependencies
125
- type sessionResources interface {
126
- // createACSClient creates a new websocket client
127
- createACSClient (url string , cfg * config.Config ) wsclient.ClientServer
128
- sessionState
129
- }
130
-
131
- // acsSessionResources implements resource creator and session state interfaces
132
- // to create resources needed to connect to ACS and to record session state
133
- // for the same
134
- type acsSessionResources struct {
135
- credentialsProvider * credentials.Credentials
136
- // sendCredentials is used to set the 'sendCredentials' URL parameter
137
- // used to connect to ACS
138
- // It is set to 'true' for the very first successful connection on
139
- // agent start. It is set to false for all successive connections
140
- sendCredentials bool
141
- }
142
-
143
- // sessionState defines state recorder interface for the
144
- // session established with ACS. It can be used to record and
145
- // retrieve data shared across multiple connections to ACS
146
- type sessionState interface {
147
- // connectedToACS callback indicates that the client has
148
- // connected to ACS
149
- connectedToACS ()
150
- // getSendCredentialsURLParameter retrieves the value for
151
- // the 'sendCredentials' URL parameter
152
- getSendCredentialsURLParameter () string
153
- }
154
-
155
118
// NewSession creates a new Session object
156
119
func NewSession (
157
120
ctx context.Context ,
@@ -168,8 +131,8 @@ func NewSession(
168
131
taskHandler * eventhandler.TaskHandler ,
169
132
latestSeqNumTaskManifest * int64 ,
170
133
doctor * doctor.Doctor ,
134
+ clientFactory wsclient.ClientFactory ,
171
135
) Session {
172
- resources := newSessionResources (credentialsProvider )
173
136
backoff := retry .NewExponentialBackoff (connectionBackoffMin , connectionBackoffMax ,
174
137
connectionBackoffJitter , connectionBackoffMultiplier )
175
138
derivedContext , cancel := context .WithCancel (ctx )
@@ -189,9 +152,10 @@ func NewSession(
189
152
ctx : derivedContext ,
190
153
cancel : cancel ,
191
154
backoff : backoff ,
192
- resources : resources ,
193
155
latestSeqNumTaskManifest : latestSeqNumTaskManifest ,
194
156
doctor : doctor ,
157
+ clientFactory : clientFactory ,
158
+ sendCredentials : true ,
195
159
_heartbeatTimeout : heartbeatTimeout ,
196
160
_heartbeatJitter : heartbeatJitter ,
197
161
connectionTime : connectionTime ,
@@ -257,14 +221,23 @@ func (acsSession *session) Start() error {
257
221
// startSessionOnce creates a session with ACS and handles requests using the passed
258
222
// in arguments
259
223
func (acsSession * session ) startSessionOnce () error {
224
+ minAgentCfg := & wsclient.WSClientMinAgentConfig {
225
+ AcceptInsecureCert : acsSession .agentConfig .AcceptInsecureCert ,
226
+ AWSRegion : acsSession .agentConfig .AWSRegion ,
227
+ }
228
+
260
229
acsEndpoint , err := acsSession .ecsClient .DiscoverPollEndpoint (acsSession .containerInstanceARN )
261
230
if err != nil {
262
231
seelog .Errorf ("acs: unable to discover poll endpoint, err: %v" , err )
263
232
return err
264
233
}
265
234
266
- url := acsWsURL (acsEndpoint , acsSession .agentConfig .Cluster , acsSession .containerInstanceARN , acsSession .taskEngine , acsSession .resources )
267
- client := acsSession .resources .createACSClient (url , acsSession .agentConfig )
235
+ url := acsSession .acsWsURL (acsEndpoint , acsSession .agentConfig .Cluster , acsSession .containerInstanceARN )
236
+ client := acsSession .clientFactory .New (
237
+ url ,
238
+ acsSession .credentialsProvider ,
239
+ wsRWTimeout ,
240
+ minAgentCfg )
268
241
defer client .Close ()
269
242
270
243
return acsSession .startACSSession (client )
@@ -371,7 +344,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
371
344
client .SetAnyRequestHandler (anyMessageHandler (heartbeatTimer , client ))
372
345
defer heartbeatTimer .Stop ()
373
346
374
- acsSession .resources . connectedToACS ()
347
+ acsSession .sendCredentials = false
375
348
376
349
backoffResetTimer := time .AfterFunc (
377
350
retry .AddJitter (acsSession .heartbeatTimeout (), acsSession .heartbeatJitter ()), func () {
@@ -383,30 +356,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
383
356
})
384
357
defer backoffResetTimer .Stop ()
385
358
386
- serveErr := make (chan error , 1 )
387
- go func () {
388
- serveErr <- client .Serve ()
389
- }()
390
-
391
- for {
392
- select {
393
- case <- acsSession .ctx .Done ():
394
- // Stop receiving and sending messages from and to ACS when
395
- // the context received from the main function is canceled
396
- seelog .Infof ("ACS session exited cleanly." )
397
- return acsSession .ctx .Err ()
398
- case err := <- serveErr :
399
- // Stop receiving and sending messages from and to ACS when
400
- // client.Serve returns an error. This can happen when the
401
- // connection is closed by ACS or the agent
402
- if err == nil || err == io .EOF {
403
- seelog .Info ("ACS Websocket connection closed for a valid reason" )
404
- } else {
405
- seelog .Errorf ("Error: lost websocket connection with Agent Communication Service (ACS): %v" , err )
406
- }
407
- return err
408
- }
409
- }
359
+ return client .Serve (acsSession .ctx )
410
360
}
411
361
412
362
func (acsSession * session ) computeReconnectDelay (isInactiveInstance bool ) time.Duration {
@@ -438,32 +388,8 @@ func (acsSession *session) heartbeatJitter() time.Duration {
438
388
return acsSession ._heartbeatJitter
439
389
}
440
390
441
- // createACSClient creates the ACS Client using the specified URL
442
- func (acsResources * acsSessionResources ) createACSClient (url string , cfg * config.Config ) wsclient.ClientServer {
443
- return acsclient .New (url , cfg , acsResources .credentialsProvider , wsRWTimeout )
444
- }
445
-
446
- // connectedToACS records a successful connection to ACS
447
- // It sets sendCredentials to false on such an event
448
- func (acsResources * acsSessionResources ) connectedToACS () {
449
- acsResources .sendCredentials = false
450
- }
451
-
452
- // getSendCredentialsURLParameter gets the value to be set for the
453
- // 'sendCredentials' URL parameter
454
- func (acsResources * acsSessionResources ) getSendCredentialsURLParameter () string {
455
- return strconv .FormatBool (acsResources .sendCredentials )
456
- }
457
-
458
- func newSessionResources (credentialsProvider * credentials.Credentials ) sessionResources {
459
- return & acsSessionResources {
460
- credentialsProvider : credentialsProvider ,
461
- sendCredentials : true ,
462
- }
463
- }
464
-
465
391
// acsWsURL returns the websocket url for ACS given the endpoint
466
- func acsWsURL (endpoint , cluster , containerInstanceArn string , taskEngine engine. TaskEngine , acsSessionState sessionState ) string {
392
+ func ( acsSession * session ) acsWsURL (endpoint , cluster , containerInstanceArn string ) string {
467
393
acsURL := endpoint
468
394
if endpoint [len (endpoint )- 1 ] != '/' {
469
395
acsURL += "/"
@@ -476,10 +402,10 @@ func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.
476
402
query .Set ("agentVersion" , version .Version )
477
403
query .Set ("seqNum" , "1" )
478
404
query .Set ("protocolVersion" , strconv .Itoa (acsProtocolVersion ))
479
- if dockerVersion , err := taskEngine .Version (); err == nil {
405
+ if dockerVersion , err := acsSession . taskEngine .Version (); err == nil {
480
406
query .Set ("dockerVersion" , "DockerVersion: " + dockerVersion )
481
407
}
482
- query .Set (sendCredentialsURLParameterName , acsSessionState . getSendCredentialsURLParameter ( ))
408
+ query .Set (sendCredentialsURLParameterName , strconv . FormatBool ( acsSession . sendCredentials ))
483
409
return acsURL + "?" + query .Encode ()
484
410
}
485
411
0 commit comments