Skip to content

Commit

Permalink
Move ACS client to ecs-agent module and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
danehlim committed May 25, 2023
1 parent c311434 commit b8ea04c
Show file tree
Hide file tree
Showing 36 changed files with 1,694 additions and 234 deletions.
126 changes: 27 additions & 99 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"sync"
"time"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
"github.com/aws/amazon-ecs-agent/agent/api"
"github.com/aws/amazon-ecs-agent/agent/config"
Expand All @@ -36,10 +35,10 @@ import (
"github.com/aws/amazon-ecs-agent/agent/eventstream"
"github.com/aws/amazon-ecs-agent/agent/utils/retry"
"github.com/aws/amazon-ecs-agent/agent/version"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/cihub/seelog"
Expand Down Expand Up @@ -105,7 +104,8 @@ type session struct {
ctx context.Context
cancel context.CancelFunc
backoff retry.Backoff
resources sessionResources
clientFactory wsclient.ClientFactory
sendCredentials bool
latestSeqNumTaskManifest *int64
doctor *doctor.Doctor
_heartbeatTimeout time.Duration
Expand All @@ -115,43 +115,6 @@ type session struct {
_inactiveInstanceReconnectDelay time.Duration
}

// sessionResources defines the resource creator interface for starting
// a session with ACS. This interface is intended to define methods
// that create resources used to establish the connection to ACS
// It is confined to just the createACSClient() method for now. It can be
// extended to include the acsWsURL() and newHeartbeatTimer() methods
// when needed
// The goal is to make it easier to test and inject dependencies
type sessionResources interface {
// createACSClient creates a new websocket client
createACSClient(url string, cfg *config.Config) wsclient.ClientServer
sessionState
}

// acsSessionResources implements resource creator and session state interfaces
// to create resources needed to connect to ACS and to record session state
// for the same
type acsSessionResources struct {
credentialsProvider *credentials.Credentials
// sendCredentials is used to set the 'sendCredentials' URL parameter
// used to connect to ACS
// It is set to 'true' for the very first successful connection on
// agent start. It is set to false for all successive connections
sendCredentials bool
}

// sessionState defines state recorder interface for the
// session established with ACS. It can be used to record and
// retrieve data shared across multiple connections to ACS
type sessionState interface {
// connectedToACS callback indicates that the client has
// connected to ACS
connectedToACS()
// getSendCredentialsURLParameter retrieves the value for
// the 'sendCredentials' URL parameter
getSendCredentialsURLParameter() string
}

// NewSession creates a new Session object
func NewSession(
ctx context.Context,
Expand All @@ -168,8 +131,8 @@ func NewSession(
taskHandler *eventhandler.TaskHandler,
latestSeqNumTaskManifest *int64,
doctor *doctor.Doctor,
clientFactory wsclient.ClientFactory,
) Session {
resources := newSessionResources(credentialsProvider)
backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax,
connectionBackoffJitter, connectionBackoffMultiplier)
derivedContext, cancel := context.WithCancel(ctx)
Expand All @@ -189,9 +152,10 @@ func NewSession(
ctx: derivedContext,
cancel: cancel,
backoff: backoff,
resources: resources,
latestSeqNumTaskManifest: latestSeqNumTaskManifest,
doctor: doctor,
clientFactory: clientFactory,
sendCredentials: true,
_heartbeatTimeout: heartbeatTimeout,
_heartbeatJitter: heartbeatJitter,
connectionTime: connectionTime,
Expand Down Expand Up @@ -257,14 +221,23 @@ func (acsSession *session) Start() error {
// startSessionOnce creates a session with ACS and handles requests using the passed
// in arguments
func (acsSession *session) startSessionOnce() error {
minAgentCfg := &wsclient.WSClientMinAgentConfig{
AcceptInsecureCert: acsSession.agentConfig.AcceptInsecureCert,
AWSRegion: acsSession.agentConfig.AWSRegion,
}

acsEndpoint, err := acsSession.ecsClient.DiscoverPollEndpoint(acsSession.containerInstanceARN)
if err != nil {
seelog.Errorf("acs: unable to discover poll endpoint, err: %v", err)
return err
}

url := acsWsURL(acsEndpoint, acsSession.agentConfig.Cluster, acsSession.containerInstanceARN, acsSession.taskEngine, acsSession.resources)
client := acsSession.resources.createACSClient(url, acsSession.agentConfig)
url := acsSession.acsURL(acsEndpoint)
client := acsSession.clientFactory.New(
url,
acsSession.credentialsProvider,
wsRWTimeout,
minAgentCfg)
defer client.Close()

return acsSession.startACSSession(client)
Expand Down Expand Up @@ -371,7 +344,9 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client))
defer heartbeatTimer.Stop()

acsSession.resources.connectedToACS()
// Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence
// and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS.
acsSession.sendCredentials = false

backoffResetTimer := time.AfterFunc(
retry.AddJitter(acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()), func() {
Expand All @@ -383,30 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
})
defer backoffResetTimer.Stop()

serveErr := make(chan error, 1)
go func() {
serveErr <- client.Serve()
}()

for {
select {
case <-acsSession.ctx.Done():
// Stop receiving and sending messages from and to ACS when
// the context received from the main function is canceled
seelog.Infof("ACS session exited cleanly.")
return acsSession.ctx.Err()
case err := <-serveErr:
// Stop receiving and sending messages from and to ACS when
// client.Serve returns an error. This can happen when the
// connection is closed by ACS or the agent
if err == nil || err == io.EOF {
seelog.Info("ACS Websocket connection closed for a valid reason")
} else {
seelog.Errorf("Error: lost websocket connection with Agent Communication Service (ACS): %v", err)
}
return err
}
}
return client.Serve(acsSession.ctx)
}

func (acsSession *session) computeReconnectDelay(isInactiveInstance bool) time.Duration {
Expand Down Expand Up @@ -438,48 +390,24 @@ func (acsSession *session) heartbeatJitter() time.Duration {
return acsSession._heartbeatJitter
}

// createACSClient creates the ACS Client using the specified URL
func (acsResources *acsSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer {
return acsclient.New(url, cfg, acsResources.credentialsProvider, wsRWTimeout)
}

// connectedToACS records a successful connection to ACS
// It sets sendCredentials to false on such an event
func (acsResources *acsSessionResources) connectedToACS() {
acsResources.sendCredentials = false
}

// getSendCredentialsURLParameter gets the value to be set for the
// 'sendCredentials' URL parameter
func (acsResources *acsSessionResources) getSendCredentialsURLParameter() string {
return strconv.FormatBool(acsResources.sendCredentials)
}

func newSessionResources(credentialsProvider *credentials.Credentials) sessionResources {
return &acsSessionResources{
credentialsProvider: credentialsProvider,
sendCredentials: true,
}
}

// acsWsURL returns the websocket url for ACS given the endpoint
func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.TaskEngine, acsSessionState sessionState) string {
// acsURL returns the websocket url for ACS given the endpoint
func (acsSession *session) acsURL(endpoint string) string {
acsURL := endpoint
if endpoint[len(endpoint)-1] != '/' {
acsURL += "/"
}
acsURL += "ws"
query := url.Values{}
query.Set("clusterArn", cluster)
query.Set("containerInstanceArn", containerInstanceArn)
query.Set("clusterArn", acsSession.agentConfig.Cluster)
query.Set("containerInstanceArn", acsSession.containerInstanceARN)
query.Set("agentHash", version.GitHashString())
query.Set("agentVersion", version.Version)
query.Set("seqNum", "1")
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
if dockerVersion, err := taskEngine.Version(); err == nil {
if dockerVersion, err := acsSession.taskEngine.Version(); err == nil {
query.Set("dockerVersion", "DockerVersion: "+dockerVersion)
}
query.Set(sendCredentialsURLParameterName, acsSessionState.getSendCredentialsURLParameter())
query.Set(sendCredentialsURLParameterName, strconv.FormatBool(acsSession.sendCredentials))
return acsURL + "?" + query.Encode()
}

Expand Down
Loading

0 comments on commit b8ea04c

Please sign in to comment.