diff --git a/api/types/installers/agentless-installer.sh.tmpl b/api/types/installers/agentless-installer.sh.tmpl index ea7228e5ef924..f30e7bb465905 100644 --- a/api/types/installers/agentless-installer.sh.tmpl +++ b/api/types/installers/agentless-installer.sh.tmpl @@ -4,15 +4,84 @@ set -o errexit set -o pipefail set -o nounset -( - flock -n 9 || exit 1 - if grep -q "Section created by 'teleport join openssh'" "$SSHD_CONFIG"; then - exit 0 +run_teleport() { + TOKEN="$1" + PRINCIPALS="$2" + LABELS="$3" + ADDRESS="$4" + + sudo /usr/local/bin/teleport join openssh \ + --openssh-config="${SSHD_CONFIG}" \ + --join-method=iam \ + --token="$TOKEN" \ + --proxy-server="{{ .PublicProxyAddr }}" \ + --additional-principals="$PRINCIPALS" \ + --labels="$LABELS" \ + --address="$ADDRESS":22 \ + --restart-sshd +} + +get_metadata_item() { + IMDS_TOKEN="$1" + ENDPOINT="$2" + + curl -m5 -sS -H "X-aws-ec2-metadata-token: ${IMDS_TOKEN}" "http://169.254.169.254/latest/meta-data/$ENDPOINT" +} + +get_principals() { + IMDS_TOKEN="$1" + + LOCAL_IP="$(get_metadata_item "$IMDS_TOKEN" local-ipv4)" + PUBLIC_IP="$(get_metadata_item "$IMDS_TOKEN" public-ipv4 || echo "")" + + PRINCIPALS="" + if [ ! "$LOCAL_IP" = "" ]; then + PRINCIPALS="$LOCAL_IP,$PRINCIPALS" + fi + if [ ! "$PUBLIC_IP" = "" ]; then + PRINCIPALS="$PUBLIC_IP,$PRINCIPALS" fi + echo "$PRINCIPALS" +} + +get_address() { + IMDS_TOKEN="$1" + + PUBLIC_IP=$(get_metadata_item "$IMDS_TOKEN" public-ipv4 || echo "") + if [ ! "$PUBLIC_IP" = "" ]; then + echo "$PUBLIC_IP" + return 0 + fi + + LOCAL_IP="$(get_metadata_item "$IMDS_TOKEN" local-ipv4)" + if [ ! "$LOCAL_IP" = "" ]; then + echo "$LOCAL_IP" + return 0 + fi + + echo "Failed to retreive an IP address to connect to, which is a required parameter" + return 1 +} + +get_labels() { + IMDS_TOKEN="$1" + + INSTANCE_INFO=$(curl -m5 -sS -H "X-aws-ec2-metadata-token: ${IMDS_TOKEN}" http://169.254.169.254/latest/dynamic/instance-identity/document) + + ACCOUNT_ID="$(echo "$INSTANCE_INFO" | jq -r .accountId)" + INSTANCE_ID="$(echo "$INSTANCE_INFO" | jq -r .instanceId)" + REGION="$(echo "$INSTANCE_INFO" | jq -r .region)" + + LABELS="teleport.dev/instance-id=${INSTANCE_ID},teleport.dev/account-id=${ACCOUNT_ID},teleport.dev/aws-region=${REGION}" + + echo "$LABELS" +} + +install_teleport() { . /etc/os-release - PACKAGE_LIST="{{ .TeleportPackage }}" + PACKAGE_LIST="jq {{ .TeleportPackage }}" if [[ "{{ .AutomaticUpgrades }}" == "true" ]]; then PACKAGE_LIST="${PACKAGE_LIST} {{ .TeleportPackage }}-updater" fi @@ -42,25 +111,21 @@ set -o nounset echo "Unsupported distro: $ID" exit 1 fi +} - IMDS_TOKEN=$(curl -m5 -sS -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 300") - LOCAL_IP=$(curl -m5 -sS -H "X-aws-ec2-metadata-token: ${IMDS_TOKEN}" http://169.254.169.254/latest/meta-data/local-ipv4) - PUBLIC_IP=$(curl -m5 -sS -H "X-aws-ec2-metadata-token: ${IMDS_TOKEN}" http://169.254.169.254/latest/meta-data/public-ipv4 || echo "") +( + flock -n 9 || exit 1 - PRINCIPALS="" - if [ ! "$LOCAL_IP" = "" ]; then - PRINCIPALS="$LOCAL_IP,$PRINCIPALS" - fi - if [ ! "$PUBLIC_IP" = "" ]; then - PRINCIPALS="$PUBLIC_IP,$PRINCIPALS" + TOKEN="$1" + + if ! test -f /usr/local/bin/teleport; then + install_teleport fi - sudo /usr/bin/teleport join openssh \ - --openssh-config="${SSHD_CONFIG}" \ - --join-method=iam \ - --token="$1" \ - --proxy-server="{{ .PublicProxyAddr }}" \ - --additional-principals="$PRINCIPALS" \ - --restart-sshd + IMDS_TOKEN=$(curl -m5 -sS -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 300") + PRINCIPALS="$(get_principals "$IMDS_TOKEN")" + LABELS="$(get_labels "$IMDS_TOKEN")" + ADDRESS="$(get_address "$IMDS_TOKEN")" + run_teleport "$TOKEN" "$PRINCIPALS" "$LABELS" "$ADDRESS" ) 9>/var/lock/teleport_install.lock diff --git a/lib/service/discovery.go b/lib/service/discovery.go index ccce5d93c6232..b95ccc6d015e6 100644 --- a/lib/service/discovery.go +++ b/lib/service/discovery.go @@ -63,6 +63,7 @@ func (process *TeleportProcess) initDiscoveryService() error { Emitter: asyncEmitter, AccessPoint: accessPoint, Log: process.log, + ClusterName: conn.ClientIdentity.ClusterName, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index bf58b64788e54..1610ff95ab93b 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -26,7 +26,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ssm" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -72,6 +71,8 @@ type Config struct { // for all discovery services. If different agents are used to discover different // sets of cloud resources, this field must be different for each set of agents. DiscoveryGroup string + // ClusterName is the name of the Teleport cluster. + ClusterName string } func (c *Config) CheckAndSetDefaults() error { @@ -124,6 +125,8 @@ type Server struct { kubeFetchers []common.Fetcher // databaseFetchers holds all database fetchers. databaseFetchers []common.Fetcher + // caRotationCh receives nodes that need to have their CAs rotated. + caRotationCh chan []types.Server } // New initializes a discovery Server @@ -169,10 +172,12 @@ func (s *Server) initAWSWatchers(matchers []services.AWSMatcher) error { // start ec2 watchers var err error if len(ec2Matchers) > 0 { - s.ec2Watcher, err = server.NewEC2Watcher(s.ctx, ec2Matchers, s.Clients) + s.caRotationCh = make(chan []types.Server) + s.ec2Watcher, err = server.NewEC2Watcher(s.ctx, ec2Matchers, s.Clients, s.caRotationCh) if err != nil { return trace.Wrap(err) } + s.ec2Installer = server.NewSSMInstaller(server.SSMInstallerConfig{ Emitter: s.Emitter, }) @@ -330,13 +335,13 @@ func (s *Server) filterExistingEC2Nodes(instances *server.EC2Instances) { return accountOK && instanceOK }) - var filtered []*ec2.Instance + var filtered []server.EC2Instance outer: for _, inst := range instances.Instances { for _, node := range nodes { match := types.MatchLabels(node, map[string]string{ types.AWSAccountIDLabel: instances.AccountID, - types.AWSInstanceIDLabel: aws.StringValue(inst.InstanceId), + types.AWSInstanceIDLabel: inst.InstanceID, }) if match { continue outer @@ -347,9 +352,9 @@ outer: instances.Instances = filtered } -func genEC2InstancesLogStr(instances []*ec2.Instance) string { - return genInstancesLogStr(instances, func(i *ec2.Instance) string { - return aws.StringValue(i.InstanceId) +func genEC2InstancesLogStr(instances []server.EC2Instance) string { + return genInstancesLogStr(instances, func(i server.EC2Instance) string { + return i.InstanceID }) } @@ -376,15 +381,17 @@ func genInstancesLogStr[T any](instances []T, getID func(T) string) string { } func (s *Server) handleEC2Instances(instances *server.EC2Instances) error { - // TODO(amk): once agentless node inventory management is - // implemented, create nodes after a successful SSM run - // TODO(gavin): support assume_role_arn for ec2. ec2Client, err := s.Clients.GetAWSSSMClient(s.ctx, instances.Region) if err != nil { return trace.Wrap(err) } - s.filterExistingEC2Nodes(instances) + // instances.Rotation is true whenever the instances received need + // to be rotated, we don't want to filter out existing OpenSSH nodes as + // they all need to have the command run on them + if !instances.Rotation { + s.filterExistingEC2Nodes(instances) + } if len(instances.Instances) == 0 { return trace.NotFound("all fetched nodes already enrolled") } @@ -403,6 +410,87 @@ func (s *Server) handleEC2Instances(instances *server.EC2Instances) error { return trace.Wrap(s.ec2Installer.Run(s.ctx, req)) } +func (s *Server) logHandleInstancesErr(err error) { + var aErr awserr.Error + if errors.As(err, &aErr) && aErr.Code() == ssm.ErrCodeInvalidInstanceId { + s.Log.WithError(err).Error("SSM SendCommand failed with ErrCodeInvalidInstanceId. Make sure that the instances have AmazonSSMManagedInstanceCore policy assigned. Also check that SSM agent is running and registered with the SSM endpoint on that instance and try restarting or reinstalling it in case of issues. See https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_SendCommand.html#API_SendCommand_Errors for more details.") + } else if trace.IsNotFound(err) { + s.Log.Debug("All discovered EC2 instances are already part of the cluster.") + } else { + s.Log.WithError(err).Error("Failed to enroll discovered EC2 instances.") + } +} + +func (s *Server) watchCARotation(ctx context.Context) { + ticker := time.NewTicker(time.Minute * 10) + defer ticker.Stop() + for { + select { + case <-ticker.C: + nodes, err := s.findUnrotatedEC2Nodes(ctx) + if err != nil { + if trace.IsNotFound(err) { + s.Log.Debug("No OpenSSH nodes require CA rotation") + continue + } + s.Log.Errorf("Error finding OpenSSH nodes requiring CA rotation: %s", err) + continue + } + s.Log.Debugf("Found %d nodes requiring rotation", len(nodes)) + s.caRotationCh <- nodes + case <-s.ctx.Done(): + return + } + } +} + +func (s *Server) getMostRecentRotationForCAs(ctx context.Context, caTypes ...types.CertAuthType) (time.Time, error) { + var mostRecentUpdate time.Time + for _, caType := range caTypes { + ca, err := s.AccessPoint.GetCertAuthority(ctx, types.CertAuthID{ + Type: caType, + DomainName: s.ClusterName, + }, false) + if err != nil { + return time.Time{}, trace.Wrap(err) + } + caRot := ca.GetRotation() + if caRot.State == types.RotationStateInProgress && caRot.Started.After(mostRecentUpdate) { + mostRecentUpdate = caRot.Started + } + + if caRot.LastRotated.After(mostRecentUpdate) { + mostRecentUpdate = caRot.LastRotated + } + } + return mostRecentUpdate, nil +} + +func (s *Server) findUnrotatedEC2Nodes(ctx context.Context) ([]types.Server, error) { + mostRecentCertRotation, err := s.getMostRecentRotationForCAs(ctx, types.OpenSSHCA, types.HostCA) + if err != nil { + return nil, trace.Wrap(err) + } + found := s.nodeWatcher.GetNodes(ctx, func(n services.Node) bool { + if n.GetSubKind() != types.SubKindOpenSSHNode { + return false + } + if _, ok := n.GetLabel(types.AWSAccountIDLabel); !ok { + return false + } + if _, ok := n.GetLabel(types.AWSInstanceIDLabel); !ok { + return false + } + + return mostRecentCertRotation.After(n.GetRotation().LastRotated) + }) + + if len(found) == 0 { + return nil, trace.NotFound("no unrotated nodes found") + } + return found, nil +} + func (s *Server) handleEC2Discovery() { if err := s.nodeWatcher.WaitInitialization(); err != nil { s.Log.WithError(err).Error("Failed to initialize nodeWatcher.") @@ -410,6 +498,8 @@ func (s *Server) handleEC2Discovery() { } go s.ec2Watcher.Run() + go s.watchCARotation(s.ctx) + for { select { case instances := <-s.ec2Watcher.InstancesC: @@ -418,14 +508,7 @@ func (s *Server) handleEC2Discovery() { instances.AccountID, genEC2InstancesLogStr(ec2Instances.Instances)) if err := s.handleEC2Instances(ec2Instances); err != nil { - var aErr awserr.Error - if errors.As(err, &aErr) && aErr.Code() == ssm.ErrCodeInvalidInstanceId { - s.Log.WithError(err).Error("SSM SendCommand failed with ErrCodeInvalidInstanceId. Make sure that the instances have AmazonSSMManagedInstanceCore policy assigned. Also check that SSM agent is running and registered with the SSM endpoint on that instance and try restarting or reinstalling it in case of issues. See https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_SendCommand.html#API_SendCommand_Errors for more details.") - } else if trace.IsNotFound(err) { - s.Log.Debug("All discovered EC2 instances are already part of the cluster.") - } else { - s.Log.WithError(err).Error("Failed to enroll discovered EC2 instances.") - } + s.logHandleInstancesErr(err) } case <-s.ctx.Done(): s.ec2Watcher.Stop() diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 5b422f31b9d26..4dd60c41f87b2 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -61,6 +61,7 @@ import ( "github.com/gravitational/teleport/lib/cloud/mocks" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/server" ) type mockSSMClient struct { @@ -303,7 +304,7 @@ func TestDiscoveryServer(t *testing.T) { logHandler: func(t *testing.T, logs io.Reader, done chan struct{}) { scanner := bufio.NewScanner(logs) instances := genEC2Instances(58) - findAll := []string{genEC2InstancesLogStr(instances[:50]), genEC2InstancesLogStr(instances[50:])} + findAll := []string{genEC2InstancesLogStr(server.ToEC2Instances(instances[:50])), genEC2InstancesLogStr(server.ToEC2Instances(instances[50:]))} index := 0 for scanner.Scan() { if index == len(findAll) { @@ -368,6 +369,9 @@ func TestDiscoveryServer(t *testing.T) { Regions: []string{"eu-central-1"}, Tags: map[string]utils.Strings{"teleport": {"yes"}}, SSM: &services.AWSSSM{DocumentName: "document"}, + Params: services.InstallerParams{ + InstallTeleport: true, + }, }}, Emitter: tc.emitter, Log: logger, diff --git a/lib/srv/server/azure_watcher.go b/lib/srv/server/azure_watcher.go index 272b2efc461f4..4671ab818c65b 100644 --- a/lib/srv/server/azure_watcher.go +++ b/lib/srv/server/azure_watcher.go @@ -111,8 +111,12 @@ func newAzureInstanceFetcher(cfg azureFetcherConfig) *azureInstanceFetcher { } } +func (*azureInstanceFetcher) GetMatchingInstances(_ []types.Server, _ bool) ([]Instances, error) { + return nil, trace.NotImplemented("not implemented for azure fetchers") +} + // GetInstances fetches all Azure virtual machines matching configured filters. -func (f *azureInstanceFetcher) GetInstances(ctx context.Context) ([]Instances, error) { +func (f *azureInstanceFetcher) GetInstances(ctx context.Context, _ bool) ([]Instances, error) { instancesByRegion := make(map[string][]*armcompute.VirtualMachine) for _, region := range f.Regions { instancesByRegion[region] = []*armcompute.VirtualMachine{} diff --git a/lib/srv/server/ec2_watcher.go b/lib/srv/server/ec2_watcher.go index c7ea92bdde473..ce299a564a1d0 100644 --- a/lib/srv/server/ec2_watcher.go +++ b/lib/srv/server/ec2_watcher.go @@ -18,6 +18,7 @@ package server import ( "context" + "sync" "time" "github.com/aws/aws-sdk-go/aws" @@ -52,18 +53,45 @@ type EC2Instances struct { // AccountID is the AWS account the instances belong to. AccountID string // Instances is a list of discovered EC2 instances - Instances []*ec2.Instance + Instances []EC2Instance + // Rotation is set so instances dont get filtered out for already + // existing in the teleport instance + Rotation bool +} + +// EC2Instance represents an AWS EC2 instance that has been +// discovered. +type EC2Instance struct { + InstanceID string +} + +func toEC2Instance(inst *ec2.Instance) EC2Instance { + return EC2Instance{ + InstanceID: aws.StringValue(inst.InstanceId), + } +} + +// ToEC2Instances converts aws []*ec2.Instance to []EC2Instance +func ToEC2Instances(insts []*ec2.Instance) []EC2Instance { + var ec2Insts []EC2Instance + + for _, inst := range insts { + ec2Insts = append(ec2Insts, toEC2Instance(inst)) + } + return ec2Insts + } // NewEC2Watcher creates a new EC2 watcher instance. -func NewEC2Watcher(ctx context.Context, matchers []services.AWSMatcher, clients cloud.Clients) (*Watcher, error) { +func NewEC2Watcher(ctx context.Context, matchers []services.AWSMatcher, clients cloud.Clients, missedRotation <-chan []types.Server) (*Watcher, error) { cancelCtx, cancelFn := context.WithCancel(ctx) watcher := Watcher{ - fetchers: []Fetcher{}, - ctx: cancelCtx, - cancel: cancelFn, - fetchInterval: time.Minute, - InstancesC: make(chan Instances), + fetchers: []Fetcher{}, + ctx: cancelCtx, + cancel: cancelFn, + fetchInterval: time.Minute, + InstancesC: make(chan Instances), + missedRotation: missedRotation, } for _, matcher := range matchers { @@ -104,8 +132,54 @@ type ec2InstanceFetcher struct { Region string DocumentName string Parameters map[string]string + + // cachedInstances keeps all of the ec2 instances that were matched + // in the last run of GetInstances for use as a cache with + // GetMatchingInstances + cachedInstances *instancesCache +} + +type instancesCache struct { + sync.Mutex + instances map[cachedInstanceKey]struct{} +} + +func (ic *instancesCache) add(accountID, instanceID string) { + ic.Lock() + defer ic.Unlock() + ic.instances[cachedInstanceKey{accountID: accountID, instanceID: instanceID}] = struct{}{} +} + +func (ic *instancesCache) clear() { + ic.Lock() + defer ic.Unlock() + ic.instances = make(map[cachedInstanceKey]struct{}) +} + +func (ic *instancesCache) exists(accountID, instanceID string) bool { + ic.Lock() + defer ic.Unlock() + _, ok := ic.instances[cachedInstanceKey{accountID: accountID, instanceID: instanceID}] + return ok +} + +type cachedInstanceKey struct { + accountID string + instanceID string } +const ( + // ParamToken is the name of the invite token parameter sent in the SSM Document + ParamToken = "token" + // ParamScriptName is the name of the Teleport install script sent in the SSM Document + ParamScriptName = "scriptName" + // ParamSSHDConfigPath is the path to the OpenSSH config file sent in the SSM Document + ParamSSHDConfigPath = "sshdConfigPath" +) + +// awsEC2APIChunkSize is the max number of instances SSM will send commands to at a time +const awsEC2APIChunkSize = 50 + func newEC2InstanceFetcher(cfg ec2FetcherConfig) *ec2InstanceFetcher { tagFilters := []*ec2.Filter{{ Name: aws.String(AWSInstanceStateName), @@ -125,14 +199,14 @@ func newEC2InstanceFetcher(cfg ec2FetcherConfig) *ec2InstanceFetcher { var parameters map[string]string if cfg.Matcher.Params.InstallTeleport { parameters = map[string]string{ - "token": cfg.Matcher.Params.JoinToken, - "scriptName": cfg.Matcher.Params.ScriptName, + ParamToken: cfg.Matcher.Params.JoinToken, + ParamScriptName: cfg.Matcher.Params.ScriptName, } } else { parameters = map[string]string{ - "token": cfg.Matcher.Params.JoinToken, - "scriptName": cfg.Matcher.Params.ScriptName, - "sshdConfigPath": cfg.Matcher.Params.SSHDConfig, + ParamToken: cfg.Matcher.Params.JoinToken, + ParamScriptName: cfg.Matcher.Params.ScriptName, + ParamSSHDConfigPath: cfg.Matcher.Params.SSHDConfig, } } @@ -142,30 +216,102 @@ func newEC2InstanceFetcher(cfg ec2FetcherConfig) *ec2InstanceFetcher { Region: cfg.Region, DocumentName: cfg.Document, Parameters: parameters, + cachedInstances: &instancesCache{ + instances: map[cachedInstanceKey]struct{}{}, + }, } return &fetcherConfig } +// GetMatchingInstances returns a list of EC2 instances from a list of matching Teleport nodes +func (f *ec2InstanceFetcher) GetMatchingInstances(nodes []types.Server, rotation bool) ([]Instances, error) { + insts := EC2Instances{ + Region: f.Region, + DocumentName: f.DocumentName, + Parameters: f.Parameters, + Rotation: rotation, + } + for _, node := range nodes { + if node.GetSubKind() != types.SubKindOpenSSHNode { + continue + } + region, ok := node.GetLabel(types.AWSInstanceRegion) + if !ok || region != f.Region { + continue + } + instID, ok := node.GetLabel(types.AWSInstanceIDLabel) + if !ok { + continue + } + accountID, ok := node.GetLabel(types.AWSAccountIDLabel) + if !ok { + continue + } + + if !f.cachedInstances.exists(accountID, instID) { + continue + } + if insts.AccountID == "" { + insts.AccountID = accountID + } + + insts.Instances = append(insts.Instances, EC2Instance{ + InstanceID: instID, + }) + } + + if len(insts.Instances) == 0 { + return nil, trace.NotFound("no ec2 instances found") + } + + return chunkInstances(insts), nil +} + +func chunkInstances(insts EC2Instances) []Instances { + var instColl []Instances + for i := 0; i < len(insts.Instances); i += awsEC2APIChunkSize { + end := i + awsEC2APIChunkSize + if end > len(insts.Instances) { + end = len(insts.Instances) + } + inst := EC2Instances{ + AccountID: insts.AccountID, + Region: insts.Region, + DocumentName: insts.DocumentName, + Parameters: insts.Parameters, + Instances: insts.Instances[i:end], + Rotation: insts.Rotation, + } + instColl = append(instColl, Instances{EC2Instances: &inst}) + } + return instColl +} + // GetInstances fetches all EC2 instances matching configured filters. -func (f *ec2InstanceFetcher) GetInstances(ctx context.Context) ([]Instances, error) { +func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([]Instances, error) { var instances []Instances + f.cachedInstances.clear() err := f.EC2.DescribeInstancesPagesWithContext(ctx, &ec2.DescribeInstancesInput{ Filters: f.Filters, }, func(dio *ec2.DescribeInstancesOutput, b bool) bool { - const chunkSize = 50 // max number of instances SSM will send commands to at a time for _, res := range dio.Reservations { - for i := 0; i < len(res.Instances); i += chunkSize { - end := i + chunkSize + for i := 0; i < len(res.Instances); i += awsEC2APIChunkSize { + end := i + awsEC2APIChunkSize if end > len(res.Instances) { end = len(res.Instances) } + ownerID := aws.StringValue(res.OwnerId) inst := EC2Instances{ - AccountID: aws.StringValue(res.OwnerId), + AccountID: ownerID, Region: f.Region, DocumentName: f.DocumentName, - Instances: res.Instances[i:end], + Instances: ToEC2Instances(res.Instances[i:end]), Parameters: f.Parameters, + Rotation: rotation, + } + for _, ec2inst := range res.Instances[i:end] { + f.cachedInstances.add(ownerID, aws.StringValue(ec2inst.InstanceId)) } instances = append(instances, Instances{EC2Instances: &inst}) } diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go index 22b075ca6ff8e..83a61c51efcc2 100644 --- a/lib/srv/server/ec2_watcher_test.go +++ b/lib/srv/server/ec2_watcher_test.go @@ -216,7 +216,7 @@ func TestEC2Watcher(t *testing.T) { }}, } clients.ec2Client.output = &output - watcher, err := NewEC2Watcher(ctx, matchers, &clients) + watcher, err := NewEC2Watcher(ctx, matchers, &clients, make(<-chan []types.Server)) require.NoError(t, err) go watcher.Run() @@ -224,13 +224,13 @@ func TestEC2Watcher(t *testing.T) { result := <-watcher.InstancesC require.Equal(t, EC2Instances{ Region: "us-west-2", - Instances: []*ec2.Instance{&present}, + Instances: []EC2Instance{toEC2Instance(&present)}, Parameters: map[string]string{"token": "", "scriptName": ""}, }, *result.EC2Instances) result = <-watcher.InstancesC require.Equal(t, EC2Instances{ Region: "us-west-2", - Instances: []*ec2.Instance{&presentOther}, + Instances: []EC2Instance{toEC2Instance(&presentOther)}, Parameters: map[string]string{"token": "", "scriptName": ""}, }, *result.EC2Instances) } diff --git a/lib/srv/server/ssm_install.go b/lib/srv/server/ssm_install.go index f817059bce537..9028d8d4685dd 100644 --- a/lib/srv/server/ssm_install.go +++ b/lib/srv/server/ssm_install.go @@ -23,7 +23,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/gravitational/trace" @@ -53,7 +52,7 @@ type SSMRunRequest struct { SSM ssmiface.SSMAPI // Instances is the list of instances that will have the SSM // document executed on them. - Instances []*ec2.Instance + Instances []EC2Instance // Params is a list of parameters to include when executing the // SSM document. Params map[string]string @@ -75,7 +74,7 @@ func NewSSMInstaller(cfg SSMInstallerConfig) *SSMInstaller { func (si *SSMInstaller) Run(ctx context.Context, req SSMRunRequest) error { ids := make([]string, 0, len(req.Instances)) for _, inst := range req.Instances { - ids = append(ids, aws.StringValue(inst.InstanceId)) + ids = append(ids, inst.InstanceID) } params := make(map[string][]*string) diff --git a/lib/srv/server/ssm_install_test.go b/lib/srv/server/ssm_install_test.go index b1c274aaba910..75ed30d0cb299 100644 --- a/lib/srv/server/ssm_install_test.go +++ b/lib/srv/server/ssm_install_test.go @@ -23,7 +23,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/stretchr/testify/require" @@ -74,8 +73,8 @@ func TestSSMInstaller(t *testing.T) { { name: "ssm run was successful", req: SSMRunRequest{ - Instances: []*ec2.Instance{ - {InstanceId: aws.String("instance-id-1")}, + Instances: []EC2Instance{ + {InstanceID: "instance-id-1"}, }, DocumentName: document, Params: map[string]string{"token": "abcdefg"}, @@ -115,8 +114,8 @@ func TestSSMInstaller(t *testing.T) { name: "ssm run failed", req: SSMRunRequest{ DocumentName: document, - Instances: []*ec2.Instance{ - {InstanceId: aws.String("instance-id-1")}, + Instances: []EC2Instance{ + {InstanceID: "instance-id-1"}, }, Params: map[string]string{"token": "abcdefg"}, SSM: &mockSSMClient{ diff --git a/lib/srv/server/watcher.go b/lib/srv/server/watcher.go index 1630d61f247be..819e2a5faaf96 100644 --- a/lib/srv/server/watcher.go +++ b/lib/srv/server/watcher.go @@ -22,6 +22,8 @@ import ( "github.com/gravitational/trace" log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/types" ) // Instances contains information about discovered cloud instances from any provider. @@ -33,13 +35,17 @@ type Instances struct { // Fetcher fetches instances from a particular cloud provider. type Fetcher interface { // GetInstances gets a list of cloud instances. - GetInstances(context.Context) ([]Instances, error) + GetInstances(ctx context.Context, rotation bool) ([]Instances, error) + // GetMatchingInstances finds Instances from the list of nodes + // that the fetcher matches. + GetMatchingInstances(nodes []types.Server, rotation bool) ([]Instances, error) } // Watcher allows callers to discover cloud instances matching specified filters. type Watcher struct { // InstancesC can be used to consume newly discovered instances. - InstancesC chan Instances + InstancesC chan Instances + missedRotation <-chan []types.Server fetchers []Fetcher fetchInterval time.Duration @@ -47,30 +53,44 @@ type Watcher struct { cancel context.CancelFunc } +func (w *Watcher) sendInstancesOrLogError(instancesColl []Instances, err error) { + if err != nil { + if trace.IsNotFound(err) { + return + } + log.WithError(err).Error("Failed to fetch instances") + return + } + for _, inst := range instancesColl { + select { + case w.InstancesC <- inst: + case <-w.ctx.Done(): + } + } +} + // Run starts the watcher's main watch loop. func (w *Watcher) Run() { + if len(w.fetchers) == 0 { + return + } ticker := time.NewTicker(w.fetchInterval) defer ticker.Stop() + + for _, fetcher := range w.fetchers { + w.sendInstancesOrLogError(fetcher.GetInstances(w.ctx, false)) + } + for { - for _, fetcher := range w.fetchers { - instancesColl, err := fetcher.GetInstances(w.ctx) - if err != nil { - if trace.IsNotFound(err) { - continue - } - log.WithError(err).Error("Failed to fetch instances") - continue - } - for _, inst := range instancesColl { - select { - case w.InstancesC <- inst: - case <-w.ctx.Done(): - } - } - } select { + case insts := <-w.missedRotation: + for _, fetcher := range w.fetchers { + w.sendInstancesOrLogError(fetcher.GetMatchingInstances(insts, true)) + } case <-ticker.C: - continue + for _, fetcher := range w.fetchers { + w.sendInstancesOrLogError(fetcher.GetInstances(w.ctx, false)) + } case <-w.ctx.Done(): return } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 968a93b200a7c..2c25771a26703 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -2496,7 +2496,7 @@ echo AutomaticUpgrades: {{ .AutomaticUpgrades }} require.Contains(t, responseString, "stable/cloud") require.NotContains(t, responseString, "stable/v") require.Contains(t, responseString, ""+ - " PACKAGE_LIST=\"teleport-ent\"\n"+ + " PACKAGE_LIST=\"jq teleport-ent\"\n"+ " if [[ \"true\" == \"true\" ]]; then\n"+ " PACKAGE_LIST=\"${PACKAGE_LIST} teleport-ent-updater\"\n"+ " fi\n",