Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move ACS client to ecs-agent module and refactor #3710

Merged
merged 1 commit into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 27 additions & 99 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"sync"
"time"

acsclient "github.com/aws/amazon-ecs-agent/agent/acs/client"
updater "github.com/aws/amazon-ecs-agent/agent/acs/update_handler"
"github.com/aws/amazon-ecs-agent/agent/api"
"github.com/aws/amazon-ecs-agent/agent/config"
Expand All @@ -35,11 +34,11 @@ import (
"github.com/aws/amazon-ecs-agent/agent/eventhandler"
"github.com/aws/amazon-ecs-agent/agent/eventstream"
"github.com/aws/amazon-ecs-agent/agent/version"
"github.com/aws/amazon-ecs-agent/agent/wsclient"
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/cihub/seelog"
Expand Down Expand Up @@ -105,7 +104,8 @@ type session struct {
ctx context.Context
cancel context.CancelFunc
backoff retry.Backoff
resources sessionResources
clientFactory wsclient.ClientFactory
sendCredentials bool
latestSeqNumTaskManifest *int64
doctor *doctor.Doctor
_heartbeatTimeout time.Duration
Expand All @@ -115,43 +115,6 @@ type session struct {
_inactiveInstanceReconnectDelay time.Duration
}

// sessionResources defines the resource creator interface for starting
// a session with ACS. This interface is intended to define methods
// that create resources used to establish the connection to ACS
// It is confined to just the createACSClient() method for now. It can be
// extended to include the acsWsURL() and newHeartbeatTimer() methods
// when needed
// The goal is to make it easier to test and inject dependencies
type sessionResources interface {
// createACSClient creates a new websocket client
createACSClient(url string, cfg *config.Config) wsclient.ClientServer
sessionState
}

// acsSessionResources implements resource creator and session state interfaces
// to create resources needed to connect to ACS and to record session state
// for the same
type acsSessionResources struct {
credentialsProvider *credentials.Credentials
// sendCredentials is used to set the 'sendCredentials' URL parameter
// used to connect to ACS
// It is set to 'true' for the very first successful connection on
// agent start. It is set to false for all successive connections
sendCredentials bool
}

// sessionState defines state recorder interface for the
// session established with ACS. It can be used to record and
// retrieve data shared across multiple connections to ACS
type sessionState interface {
// connectedToACS callback indicates that the client has
// connected to ACS
connectedToACS()
// getSendCredentialsURLParameter retrieves the value for
// the 'sendCredentials' URL parameter
getSendCredentialsURLParameter() string
}

// NewSession creates a new Session object
func NewSession(
ctx context.Context,
Expand All @@ -168,8 +131,8 @@ func NewSession(
taskHandler *eventhandler.TaskHandler,
latestSeqNumTaskManifest *int64,
doctor *doctor.Doctor,
clientFactory wsclient.ClientFactory,
) Session {
resources := newSessionResources(credentialsProvider)
backoff := retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax,
connectionBackoffJitter, connectionBackoffMultiplier)
derivedContext, cancel := context.WithCancel(ctx)
Expand All @@ -189,9 +152,10 @@ func NewSession(
ctx: derivedContext,
cancel: cancel,
backoff: backoff,
resources: resources,
latestSeqNumTaskManifest: latestSeqNumTaskManifest,
doctor: doctor,
clientFactory: clientFactory,
sendCredentials: true,
_heartbeatTimeout: heartbeatTimeout,
_heartbeatJitter: heartbeatJitter,
connectionTime: connectionTime,
Expand Down Expand Up @@ -257,14 +221,23 @@ func (acsSession *session) Start() error {
// startSessionOnce creates a session with ACS and handles requests using the passed
// in arguments
func (acsSession *session) startSessionOnce() error {
minAgentCfg := &wsclient.WSClientMinAgentConfig{
AcceptInsecureCert: acsSession.agentConfig.AcceptInsecureCert,
AWSRegion: acsSession.agentConfig.AWSRegion,
}

acsEndpoint, err := acsSession.ecsClient.DiscoverPollEndpoint(acsSession.containerInstanceARN)
if err != nil {
seelog.Errorf("acs: unable to discover poll endpoint, err: %v", err)
return err
}

url := acsWsURL(acsEndpoint, acsSession.agentConfig.Cluster, acsSession.containerInstanceARN, acsSession.taskEngine, acsSession.resources)
client := acsSession.resources.createACSClient(url, acsSession.agentConfig)
url := acsSession.acsURL(acsEndpoint)
client := acsSession.clientFactory.New(
url,
acsSession.credentialsProvider,
wsRWTimeout,
minAgentCfg)
defer client.Close()

return acsSession.startACSSession(client)
Expand Down Expand Up @@ -371,7 +344,9 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
client.SetAnyRequestHandler(anyMessageHandler(heartbeatTimer, client))
defer heartbeatTimer.Stop()

acsSession.resources.connectedToACS()
// Connection to ACS was successful. Moving forward, rely on ACS to send credentials to Agent at its own cadence
// and make sure Agent does not force ACS to send credentials for any subsequent reconnects to ACS.
acsSession.sendCredentials = false
ohsoo marked this conversation as resolved.
Show resolved Hide resolved

backoffResetTimer := time.AfterFunc(
retry.AddJitter(acsSession.heartbeatTimeout(), acsSession.heartbeatJitter()), func() {
Expand All @@ -383,30 +358,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
})
defer backoffResetTimer.Stop()

serveErr := make(chan error, 1)
go func() {
serveErr <- client.Serve()
}()

for {
select {
case <-acsSession.ctx.Done():
// Stop receiving and sending messages from and to ACS when
// the context received from the main function is canceled
seelog.Infof("ACS session exited cleanly.")
return acsSession.ctx.Err()
case err := <-serveErr:
// Stop receiving and sending messages from and to ACS when
// client.Serve returns an error. This can happen when the
// connection is closed by ACS or the agent
if err == nil || err == io.EOF {
seelog.Info("ACS Websocket connection closed for a valid reason")
} else {
seelog.Errorf("Error: lost websocket connection with Agent Communication Service (ACS): %v", err)
}
return err
}
}
return client.Serve(acsSession.ctx)
}

func (acsSession *session) computeReconnectDelay(isInactiveInstance bool) time.Duration {
Expand Down Expand Up @@ -438,48 +390,24 @@ func (acsSession *session) heartbeatJitter() time.Duration {
return acsSession._heartbeatJitter
}

// createACSClient creates the ACS Client using the specified URL
func (acsResources *acsSessionResources) createACSClient(url string, cfg *config.Config) wsclient.ClientServer {
return acsclient.New(url, cfg, acsResources.credentialsProvider, wsRWTimeout)
}

// connectedToACS records a successful connection to ACS
// It sets sendCredentials to false on such an event
func (acsResources *acsSessionResources) connectedToACS() {
acsResources.sendCredentials = false
}

// getSendCredentialsURLParameter gets the value to be set for the
// 'sendCredentials' URL parameter
func (acsResources *acsSessionResources) getSendCredentialsURLParameter() string {
return strconv.FormatBool(acsResources.sendCredentials)
}

func newSessionResources(credentialsProvider *credentials.Credentials) sessionResources {
return &acsSessionResources{
credentialsProvider: credentialsProvider,
sendCredentials: true,
}
}

// acsWsURL returns the websocket url for ACS given the endpoint
func acsWsURL(endpoint, cluster, containerInstanceArn string, taskEngine engine.TaskEngine, acsSessionState sessionState) string {
// acsURL returns the websocket url for ACS given the endpoint
func (acsSession *session) acsURL(endpoint string) string {
acsURL := endpoint
if endpoint[len(endpoint)-1] != '/' {
acsURL += "/"
}
acsURL += "ws"
query := url.Values{}
query.Set("clusterArn", cluster)
query.Set("containerInstanceArn", containerInstanceArn)
query.Set("clusterArn", acsSession.agentConfig.Cluster)
query.Set("containerInstanceArn", acsSession.containerInstanceARN)
query.Set("agentHash", version.GitHashString())
query.Set("agentVersion", version.Version)
query.Set("seqNum", "1")
query.Set("protocolVersion", strconv.Itoa(acsProtocolVersion))
if dockerVersion, err := taskEngine.Version(); err == nil {
if dockerVersion, err := acsSession.taskEngine.Version(); err == nil {
query.Set("dockerVersion", "DockerVersion: "+dockerVersion)
}
query.Set(sendCredentialsURLParameterName, acsSessionState.getSendCredentialsURLParameter())
query.Set(sendCredentialsURLParameterName, strconv.FormatBool(acsSession.sendCredentials))
return acsURL + "?" + query.Encode()
}

Expand Down
Loading