From ef8820957bd05ea35df4e90ca9462712607a848f Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Wed, 6 Sep 2023 21:05:13 -0700 Subject: [PATCH 01/20] Add support for ecs agent auto updates --- lib/auth/integration/integrationv1/awsoidc.go | 79 ++++- lib/automaticupgrades/config.go | 8 + lib/automaticupgrades/version.go | 70 ++++- lib/automaticupgrades/version_test.go | 121 +++++++- lib/cloud/aws/policy_statements.go | 3 +- lib/integrations/awsoidc/deployservice.go | 23 +- .../awsoidc/deployservice_update.go | 284 ++++++++++++++++++ .../awsoidc/deployservice_update_test.go | 123 ++++++++ lib/service/awsoidc.go | 245 +++++++++++++++ lib/service/awsoidc_test.go | 67 +++++ lib/service/service.go | 5 + .../upgradewindow/upgradewindow.go | 3 + lib/web/join_tokens.go | 2 +- 13 files changed, 1006 insertions(+), 27 deletions(-) create mode 100644 lib/integrations/awsoidc/deployservice_update.go create mode 100644 lib/integrations/awsoidc/deployservice_update_test.go create mode 100644 lib/service/awsoidc.go create mode 100644 lib/service/awsoidc_test.go diff --git a/lib/auth/integration/integrationv1/awsoidc.go b/lib/auth/integration/integrationv1/awsoidc.go index 28c6a0827b8a1..56309df2f393a 100644 --- a/lib/auth/integration/integrationv1/awsoidc.go +++ b/lib/auth/integration/integrationv1/awsoidc.go @@ -21,6 +21,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" integrationpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1" "github.com/gravitational/teleport/api/types" @@ -29,6 +30,9 @@ import ( "github.com/gravitational/teleport/lib/services" ) +// defaultTokenTTL is the default TTL for AWS OIDC tokens +const defaultTokenTTL = time.Minute + // GenerateAWSOIDCToken generates a token to be used when executing an AWS OIDC Integration action. func (s *Service) GenerateAWSOIDCToken(ctx context.Context, req *integrationpb.GenerateAWSOIDCTokenRequest) (*integrationpb.GenerateAWSOIDCTokenResponse, error) { _, err := authz.AuthorizeWithVerbs(ctx, s.logger, s.authorizer, true, types.KindIntegration, types.VerbUse) @@ -41,42 +45,89 @@ func (s *Service) GenerateAWSOIDCToken(ctx context.Context, req *integrationpb.G return nil, trace.Wrap(err) } - clusterName, err := s.caGetter.GetDomainName() + token, err := GenerateAWSOIDCToken(ctx, AWSOIDCTokenConfig{ + CAGetter: s.caGetter, + Clock: s.clock, + Issuer: req.Issuer, + Username: username, + }) if err != nil { return nil, trace.Wrap(err) } - ca, err := s.caGetter.GetCertAuthority(ctx, types.CertAuthID{ + return &integrationpb.GenerateAWSOIDCTokenResponse{ + Token: token, + }, nil +} + +// AWSOIDCTokenConfig contains configuration to be used when generating an AWS OIDC token. +type AWSOIDCTokenConfig struct { + CAGetter CAGetter + Clock clockwork.Clock + // TTL is the time to live for the token + TTL time.Duration + // Issuer is the issuer of the token. + Issuer string + // Username is the Teleport identity. + Username string +} + +func (c *AWSOIDCTokenConfig) checkAndSetDefaults() error { + if c.CAGetter == nil { + return trace.BadParameter("ca getter is required") + } + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + if c.TTL == 0 { + c.TTL = defaultTokenTTL + } + if c.Issuer == "" { + return trace.BadParameter("issuer is required") + } + return nil +} + +// GenerateToken generates a token to be used when executing an AWS OIDC Integration action. +func GenerateAWSOIDCToken(ctx context.Context, config AWSOIDCTokenConfig) (string, error) { + if err := config.checkAndSetDefaults(); err != nil { + return "", trace.Wrap(err) + } + + clusterName, err := config.CAGetter.GetDomainName() + if err != nil { + return "", trace.Wrap(err) + } + + ca, err := config.CAGetter.GetCertAuthority(ctx, types.CertAuthID{ Type: types.OIDCIdPCA, DomainName: clusterName, }, true) if err != nil { - return nil, trace.Wrap(err) + return "", trace.Wrap(err) } // Extract the JWT signing key and sign the claims. - signer, err := s.caGetter.GetKeyStore().GetJWTSigner(ctx, ca) + signer, err := config.CAGetter.GetKeyStore().GetJWTSigner(ctx, ca) if err != nil { - return nil, trace.Wrap(err) + return "", trace.Wrap(err) } - privateKey, err := services.GetJWTSigner(signer, ca.GetClusterName(), s.clock) + privateKey, err := services.GetJWTSigner(signer, ca.GetClusterName(), config.Clock) if err != nil { - return nil, trace.Wrap(err) + return "", trace.Wrap(err) } token, err := privateKey.SignAWSOIDC(jwt.SignParams{ - Username: username, + Username: config.Username, Audience: types.IntegrationAWSOIDCAudience, Subject: types.IntegrationAWSOIDCSubject, - Issuer: req.Issuer, - Expires: s.clock.Now().Add(time.Minute), + Issuer: config.Issuer, + Expires: config.Clock.Now().Add(config.TTL), }) if err != nil { - return nil, trace.Wrap(err) + return "", trace.Wrap(err) } - return &integrationpb.GenerateAWSOIDCTokenResponse{ - Token: token, - }, nil + return token, nil } diff --git a/lib/automaticupgrades/config.go b/lib/automaticupgrades/config.go index 1333df6935283..21460562ea4a5 100644 --- a/lib/automaticupgrades/config.go +++ b/lib/automaticupgrades/config.go @@ -26,6 +26,9 @@ import ( const ( // automaticUpgradesEnvar defines the env var to lookup when deciding whether to enable AutomaticUpgrades feature. automaticUpgradesEnvar = "TELEPORT_AUTOMATIC_UPGRADES" + + // automaticUpgradesChannelEnvar defines a customer automatic upgrades version release channel. + automaticUpgradesChannelEnvar = "TELEPORT_AUTOMATIC_UPGRADES_CHANNEL" ) // IsEnabled reads the TELEPORT_AUTOMATIC_UPGRADES and returns whether Automatic Upgrades are enabled or disabled. @@ -46,3 +49,8 @@ func IsEnabled() bool { return automaticUpgrades } + +// GetChannel returns the TELEPORT_AUTOMATIC_UPGRADES_CHANNEL value. +func GetChannel() string { + return os.Getenv(automaticUpgradesChannelEnvar) +} diff --git a/lib/automaticupgrades/version.go b/lib/automaticupgrades/version.go index 607d32ef9bbd9..5949f3e7e65d1 100644 --- a/lib/automaticupgrades/version.go +++ b/lib/automaticupgrades/version.go @@ -34,22 +34,21 @@ const ( // stableCloudVersionPath is the URL path that returns the current stable/cloud version. stableCloudVersionPath = "/v1/stable/cloud/version" + + // stableCloudCriticalPath is the URL path that returns the stable/cloud critical flag. + stableCloudCriticalPath = "/v1/stable/cloud/critical" ) // Version returns the version that should be used for installing Teleport Services // This is used when installing agents using scripts. // Even when Teleport Auth/Proxy is using vX, the agents must always respect this version. -func Version(ctx context.Context, baseURL string) (string, error) { - if baseURL == "" { - baseURL = stableCloudVersionBaseURL - } - - fullURL, err := url.JoinPath(baseURL, stableCloudVersionPath) +func Version(ctx context.Context, versionURL string) (string, error) { + versionURL, err := getVersionURL(versionURL) if err != nil { return "", trace.Wrap(err) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL, nil) if err != nil { return "", trace.Wrap(err) } @@ -73,3 +72,60 @@ func Version(ctx context.Context, baseURL string) (string, error) { return versionString, trace.Wrap(err) } + +// getVersionURL returns the versionURL or the default stable/cloud version url. +func getVersionURL(versionURL string) (string, error) { + if versionURL != "" { + return versionURL, nil + } + cloudStableVersionURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudVersionPath) + if err != nil { + return "", trace.Wrap(err) + } + return cloudStableVersionURL, nil +} + +// Critical returns true if a critical upgrade is available. +func Critical(ctx context.Context, criticalURL string) (bool, error) { + criticalURL, err := getCriticalURL(criticalURL) + if err != nil { + return false, trace.Wrap(err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, criticalURL, nil) + if err != nil { + return false, trace.Wrap(err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false, trace.Wrap(err) + } + defer resp.Body.Close() + + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) + if err != nil { + return false, trace.Wrap(err) + } + + if resp.StatusCode != http.StatusOK { + return false, trace.BadParameter("invalid status code %d, body: %s", resp.StatusCode, string(body)) + } + + critical := strings.TrimSpace(string(body)) + + // The critical endpoint returns either the string "yes" or "no" + return critical == "yes", nil +} + +// getCriticalURL returns the criticalURL or the default stable/cloud critical url. +func getCriticalURL(criticalURL string) (string, error) { + if criticalURL != "" { + return criticalURL, nil + } + cloudStableCriticalURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudCriticalPath) + if err != nil { + return "", trace.Wrap(err) + } + return cloudStableCriticalURL, nil +} diff --git a/lib/automaticupgrades/version_test.go b/lib/automaticupgrades/version_test.go index e540dbc1cb34b..4ba8d299201e3 100644 --- a/lib/automaticupgrades/version_test.go +++ b/lib/automaticupgrades/version_test.go @@ -20,6 +20,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/gravitational/trace" @@ -68,19 +69,133 @@ func TestVersion(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.URL.Path, "/v1/stable/cloud/version") + assert.Equal(t, "/v1/stable/cloud/version", r.URL.Path) w.WriteHeader(tt.mockStatusCode) w.Write([]byte(tt.mockResponseString)) })) defer httpTestServer.Close() - v, err := Version(ctx, httpTestServer.URL) + versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") + require.NoError(t, err) + + v, err := Version(ctx, versionURL) tt.errCheck(t, err) if err != nil { return } - require.Equal(t, v, tt.expectedVersion) + require.Equal(t, tt.expectedVersion, v) + }) + } +} + +func TestCritical(t *testing.T) { + ctx := context.Background() + + isBadParameterErr := func(tt require.TestingT, err error, i ...any) { + require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) + } + + for _, tt := range []struct { + name string + mockStatusCode int + mockResponseString string + errCheck require.ErrorAssertionFunc + expectedCritical bool + }{ + { + name: "critical available", + mockStatusCode: http.StatusOK, + mockResponseString: "yes\n", + errCheck: require.NoError, + expectedCritical: true, + }, + { + name: "critical is not available", + mockStatusCode: http.StatusOK, + mockResponseString: "no\n", + errCheck: require.NoError, + expectedCritical: false, + }, + { + name: "invalid status code (500)", + mockStatusCode: http.StatusInternalServerError, + errCheck: isBadParameterErr, + }, + { + name: "invalid status code (403)", + mockStatusCode: http.StatusForbidden, + errCheck: isBadParameterErr, + }, + } { + t.Run(tt.name, func(t *testing.T) { + httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/stable/cloud/critical", r.URL.Path) + w.WriteHeader(tt.mockStatusCode) + w.Write([]byte(tt.mockResponseString)) + })) + defer httpTestServer.Close() + + criticalURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/critical") + require.NoError(t, err) + + v, err := Critical(ctx, criticalURL) + tt.errCheck(t, err) + if err != nil { + return + } + + require.Equal(t, tt.expectedCritical, v) + }) + } +} + +func TestGetVersionURL(t *testing.T) { + for _, tt := range []struct { + name string + versionURL string + expectedURL string + }{ + { + name: "default stable/cloud version url", + versionURL: "", + expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/version", + }, + { + name: "custom version url", + versionURL: "https://custom.dev/version", + expectedURL: "https://custom.dev/version", + }, + } { + t.Run(tt.name, func(t *testing.T) { + v, err := getVersionURL(tt.versionURL) + require.NoError(t, err) + require.Equal(t, tt.expectedURL, v) + }) + } +} + +func TestGetCriticalURL(t *testing.T) { + for _, tt := range []struct { + name string + criticalURL string + expectedURL string + }{ + { + name: "default stable/cloud critical url", + criticalURL: "", + expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/critical", + }, + { + name: "custom critical url", + criticalURL: "https://custom.dev/critical", + expectedURL: "https://custom.dev/critical", + }, + } { + t.Run(tt.name, func(t *testing.T) { + v, err := getCriticalURL(tt.criticalURL) + require.NoError(t, err) + require.Equal(t, tt.expectedURL, v) }) } } diff --git a/lib/cloud/aws/policy_statements.go b/lib/cloud/aws/policy_statements.go index 74587bdc036b9..80becc659e487 100644 --- a/lib/cloud/aws/policy_statements.go +++ b/lib/cloud/aws/policy_statements.go @@ -48,7 +48,8 @@ func StatementForECSManageService() *Statement { Actions: []string{ "ecs:DescribeClusters", "ecs:CreateCluster", "ecs:PutClusterCapacityProviders", "ecs:DescribeServices", "ecs:CreateService", "ecs:UpdateService", - "ecs:RegisterTaskDefinition", + "ecs:RegisterTaskDefinition", "ecs:ListClusters", "ecs:ListServices", + "ecs:DescribeTaskDefinition", "ecs:DeregisterTaskDefinition", // EC2 DescribeSecurityGroups is required so that the user can list the SG and then pick which ones they want to apply to the ECS Service. "ec2:DescribeSecurityGroups", diff --git a/lib/integrations/awsoidc/deployservice.go b/lib/integrations/awsoidc/deployservice.go index 4d8e26a1894e3..47ace675266a1 100644 --- a/lib/integrations/awsoidc/deployservice.go +++ b/lib/integrations/awsoidc/deployservice.go @@ -267,6 +267,10 @@ type DeployServiceResponse struct { // DeployServiceClient describes the required methods to Deploy a Teleport Service. type DeployServiceClient interface { + // ListClusters lists ECS Clusters + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.ListClusters + ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) + // DescribeClusters lists ECS Clusters. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DescribeClusters DescribeClusters(ctx context.Context, params *ecs.DescribeClustersInput, optFns ...func(*ecs.Options)) (*ecs.DescribeClustersOutput, error) @@ -279,6 +283,10 @@ type DeployServiceClient interface { // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.PutClusterCapacityProviders PutClusterCapacityProviders(ctx context.Context, params *ecs.PutClusterCapacityProvidersInput, optFns ...func(*ecs.Options)) (*ecs.PutClusterCapacityProvidersOutput, error) + // ListServices returns a list of services + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.ListServices + ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) + // DescribeServices lists the matching Services of a given Cluster. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DescribeServices DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) @@ -291,10 +299,18 @@ type DeployServiceClient interface { // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.CreateService CreateService(ctx context.Context, params *ecs.CreateServiceInput, optFns ...func(*ecs.Options)) (*ecs.CreateServiceOutput, error) + // DescribeTaskDefinition describes the task definition. + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DescribeTaskDefinition + DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) + // RegisterTaskDefinition registers a new task definition from the supplied family and containerDefinitions. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.RegisterTaskDefinition RegisterTaskDefinition(ctx context.Context, params *ecs.RegisterTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.RegisterTaskDefinitionOutput, error) + // DeregisterTaskDefinition deregisters the task definition. + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DeregisterTaskDefinition + DeregisterTaskDefinition(ctx context.Context, params *ecs.DeregisterTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DeregisterTaskDefinitionOutput, error) + // TokenService are the required methods to manage the IAM Join Token. // When the deployed service connects to the cluster, it will use the IAM Join method. // Before deploying the service, it must ensure that the token exists and has the appropriate token rul. @@ -456,11 +472,16 @@ func upsertTask(ctx context.Context, clt DeployServiceClient, req DeployServiceR Value: aws.String("true"), }}, Command: []string{ + // --rewrite 15:3 rewrites SIGTERM -> SIGQUIT. This enables graceful shutdown of teleport + "--rewrite", + "15:3", + "--", + "teleport", "start", "--config-string", configB64, }, - EntryPoint: []string{"teleport"}, + EntryPoint: []string{"/usr/bin/dumb-init"}, Image: &taskAgentContainerImage, Name: &taskAgentContainerName, LogConfiguration: &ecsTypes.LogConfiguration{ diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go new file mode 100644 index 0000000000000..960b030685de3 --- /dev/null +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -0,0 +1,284 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package awsoidc + +import ( + "context" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/trace" +) + +// waitDuration specifies the amount of time to wait for a service to become healthy after an update. +const waitDuration = time.Minute * 5 + +// AWSRegionsList is the list of available AWS regions +// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html +var AWSRegionsList = []string{ + "us-east-2", + "us-east-1", + "us-west-1", + "us-west-2", + "af-south-1", + "ap-east-1", + "ap-south-2", + "ap-southeast-3", + "ap-southeast-4", + "ap-south-1", + "ap-northeast-3", + "ap-northeast-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-northeast-1", + "ca-central-1", + "eu-central-1", + "eu-west-1", + "eu-west-2", + "eu-south-1", + "eu-west-3", + "eu-south-2", + "eu-north-1", + "eu-central-2", + "me-south-1", + "me-central-1", + "sa-east-1", +} + +func listManagedClusters(ctx context.Context, clt DeployServiceClient, ownershipTags AWSTags) (clusterARNs []string, err error) { + listClustersOut, err := clt.ListClusters(ctx, &ecs.ListClustersInput{}) + if err != nil { + return nil, trace.Wrap(err) + } + + describeClustersOut, err := clt.DescribeClusters(ctx, &ecs.DescribeClustersInput{ + Clusters: listClustersOut.ClusterArns, + Include: []ecsTypes.ClusterField{ + ecsTypes.ClusterFieldTags, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, cluster := range describeClustersOut.Clusters { + if ownershipTags.MatchesECSTags(cluster.Tags) { + clusterARNs = append(clusterARNs, *cluster.ClusterArn) + } + } + return clusterARNs, nil +} + +func getManagedService(ctx context.Context, clt DeployServiceClient, clusterARN string, ownershipTags AWSTags) (*ecsTypes.Service, error) { + listServicesOut, err := clt.ListServices(ctx, &ecs.ListServicesInput{ + Cluster: aws.String(clusterARN), + LaunchType: ecsTypes.LaunchTypeFargate, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + describeServicesOut, err := clt.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Cluster: aws.String(clusterARN), + Services: listServicesOut.ServiceArns, + Include: []ecsTypes.ServiceField{ecsTypes.ServiceFieldTags}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if len(describeServicesOut.Services) != 1 { + return nil, trace.BadParameter("expected 1 service, but got %d", len(describeServicesOut.Services)) + } + service := describeServicesOut.Services[0] + + if !ownershipTags.MatchesECSTags(service.Tags) { + return nil, trace.Errorf("ECS Service %q already exists but is not managed by Teleport. "+ + "Add the following tags to allow Teleport to manage this service: %s", aws.ToString(service.ServiceName), ownershipTags) + } + // If the LaunchType is the required one, than we can update the current Service. + // Otherwise we have to delete it. + if service.LaunchType != ecsTypes.LaunchTypeFargate { + return nil, trace.Errorf("ECS Service %q already exists but has an invalid LaunchType %q. Delete the Service and try again.", aws.ToString(service.ServiceName), service.LaunchType) + } + + return &service, nil +} + +func getManagedTaskDefinition(ctx context.Context, clt DeployServiceClient, taskDefinitionName string, ownershipTags AWSTags) (*ecsTypes.TaskDefinition, error) { + describeTaskDefinitionOut, err := clt.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: aws.String(taskDefinitionName), + Include: []ecsTypes.TaskDefinitionField{ecsTypes.TaskDefinitionFieldTags}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if !ownershipTags.MatchesECSTags(describeTaskDefinitionOut.Tags) { + return nil, trace.Errorf("ECS Task Definition %q already exists but is not managed by Teleport. "+ + "Add the following tags to allow Teleport to manage this task definition: %s", taskDefinitionName, ownershipTags) + } + return describeTaskDefinitionOut.TaskDefinition, nil +} + +func getTaskDefinitionTeleportImage(taskDefinition *ecsTypes.TaskDefinition) (string, error) { + if len(taskDefinition.ContainerDefinitions) != 1 { + return "", trace.BadParameter("expected 1 task container definition, but got %d", len(taskDefinition.ContainerDefinitions)) + } + return aws.ToString(taskDefinition.ContainerDefinitions[0].Image), nil +} + +// updateServiceOrRollback attempts to update the service with the specified task definition. +// The service will be rolled back if the service fails to become healthy. +func updateServiceOrRollback(ctx context.Context, clt DeployServiceClient, service *ecsTypes.Service, taskDefinition *ecsTypes.TaskDefinition) (*ecsTypes.Service, error) { + // Update service with new task definition + updateServiceOut, err := clt.UpdateService(ctx, generateServiceWithTaskDefinition(service, aws.ToString(taskDefinition.TaskDefinitionArn))) + if err != nil { + return nil, trace.Wrap(err) + } + + serviceStableWaiter := ecs.NewServicesStableWaiter(clt) + err = serviceStableWaiter.Wait(ctx, &ecs.DescribeServicesInput{ + Services: []string{aws.ToString(updateServiceOut.Service.ServiceName)}, + Cluster: updateServiceOut.Service.ClusterArn, + }, waitDuration) + if err == nil { + return updateServiceOut.Service, nil + } + + // If the service fails to reach a stable state within the allowed wait time, + // then rollback service with previous task definition + rollbackServiceOut, rollbackErr := clt.UpdateService(ctx, generateServiceWithTaskDefinition(service, aws.ToString(service.TaskDefinition))) + if rollbackErr != nil { + return nil, trace.Wrap(err, "failed to rollback service: %v", err) + } + + rollbackErr = serviceStableWaiter.Wait(ctx, &ecs.DescribeServicesInput{ + Services: []string{aws.ToString(rollbackServiceOut.Service.ServiceName)}, + Cluster: updateServiceOut.Service.ClusterArn, + }, waitDuration) + if rollbackErr != nil { + return nil, trace.Wrap(err, "failed to rollback service: %v", err) + } + + return nil, trace.Wrap(err) +} + +// generateTaskDefinitionWithImage returns new register task definition input with the desired teleport image +func generateTaskDefinitionWithImage(taskDefinition *ecsTypes.TaskDefinition, teleportImage string, tags []ecsTypes.Tag) (*ecs.RegisterTaskDefinitionInput, error) { + if len(taskDefinition.ContainerDefinitions) != 1 { + return nil, trace.BadParameter("expected 1 task container definition, but got %d", len(taskDefinition.ContainerDefinitions)) + } + + // Copy container definition and replace the teleport image with desired version + newContainerDefinition := new(ecsTypes.ContainerDefinition) + awsutil.Copy(newContainerDefinition, &taskDefinition.ContainerDefinitions[0]) + newContainerDefinition.Image = aws.String(teleportImage) + + // Copy task definition and replace container definitions + registerTaskDefinitionIn := new(ecs.RegisterTaskDefinitionInput) + awsutil.Copy(registerTaskDefinitionIn, taskDefinition) + registerTaskDefinitionIn.ContainerDefinitions = []ecsTypes.ContainerDefinition{*newContainerDefinition} + registerTaskDefinitionIn.Tags = tags + + return registerTaskDefinitionIn, nil +} + +// generateServiceWithTaskDefinition returns new update service input with the desired task definition +func generateServiceWithTaskDefinition(service *ecsTypes.Service, taskDefinitionName string) *ecs.UpdateServiceInput { + updateServiceIn := new(ecs.UpdateServiceInput) + awsutil.Copy(updateServiceIn, service) + updateServiceIn.Service = service.ServiceName + updateServiceIn.Cluster = service.ClusterArn + updateServiceIn.TaskDefinition = aws.String(taskDefinitionName) + return updateServiceIn +} + +// UpdateDeployServiceAgents updates the deploy service agents with the specified teleportVersionTag. +func UpdateDeployServiceAgents(ctx context.Context, clt DeployServiceClient, teleportVersionTag string, ownershipTags AWSTags) error { + teleportFlavor := teleportOSS + if modules.GetModules().BuildType() == modules.BuildEnterprise { + teleportFlavor = teleportEnt + } + teleportImage := fmt.Sprintf("public.ecr.aws/gravitational/%s-distroless:%s", teleportFlavor, teleportVersionTag) + + clusterARNs, err := listManagedClusters(ctx, clt, ownershipTags) + if err != nil { + return trace.Wrap(err) + } + + var errs []error + for _, clusterARN := range clusterARNs { + if err := updateDeployServiceAgent(ctx, clt, clusterARN, teleportImage, ownershipTags); err != nil { + errs = append(errs, err) + } + } + return trace.NewAggregate(errs...) +} + +func updateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, clusterARN, teleportImage string, ownershipTags AWSTags) error { + service, err := getManagedService(ctx, clt, clusterARN, ownershipTags) + if err != nil { + return trace.Wrap(err) + } + + taskDefinition, err := getManagedTaskDefinition(ctx, clt, aws.ToString(service.TaskDefinition), ownershipTags) + if err != nil { + return trace.Wrap(err) + } + + currentTeleportImage, err := getTaskDefinitionTeleportImage(taskDefinition) + if err != nil { + return trace.Wrap(err) + } + + if currentTeleportImage == teleportImage { + return nil + } + + registerTaskDefinitionIn, err := generateTaskDefinitionWithImage(taskDefinition, teleportImage, ownershipTags.ToECSTags()) + if err != nil { + return trace.Wrap(err) + } + + registerTaskDefinitionOut, err := clt.RegisterTaskDefinition(ctx, registerTaskDefinitionIn) + if err != nil { + return trace.Wrap(err) + } + + // Update service with new task definition + _, err = updateServiceOrRollback(ctx, clt, service, registerTaskDefinitionOut.TaskDefinition) + if err != nil { + // If update failed, then rollback task definition + _, rollbackErr := clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: registerTaskDefinitionOut.TaskDefinition.TaskDefinitionArn, + }) + if rollbackErr != nil { + return trace.Wrap(err, "failed to rollback task definition: %v", rollbackErr) + } + return trace.Wrap(err) + } + + // Attempt to deregister previous task definition but ignore error on failure + clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: taskDefinition.TaskDefinitionArn, + }) + return nil +} diff --git a/lib/integrations/awsoidc/deployservice_update_test.go b/lib/integrations/awsoidc/deployservice_update_test.go new file mode 100644 index 0000000000000..66a9372505ca5 --- /dev/null +++ b/lib/integrations/awsoidc/deployservice_update_test.go @@ -0,0 +1,123 @@ +package awsoidc + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/gravitational/teleport/api/types" + "github.com/stretchr/testify/require" +) + +func TestGenerateServiceWithTaskDefinition(t *testing.T) { + service := &ecsTypes.Service{ + ServiceName: aws.String("service"), + ClusterArn: aws.String("cluster"), + TaskDefinition: aws.String("task-definition-v1"), + NetworkConfiguration: &ecsTypes.NetworkConfiguration{ + AwsvpcConfiguration: &ecsTypes.AwsVpcConfiguration{ + AssignPublicIp: ecsTypes.AssignPublicIpEnabled, + Subnets: []string{"subnet"}, + }, + }, + PropagateTags: ecsTypes.PropagateTagsService, + } + + expected := &ecs.UpdateServiceInput{ + Service: aws.String("service"), + Cluster: aws.String("cluster"), + TaskDefinition: aws.String("task-definition-v2"), + NetworkConfiguration: &ecsTypes.NetworkConfiguration{ + AwsvpcConfiguration: &ecsTypes.AwsVpcConfiguration{ + AssignPublicIp: ecsTypes.AssignPublicIpEnabled, + Subnets: []string{"subnet"}, + }, + }, + PropagateTags: ecsTypes.PropagateTagsService, + } + + require.Equal(t, expected, generateServiceWithTaskDefinition(service, "task-definition-v2")) +} + +func TestGenerateTaskDefinitionWithImage(t *testing.T) { + taskDefinition := &ecsTypes.TaskDefinition{ + Family: aws.String("example-task"), + RequiresCompatibilities: []ecsTypes.Compatibility{ + ecsTypes.CompatibilityFargate, + }, + Cpu: &taskCPU, + Memory: &taskMem, + + NetworkMode: ecsTypes.NetworkModeAwsvpc, + TaskRoleArn: aws.String("task-role-arn"), + ExecutionRoleArn: aws.String("task-role-arn"), + ContainerDefinitions: []ecsTypes.ContainerDefinition{{ + Environment: []ecsTypes.KeyValuePair{{ + Name: aws.String(types.InstallMethodAWSOIDCDeployServiceEnvVar), + Value: aws.String("true"), + }}, + Command: []string{ + "start", + "--config-string", + "config-bytes", + }, + EntryPoint: []string{"teleport"}, + Image: aws.String("image-v1"), + Name: &taskAgentContainerName, + LogConfiguration: &ecsTypes.LogConfiguration{ + LogDriver: ecsTypes.LogDriverAwslogs, + Options: map[string]string{ + "awslogs-group": "ecs-cluster", + "awslogs-region": "us-west-2", + "awslogs-create-group": "true", + "awslogs-stream-prefix": "service/example-task", + }, + }, + }}, + } + tags := []ecsTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + } + + expected := &ecs.RegisterTaskDefinitionInput{ + Family: aws.String("example-task"), + RequiresCompatibilities: []ecsTypes.Compatibility{ + ecsTypes.CompatibilityFargate, + }, + Cpu: &taskCPU, + Memory: &taskMem, + + NetworkMode: ecsTypes.NetworkModeAwsvpc, + TaskRoleArn: aws.String("task-role-arn"), + ExecutionRoleArn: aws.String("task-role-arn"), + ContainerDefinitions: []ecsTypes.ContainerDefinition{{ + Environment: []ecsTypes.KeyValuePair{{ + Name: aws.String(types.InstallMethodAWSOIDCDeployServiceEnvVar), + Value: aws.String("true"), + }}, + Command: []string{ + "start", + "--config-string", + "config-bytes", + }, + EntryPoint: []string{"teleport"}, + Image: aws.String("image-v2"), + Name: &taskAgentContainerName, + LogConfiguration: &ecsTypes.LogConfiguration{ + LogDriver: ecsTypes.LogDriverAwslogs, + Options: map[string]string{ + "awslogs-group": "ecs-cluster", + "awslogs-region": "us-west-2", + "awslogs-create-group": "true", + "awslogs-stream-prefix": "service/example-task", + }, + }, + }}, + Tags: tags, + } + + input, err := generateTaskDefinitionWithImage(taskDefinition, "image-v2", tags) + require.NoError(t, err) + require.Equal(t, expected, input) +} diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go new file mode 100644 index 0000000000000..380cdc9cfcd40 --- /dev/null +++ b/lib/service/awsoidc.go @@ -0,0 +1,245 @@ +package service + +import ( + "context" + "errors" + "net/url" + "strings" + "time" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/integration/integrationv1" + "github.com/gravitational/teleport/lib/automaticupgrades" + "github.com/gravitational/teleport/lib/integrations/awsoidc" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/interval" + + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "golang.org/x/time/rate" +) + +const ( + // updateDeployAgentsInterval specifies how frequently to check for available updates. + updateDeployAgentsInterval = time.Minute * 30 + + // updateDeployAgentsRateLimit specifies the time between updates across AWS regions. + updateDeployAgentsRateLimit = time.Second * 30 +) + +func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { + if !process.Config.Auth.Enabled { + return nil + } + + // start process only after teleport process has started + if _, err := process.WaitForEvent(process.GracefulExitContext(), TeleportReadyEvent); err != nil { + return nil + } + process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) + + // Acquire the semaphore before attempting to update the deploy service agents. + // This task should only run on a single instance at a time. + lock, err := services.AcquireSemaphoreWithRetry(process.GracefulExitContext(), + services.AcquireSemaphoreWithRetryConfig{ + Service: process.GetAuthServer(), + Request: types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: "update_deploy_service_agents", + MaxLeases: 1, + Expires: process.Clock.Now().Add(updateDeployAgentsInterval), + }, + Retry: retryutils.LinearConfig{ + Step: time.Minute, + Max: updateDeployAgentsInterval, + }, + }) + if err != nil { + return trace.Wrap(err) + } + defer process.GetAuthServer().CancelSemaphoreLease(process.GracefulExitContext(), *lock) + + periodic := interval.New(interval.Config{ + Duration: updateDeployAgentsInterval, + Jitter: retryutils.NewSeventhJitter(), + }) + defer periodic.Stop() + + for { + if err := process.updateDeployServiceAgents(process.GracefulExitContext(), process.GetAuthServer()); err != nil { + process.log.Warningf("Update failed: %v. Retrying in ~%v", err, updateDeployAgentsInterval) + } + + select { + case <-periodic.Next(): + case <-process.GracefulExitContext().Done(): + return nil + } + } +} + +func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, authServer *auth.Server) error { + if !process.shouldUpdateDeployAgents() { + return nil + } + + teleportVersion, err := process.getStableTeleportVersion() + if err != nil { + return trace.Wrap(err) + } + + issuer, err := awsoidc.IssuerForCluster(ctx, authServer) + if err != nil { + return trace.Wrap(err) + } + + token, err := integrationv1.GenerateAWSOIDCToken(ctx, integrationv1.AWSOIDCTokenConfig{ + CAGetter: authServer, + Clock: process.Clock, + TTL: updateDeployAgentsInterval, + Issuer: issuer, + }) + if err != nil { + return trace.Wrap(err) + } + + clusterNameConfig, err := authServer.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + + var resources []types.Integration + var nextKey string + for { + igs, nextKey, err := authServer.ListIntegrations(ctx, 0, nextKey) + if err != nil { + return trace.Wrap(err) + } + resources = append(resources, igs...) + if nextKey == "" { + break + } + } + + limit := rate.NewLimiter(rate.Every(updateDeployAgentsRateLimit), 1) + for _, ig := range resources { + spec := ig.GetAWSOIDCIntegrationSpec() + if spec == nil { + continue + } + + for _, region := range awsoidc.AWSRegionsList { + if err := limit.Wait(ctx); err != nil { + return trace.Wrap(err) + } + + req := &awsoidc.AWSClientRequest{ + IntegrationName: ig.GetName(), + Token: token, + RoleARN: spec.RoleARN, + Region: region, + } + + deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, authServer) + if err != nil { + process.log.Warningf("Failed to update deploy service agents: %v", err) + continue + } + + ownershipTags := map[string]string{ + types.ClusterLabel: clusterNameConfig.GetClusterName(), + types.OriginLabel: types.OriginIntegrationAWSOIDC, + types.IntegrationLabel: ig.GetName(), + } + + err = awsoidc.UpdateDeployServiceAgents(ctx, deployServiceClient, teleportVersion, ownershipTags) + invalidTokenError := new(ststypes.InvalidIdentityTokenException) + if errors.As(err, &invalidTokenError) { + process.log.Debugf("Invalid identity token for region %v: %v", region, err) + continue + } + if err != nil { + process.log.Warningf("Failed to update deploy service agents: %v", err) + continue + } + } + } + return nil +} + +// shouldUpdateDeployAgents returns true if deploy agents should be updated. +func (process *TeleportProcess) shouldUpdateDeployAgents() bool { + cmc, err := process.GetAuthServer().GetClusterMaintenanceConfig(process.GracefulExitContext()) + if err != nil { + process.log.Debugf("Failed to get cluster maintenance config: %v", err) + return false + } + + var criticalEndpoint string + if automaticupgrades.GetChannel() != "" { + criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") + if err != nil { + process.log.Debugf("Failed to get critical upgrade endpoint: %v", err) + return false + } + } + + critical, err := automaticupgrades.Critical(process.GracefulExitContext(), criticalEndpoint) + if err != nil { + process.log.Debugf("Failed to get critical upgrade value: %v", err) + return false + } + + if withinUpgradeWindow(cmc, process.Clock) || critical { + return true + } + + return false +} + +func (process *TeleportProcess) getStableTeleportVersion() (string, error) { + var versionEndpoint string + var err error + if automaticupgrades.GetChannel() != "" { + versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") + if err != nil { + return "", trace.Wrap(err) + } + } + + stableVersion, err := automaticupgrades.Version(process.GracefulExitContext(), versionEndpoint) + if err != nil { + return "", trace.Wrap(err) + } + // cloudStableVersion has vX.Y.Z format, however the container image tag does not include the `v`. + return strings.TrimPrefix(stableVersion, "v"), nil +} + +// withinUpgradeWindow returns true if the current time is within the configured +// upgrade window. +func withinUpgradeWindow(cmc types.ClusterMaintenanceConfig, clock clockwork.Clock) bool { + upgradeWindow, ok := cmc.GetAgentUpgradeWindow() + if !ok { + return false + } + + now := clock.Now() + if len(upgradeWindow.Weekdays) == 0 { + if int(upgradeWindow.UTCStartHour) == now.Hour() { + return true + } + } + + weekday := now.Weekday().String() + for _, upgradeWeekday := range upgradeWindow.Weekdays { + if weekday == upgradeWeekday { + if int(upgradeWindow.UTCStartHour) == now.Hour() { + return true + } + } + } + return false +} diff --git a/lib/service/awsoidc_test.go b/lib/service/awsoidc_test.go new file mode 100644 index 0000000000000..8e960d1410c3e --- /dev/null +++ b/lib/service/awsoidc_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "testing" + "time" + + "github.com/gravitational/teleport/api/types" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func TestWithinUpgradeWindow(t *testing.T) { + t.Parallel() + + tests := []struct { + desc string + upgradeWindow types.AgentUpgradeWindow + date string + withinWindow bool + }{ + { + desc: "within upgrade window", + upgradeWindow: types.AgentUpgradeWindow{ + UTCStartHour: 8, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: true, + }, + { + desc: "not within upgrade window", + upgradeWindow: types.AgentUpgradeWindow{ + UTCStartHour: 8, + }, + date: "Mon, 02 Jan 2006 09:04:05 UTC", + withinWindow: false, + }, + { + desc: "within upgrade window weekday", + upgradeWindow: types.AgentUpgradeWindow{ + UTCStartHour: 8, + Weekdays: []string{"Monday"}, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: true, + }, + { + desc: "not within upgrade window weekday", + upgradeWindow: types.AgentUpgradeWindow{ + UTCStartHour: 8, + Weekdays: []string{"Tuesday"}, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: false, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cmc := types.NewClusterMaintenanceConfig() + cmc.SetAgentUpgradeWindow(tt.upgradeWindow) + + date, err := time.Parse(time.RFC1123, tt.date) + require.NoError(t, err) + require.Equal(t, tt.withinWindow, withinUpgradeWindow(cmc, clockwork.NewFakeClockAt(date))) + }) + } +} diff --git a/lib/service/service.go b/lib/service/service.go index d83b8d2f46be8..1631396583a52 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -79,6 +79,7 @@ import ( "github.com/gravitational/teleport/lib/auth/keygen" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/dynamo" "github.com/gravitational/teleport/lib/backend/etcdbk" @@ -1194,6 +1195,10 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { // at any time with dynamic configuration process.RegisterFunc("common.upload.init", process.initUploaderService) + if automaticupgrades.IsEnabled() { + process.RegisterFunc("update.deploy.agents.auth", process.periodUpdateDeployServiceAgents) + } + if !serviceStarted { return nil, trace.BadParameter("all services failed to start") } diff --git a/lib/versioncontrol/upgradewindow/upgradewindow.go b/lib/versioncontrol/upgradewindow/upgradewindow.go index 51a05a7cd5b93..b3f89162d4248 100644 --- a/lib/versioncontrol/upgradewindow/upgradewindow.go +++ b/lib/versioncontrol/upgradewindow/upgradewindow.go @@ -43,6 +43,9 @@ const ( // unitScheduleFile is the name of the file to which the unit schedule is exported. unitScheduleFile = "schedule" + // unitEndpointFile is the name of the file which specifies a custom version server endpoint + unitEndpointFile = "endpoint" + // unitConfigDir is the configuration directory of the teleport-upgrade unit. unitConfigDir = "/etc/teleport-upgrade.d" ) diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index 374c60f082a5f..bda0fa5963e2d 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -398,7 +398,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter // This ensures the initial installed version is the same as the `teleport-ent-updater` would install. if settings.installUpdater { repoChannel = stableCloudChannelRepo - cloudStableVersion, err := automaticupgrades.Version(ctx, settings.automaticUpgradesVersionBaseURL) + cloudStableVersion, err := automaticupgrades.Version(ctx, "") if err != nil { return "", trace.Wrap(err) } From d263297d476b7bf5c37dfbefe8aee34afc819f79 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 15 Sep 2023 13:05:53 -0700 Subject: [PATCH 02/20] fix unit test --- lib/web/join_tokens.go | 6 +++--- lib/web/join_tokens_test.go | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index bda0fa5963e2d..7402108e38e1f 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -80,9 +80,9 @@ type scriptSettings struct { databaseInstallMode bool installUpdater bool - // automaticUpgradesVersionBaseURL is the base URL for getting the version when using the cloud/stable channel. + // automaticUpgradesVersionURL is the URL for getting the version when using the cloud/stable channel. // Optional. - automaticUpgradesVersionBaseURL string + automaticUpgradesVersionURL string } // automaticUpgrades returns whether automaticUpgrades should be enabled. @@ -398,7 +398,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter // This ensures the initial installed version is the same as the `teleport-ent-updater` would install. if settings.installUpdater { repoChannel = stableCloudChannelRepo - cloudStableVersion, err := automaticupgrades.Version(ctx, "") + cloudStableVersion, err := automaticupgrades.Version(ctx, settings.automaticUpgradesVersionURL) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index 1c4f706a5bdce..eb2c88e403e98 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -22,6 +22,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "regexp" "testing" @@ -942,7 +943,10 @@ func TestJoinScript(t *testing.T) { })) defer httpTestServer.Close() - script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersionBaseURL: httpTestServer.URL}, m) + versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") + require.NoError(t, err) + + script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersionURL: versionURL}, m) require.NoError(t, err) // list of packages must include the updater From 241e070715a37c28a3172d28f21eb7f8352d280f Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 15 Sep 2023 14:25:41 -0700 Subject: [PATCH 03/20] Remove unused var --- lib/versioncontrol/upgradewindow/upgradewindow.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/versioncontrol/upgradewindow/upgradewindow.go b/lib/versioncontrol/upgradewindow/upgradewindow.go index b3f89162d4248..51a05a7cd5b93 100644 --- a/lib/versioncontrol/upgradewindow/upgradewindow.go +++ b/lib/versioncontrol/upgradewindow/upgradewindow.go @@ -43,9 +43,6 @@ const ( // unitScheduleFile is the name of the file to which the unit schedule is exported. unitScheduleFile = "schedule" - // unitEndpointFile is the name of the file which specifies a custom version server endpoint - unitEndpointFile = "endpoint" - // unitConfigDir is the configuration directory of the teleport-upgrade unit. unitConfigDir = "/etc/teleport-upgrade.d" ) From c8d0ddc98310f88ded76606c3c70876a3fc48f54 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Wed, 20 Sep 2023 20:56:50 -0700 Subject: [PATCH 04/20] Addres feedback --- lib/auth/integration/integrationv1/awsoidc.go | 2 +- lib/automaticupgrades/config.go | 2 + lib/automaticupgrades/version.go | 72 +++++++++---------- .../awsoidc/deployservice_update.go | 35 +-------- .../awsoidc/deployservice_update_test.go | 19 ++++- lib/service/awsoidc.go | 35 +++++++-- lib/service/awsoidc_test.go | 19 ++++- 7 files changed, 104 insertions(+), 80 deletions(-) diff --git a/lib/auth/integration/integrationv1/awsoidc.go b/lib/auth/integration/integrationv1/awsoidc.go index 56309df2f393a..7bda64efd1d36 100644 --- a/lib/auth/integration/integrationv1/awsoidc.go +++ b/lib/auth/integration/integrationv1/awsoidc.go @@ -88,7 +88,7 @@ func (c *AWSOIDCTokenConfig) checkAndSetDefaults() error { return nil } -// GenerateToken generates a token to be used when executing an AWS OIDC Integration action. +// GenerateAWSOIDCToken generates a token to be used when executing an AWS OIDC Integration action. func GenerateAWSOIDCToken(ctx context.Context, config AWSOIDCTokenConfig) (string, error) { if err := config.checkAndSetDefaults(); err != nil { return "", trace.Wrap(err) diff --git a/lib/automaticupgrades/config.go b/lib/automaticupgrades/config.go index 21460562ea4a5..56d0f4a297c33 100644 --- a/lib/automaticupgrades/config.go +++ b/lib/automaticupgrades/config.go @@ -51,6 +51,8 @@ func IsEnabled() bool { } // GetChannel returns the TELEPORT_AUTOMATIC_UPGRADES_CHANNEL value. +// Example of an acceptable value for TELEPORT_AUTOMATIC_UPGRADES_CHANNEL is: +// https://updates.releases.teleport.dev/v1/stable/cloud func GetChannel() string { return os.Getenv(automaticUpgradesChannelEnvar) } diff --git a/lib/automaticupgrades/version.go b/lib/automaticupgrades/version.go index 5949f3e7e65d1..2cfa5b8510dba 100644 --- a/lib/automaticupgrades/version.go +++ b/lib/automaticupgrades/version.go @@ -48,7 +48,40 @@ func Version(ctx context.Context, versionURL string) (string, error) { return "", trace.Wrap(err) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL, nil) + resp, err := sendRequest(ctx, versionURL) + if err != nil { + return "", trace.Wrap(err) + } + + return resp, nil +} + +// Critical returns true if a critical upgrade is available. +func Critical(ctx context.Context, criticalURL string) (bool, error) { + criticalURL, err := getCriticalURL(criticalURL) + if err != nil { + return false, trace.Wrap(err) + } + + critical, err := sendRequest(ctx, criticalURL) + if err != nil { + return false, trace.Wrap(err) + } + + // Expectes critical endpoint to return either the string "yes" or "no" + switch critical { + case "yes": + return true, nil + case "no": + return false, nil + default: + return false, trace.BadParameter("critical endpoint returned an unexpected value: %v", critical) + } +} + +// sendRequest sends a GET request to the reqURL and returns the response value +func sendRequest(ctx context.Context, reqURL string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { return "", trace.Wrap(err) } @@ -68,9 +101,7 @@ func Version(ctx context.Context, versionURL string) (string, error) { return "", trace.BadParameter("invalid status code %d, body: %s", resp.StatusCode, string(body)) } - versionString := strings.TrimSpace(string(body)) - - return versionString, trace.Wrap(err) + return strings.TrimSpace(string(body)), trace.Wrap(err) } // getVersionURL returns the versionURL or the default stable/cloud version url. @@ -85,39 +116,6 @@ func getVersionURL(versionURL string) (string, error) { return cloudStableVersionURL, nil } -// Critical returns true if a critical upgrade is available. -func Critical(ctx context.Context, criticalURL string) (bool, error) { - criticalURL, err := getCriticalURL(criticalURL) - if err != nil { - return false, trace.Wrap(err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, criticalURL, nil) - if err != nil { - return false, trace.Wrap(err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, trace.Wrap(err) - } - defer resp.Body.Close() - - body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) - if err != nil { - return false, trace.Wrap(err) - } - - if resp.StatusCode != http.StatusOK { - return false, trace.BadParameter("invalid status code %d, body: %s", resp.StatusCode, string(body)) - } - - critical := strings.TrimSpace(string(body)) - - // The critical endpoint returns either the string "yes" or "no" - return critical == "yes", nil -} - // getCriticalURL returns the criticalURL or the default stable/cloud critical url. func getCriticalURL(criticalURL string) (string, error) { if criticalURL != "" { diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 960b030685de3..33b1c7d2c4b08 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -25,45 +25,14 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go/aws/awsutil" - "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/modules" ) // waitDuration specifies the amount of time to wait for a service to become healthy after an update. const waitDuration = time.Minute * 5 -// AWSRegionsList is the list of available AWS regions -// https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html -var AWSRegionsList = []string{ - "us-east-2", - "us-east-1", - "us-west-1", - "us-west-2", - "af-south-1", - "ap-east-1", - "ap-south-2", - "ap-southeast-3", - "ap-southeast-4", - "ap-south-1", - "ap-northeast-3", - "ap-northeast-2", - "ap-southeast-1", - "ap-southeast-2", - "ap-northeast-1", - "ca-central-1", - "eu-central-1", - "eu-west-1", - "eu-west-2", - "eu-south-1", - "eu-west-3", - "eu-south-2", - "eu-north-1", - "eu-central-2", - "me-south-1", - "me-central-1", - "sa-east-1", -} - func listManagedClusters(ctx context.Context, clt DeployServiceClient, ownershipTags AWSTags) (clusterARNs []string, err error) { listClustersOut, err := clt.ListClusters(ctx, &ecs.ListClustersInput{}) if err != nil { diff --git a/lib/integrations/awsoidc/deployservice_update_test.go b/lib/integrations/awsoidc/deployservice_update_test.go index 66a9372505ca5..9ef8b7c1fb176 100644 --- a/lib/integrations/awsoidc/deployservice_update_test.go +++ b/lib/integrations/awsoidc/deployservice_update_test.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package awsoidc import ( @@ -6,8 +22,9 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecs" ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/gravitational/teleport/api/types" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" ) func TestGenerateServiceWithTaskDefinition(t *testing.T) { diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 380cdc9cfcd40..61cb01870a162 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -1,3 +1,19 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package service import ( @@ -7,6 +23,11 @@ import ( "strings" "time" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "golang.org/x/time/rate" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" @@ -14,12 +35,8 @@ import ( "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/aws" "github.com/gravitational/teleport/lib/utils/interval" - - ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "golang.org/x/time/rate" ) const ( @@ -60,7 +77,11 @@ func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { if err != nil { return trace.Wrap(err) } - defer process.GetAuthServer().CancelSemaphoreLease(process.GracefulExitContext(), *lock) + defer func() { + if err := process.GetAuthServer().CancelSemaphoreLease(process.GracefulExitContext(), *lock); err != nil { + process.log.WithError(err).Errorf("Failed to cancel lease: %v.", lock) + } + }() periodic := interval.New(interval.Config{ Duration: updateDeployAgentsInterval, @@ -131,7 +152,7 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a continue } - for _, region := range awsoidc.AWSRegionsList { + for _, region := range aws.GetKnownRegions() { if err := limit.Wait(ctx); err != nil { return trace.Wrap(err) } diff --git a/lib/service/awsoidc_test.go b/lib/service/awsoidc_test.go index 8e960d1410c3e..9c0cb1eb89f38 100644 --- a/lib/service/awsoidc_test.go +++ b/lib/service/awsoidc_test.go @@ -1,12 +1,29 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package service import ( "testing" "time" - "github.com/gravitational/teleport/api/types" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" ) func TestWithinUpgradeWindow(t *testing.T) { From ece42b34cea72442294c245e53478a7c29c334d3 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Thu, 21 Sep 2023 14:40:25 -0700 Subject: [PATCH 05/20] Use list of available AWS database regions --- .../awsoidc/deployservice_update.go | 161 +++++++----------- lib/service/awsoidc.go | 32 +++- 2 files changed, 93 insertions(+), 100 deletions(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 33b1c7d2c4b08..1b131fd80d393 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -30,45 +30,84 @@ import ( "github.com/gravitational/teleport/lib/modules" ) -// waitDuration specifies the amount of time to wait for a service to become healthy after an update. -const waitDuration = time.Minute * 5 +// UpdateDeployServiceAgents updates the deploy service agents with the specified teleportVersionTag. +func UpdateDeployServiceAgents(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportVersionTag string, ownershipTags AWSTags) error { + teleportFlavor := teleportOSS + if modules.GetModules().BuildType() == modules.BuildEnterprise { + teleportFlavor = teleportEnt + } + teleportImage := fmt.Sprintf("public.ecr.aws/gravitational/%s-distroless:%s", teleportFlavor, teleportVersionTag) -func listManagedClusters(ctx context.Context, clt DeployServiceClient, ownershipTags AWSTags) (clusterARNs []string, err error) { - listClustersOut, err := clt.ListClusters(ctx, &ecs.ListClustersInput{}) + if err := updateDeployServiceAgent(ctx, clt, teleportClusterName, teleportImage, ownershipTags); err != nil { + return trace.Wrap(err) + } + return nil +} + +func updateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportImage string, ownershipTags AWSTags) error { + service, err := getManagedService(ctx, clt, teleportClusterName, ownershipTags) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - describeClustersOut, err := clt.DescribeClusters(ctx, &ecs.DescribeClustersInput{ - Clusters: listClustersOut.ClusterArns, - Include: []ecsTypes.ClusterField{ - ecsTypes.ClusterFieldTags, - }, - }) + taskDefinition, err := getManagedTaskDefinition(ctx, clt, aws.ToString(service.TaskDefinition), ownershipTags) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) + } + + currentTeleportImage, err := getTaskDefinitionTeleportImage(taskDefinition) + if err != nil { + return trace.Wrap(err) + } + + if currentTeleportImage == teleportImage { + return nil + } + + registerTaskDefinitionIn, err := generateTaskDefinitionWithImage(taskDefinition, teleportImage, ownershipTags.ToECSTags()) + if err != nil { + return trace.Wrap(err) + } + + registerTaskDefinitionOut, err := clt.RegisterTaskDefinition(ctx, registerTaskDefinitionIn) + if err != nil { + return trace.Wrap(err) } - for _, cluster := range describeClustersOut.Clusters { - if ownershipTags.MatchesECSTags(cluster.Tags) { - clusterARNs = append(clusterARNs, *cluster.ClusterArn) + // Update service with new task definition + _, err = updateServiceOrRollback(ctx, clt, service, registerTaskDefinitionOut.TaskDefinition) + if err != nil { + // If update failed, then rollback task definition + _, rollbackErr := clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: registerTaskDefinitionOut.TaskDefinition.TaskDefinitionArn, + }) + if rollbackErr != nil { + return trace.Wrap(err, "failed to rollback task definition: %v", rollbackErr) } + return trace.Wrap(err) } - return clusterARNs, nil -} -func getManagedService(ctx context.Context, clt DeployServiceClient, clusterARN string, ownershipTags AWSTags) (*ecsTypes.Service, error) { - listServicesOut, err := clt.ListServices(ctx, &ecs.ListServicesInput{ - Cluster: aws.String(clusterARN), - LaunchType: ecsTypes.LaunchTypeFargate, + // Attempt to deregister previous task definition but ignore error on failure + clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: taskDefinition.TaskDefinitionArn, }) - if err != nil { - return nil, trace.Wrap(err) + return nil +} + +// waitDuration specifies the amount of time to wait for a service to become healthy after an update. +const waitDuration = time.Minute * 5 + +func getManagedService(ctx context.Context, clt DeployServiceClient, teleportClusterName string, ownershipTags AWSTags) (*ecsTypes.Service, error) { + ecsClusterName := fmt.Sprintf("%s-teleport", normalizeECSResourceName(teleportClusterName)) + + var ecsServiceNames []string + for _, deploymentMode := range DeploymentModes { + ecsServiceNames = append(ecsServiceNames, fmt.Sprintf("%s-%s", ecsClusterName, deploymentMode)) } describeServicesOut, err := clt.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Cluster: aws.String(clusterARN), - Services: listServicesOut.ServiceArns, + Cluster: aws.String(ecsClusterName), + Services: ecsServiceNames, Include: []ecsTypes.ServiceField{ecsTypes.ServiceFieldTags}, }) if err != nil { @@ -179,75 +218,3 @@ func generateServiceWithTaskDefinition(service *ecsTypes.Service, taskDefinition updateServiceIn.TaskDefinition = aws.String(taskDefinitionName) return updateServiceIn } - -// UpdateDeployServiceAgents updates the deploy service agents with the specified teleportVersionTag. -func UpdateDeployServiceAgents(ctx context.Context, clt DeployServiceClient, teleportVersionTag string, ownershipTags AWSTags) error { - teleportFlavor := teleportOSS - if modules.GetModules().BuildType() == modules.BuildEnterprise { - teleportFlavor = teleportEnt - } - teleportImage := fmt.Sprintf("public.ecr.aws/gravitational/%s-distroless:%s", teleportFlavor, teleportVersionTag) - - clusterARNs, err := listManagedClusters(ctx, clt, ownershipTags) - if err != nil { - return trace.Wrap(err) - } - - var errs []error - for _, clusterARN := range clusterARNs { - if err := updateDeployServiceAgent(ctx, clt, clusterARN, teleportImage, ownershipTags); err != nil { - errs = append(errs, err) - } - } - return trace.NewAggregate(errs...) -} - -func updateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, clusterARN, teleportImage string, ownershipTags AWSTags) error { - service, err := getManagedService(ctx, clt, clusterARN, ownershipTags) - if err != nil { - return trace.Wrap(err) - } - - taskDefinition, err := getManagedTaskDefinition(ctx, clt, aws.ToString(service.TaskDefinition), ownershipTags) - if err != nil { - return trace.Wrap(err) - } - - currentTeleportImage, err := getTaskDefinitionTeleportImage(taskDefinition) - if err != nil { - return trace.Wrap(err) - } - - if currentTeleportImage == teleportImage { - return nil - } - - registerTaskDefinitionIn, err := generateTaskDefinitionWithImage(taskDefinition, teleportImage, ownershipTags.ToECSTags()) - if err != nil { - return trace.Wrap(err) - } - - registerTaskDefinitionOut, err := clt.RegisterTaskDefinition(ctx, registerTaskDefinitionIn) - if err != nil { - return trace.Wrap(err) - } - - // Update service with new task definition - _, err = updateServiceOrRollback(ctx, clt, service, registerTaskDefinitionOut.TaskDefinition) - if err != nil { - // If update failed, then rollback task definition - _, rollbackErr := clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ - TaskDefinition: registerTaskDefinitionOut.TaskDefinition.TaskDefinitionArn, - }) - if rollbackErr != nil { - return trace.Wrap(err, "failed to rollback task definition: %v", rollbackErr) - } - return trace.Wrap(err) - } - - // Attempt to deregister previous task definition but ignore error on failure - clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ - TaskDefinition: taskDefinition.TaskDefinitionArn, - }) - return nil -} diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 61cb01870a162..ff66c624c29f9 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -35,7 +35,6 @@ import ( "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/utils/aws" "github.com/gravitational/teleport/lib/utils/interval" ) @@ -145,6 +144,11 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a } } + awsRegions, err := process.listAWSDatabaseRegions() + if err != nil { + return trace.Wrap(err) + } + limit := rate.NewLimiter(rate.Every(updateDeployAgentsRateLimit), 1) for _, ig := range resources { spec := ig.GetAWSOIDCIntegrationSpec() @@ -152,7 +156,7 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a continue } - for _, region := range aws.GetKnownRegions() { + for _, region := range awsRegions { if err := limit.Wait(ctx); err != nil { return trace.Wrap(err) } @@ -176,7 +180,7 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a types.IntegrationLabel: ig.GetName(), } - err = awsoidc.UpdateDeployServiceAgents(ctx, deployServiceClient, teleportVersion, ownershipTags) + err = awsoidc.UpdateDeployServiceAgents(ctx, deployServiceClient, clusterNameConfig.GetClusterName(), teleportVersion, ownershipTags) invalidTokenError := new(ststypes.InvalidIdentityTokenException) if errors.As(err, &invalidTokenError) { process.log.Debugf("Invalid identity token for region %v: %v", region, err) @@ -191,6 +195,28 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a return nil } +// listAWSDatabaseRegions returns the list of AWS regions containing a connected database. +func (process *TeleportProcess) listAWSDatabaseRegions() ([]string, error) { + databases, err := process.GetAuthServer().GetDatabases(process.GracefulExitContext()) + if err != nil { + return nil, trace.Wrap(err) + } + + regions := make(map[string]interface{}) + for _, database := range databases { + if database.IsAWSHosted() && database.IsRDS() { + regions[database.GetAWS().Region] = nil + } + } + + var result []string + for region := range regions { + result = append(result, region) + } + + return result, nil +} + // shouldUpdateDeployAgents returns true if deploy agents should be updated. func (process *TeleportProcess) shouldUpdateDeployAgents() bool { cmc, err := process.GetAuthServer().GetClusterMaintenanceConfig(process.GracefulExitContext()) From b4690c2ac61bfc9df1a042deadc220ab9228c0b9 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 22 Sep 2023 14:54:43 -0700 Subject: [PATCH 06/20] Run update task on proxy instances --- lib/authz/permissions.go | 1 + lib/automaticupgrades/version.go | 2 +- lib/cloud/aws/policy_statements.go | 3 +- lib/integrations/awsoidc/deployservice.go | 8 - lib/service/awsoidc.go | 187 +++++++++------------- lib/service/service.go | 5 +- 6 files changed, 83 insertions(+), 123 deletions(-) diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index d72db19084e26..c2e93857ce6c9 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -616,6 +616,7 @@ func roleSpecForProxy(clusterName string) types.RoleSpecV6 { types.NewRule(types.KindDatabaseService, services.RO()), types.NewRule(types.KindSAMLIdPServiceProvider, services.RO()), types.NewRule(types.KindUserGroup, services.RO()), + types.NewRule(types.KindClusterMaintenanceConfig, services.RO()), types.NewRule(types.KindIntegration, append(services.RO(), types.VerbUse)), // this rule allows cloud proxies to read // plugins of `openai` type, since Assist uses the OpenAI API and runs in Proxy. diff --git a/lib/automaticupgrades/version.go b/lib/automaticupgrades/version.go index 2cfa5b8510dba..d1f6c393f1605 100644 --- a/lib/automaticupgrades/version.go +++ b/lib/automaticupgrades/version.go @@ -68,7 +68,7 @@ func Critical(ctx context.Context, criticalURL string) (bool, error) { return false, trace.Wrap(err) } - // Expectes critical endpoint to return either the string "yes" or "no" + // Expects critical endpoint to return either the string "yes" or "no" switch critical { case "yes": return true, nil diff --git a/lib/cloud/aws/policy_statements.go b/lib/cloud/aws/policy_statements.go index 931c5d1b9450b..25c7bd71848bd 100644 --- a/lib/cloud/aws/policy_statements.go +++ b/lib/cloud/aws/policy_statements.go @@ -50,8 +50,7 @@ func StatementForECSManageService() *Statement { Actions: []string{ "ecs:DescribeClusters", "ecs:CreateCluster", "ecs:PutClusterCapacityProviders", "ecs:DescribeServices", "ecs:CreateService", "ecs:UpdateService", - "ecs:RegisterTaskDefinition", "ecs:ListClusters", "ecs:ListServices", - "ecs:DescribeTaskDefinition", "ecs:DeregisterTaskDefinition", + "ecs:RegisterTaskDefinition", "ecs:DescribeTaskDefinition", "ecs:DeregisterTaskDefinition", // EC2 DescribeSecurityGroups is required so that the user can list the SG and then pick which ones they want to apply to the ECS Service. "ec2:DescribeSecurityGroups", diff --git a/lib/integrations/awsoidc/deployservice.go b/lib/integrations/awsoidc/deployservice.go index 47ace675266a1..4d385c48084a4 100644 --- a/lib/integrations/awsoidc/deployservice.go +++ b/lib/integrations/awsoidc/deployservice.go @@ -267,10 +267,6 @@ type DeployServiceResponse struct { // DeployServiceClient describes the required methods to Deploy a Teleport Service. type DeployServiceClient interface { - // ListClusters lists ECS Clusters - // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.ListClusters - ListClusters(ctx context.Context, params *ecs.ListClustersInput, optFns ...func(*ecs.Options)) (*ecs.ListClustersOutput, error) - // DescribeClusters lists ECS Clusters. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DescribeClusters DescribeClusters(ctx context.Context, params *ecs.DescribeClustersInput, optFns ...func(*ecs.Options)) (*ecs.DescribeClustersOutput, error) @@ -283,10 +279,6 @@ type DeployServiceClient interface { // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.PutClusterCapacityProviders PutClusterCapacityProviders(ctx context.Context, params *ecs.PutClusterCapacityProvidersInput, optFns ...func(*ecs.Options)) (*ecs.PutClusterCapacityProvidersOutput, error) - // ListServices returns a list of services - // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.ListServices - ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) - // DescribeServices lists the matching Services of a given Cluster. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DescribeServices DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index ff66c624c29f9..aee41b26b7e3a 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -18,23 +18,21 @@ package service import ( "context" - "errors" + "fmt" "net/url" "strings" "time" - ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "golang.org/x/time/rate" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/auth/integration/integrationv1" "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/integrations/awsoidc" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils/interval" ) @@ -47,7 +45,7 @@ const ( ) func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { - if !process.Config.Auth.Enabled { + if !process.Config.Proxy.Enabled { return nil } @@ -57,30 +55,14 @@ func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { } process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) - // Acquire the semaphore before attempting to update the deploy service agents. - // This task should only run on a single instance at a time. - lock, err := services.AcquireSemaphoreWithRetry(process.GracefulExitContext(), - services.AcquireSemaphoreWithRetryConfig{ - Service: process.GetAuthServer(), - Request: types.AcquireSemaphoreRequest{ - SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: "update_deploy_service_agents", - MaxLeases: 1, - Expires: process.Clock.Now().Add(updateDeployAgentsInterval), - }, - Retry: retryutils.LinearConfig{ - Step: time.Minute, - Max: updateDeployAgentsInterval, - }, - }) + resp, err := process.getInstanceClient().Ping(process.GracefulExitContext()) if err != nil { return trace.Wrap(err) } - defer func() { - if err := process.GetAuthServer().CancelSemaphoreLease(process.GracefulExitContext(), *lock); err != nil { - process.log.WithError(err).Errorf("Failed to cancel lease: %v.", lock) - } - }() + + if !resp.ServerFeatures.AutomaticUpgrades { + return nil + } periodic := interval.New(interval.Config{ Duration: updateDeployAgentsInterval, @@ -89,7 +71,7 @@ func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { defer periodic.Stop() for { - if err := process.updateDeployServiceAgents(process.GracefulExitContext(), process.GetAuthServer()); err != nil { + if err := process.updateDeployServiceAgents(process.GracefulExitContext(), process.getInstanceClient()); err != nil { process.log.Warningf("Update failed: %v. Retrying in ~%v", err, updateDeployAgentsInterval) } @@ -101,40 +83,68 @@ func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { } } -func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, authServer *auth.Server) error { - if !process.shouldUpdateDeployAgents() { +func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, authClient *auth.Client) error { + cmc, err := authClient.GetClusterMaintenanceConfig(ctx) + if err != nil { + return trace.Wrap(err) + } + + var criticalEndpoint string + if automaticupgrades.GetChannel() != "" { + criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") + if err != nil { + return trace.Wrap(err) + } + } + + critical, err := automaticupgrades.Critical(process.GracefulExitContext(), criticalEndpoint) + if err != nil { + return trace.Wrap(err) + } + + if !withinUpgradeWindow(cmc, process.Clock) && !critical { return nil } - teleportVersion, err := process.getStableTeleportVersion() + teleportVersion, err := getStableTeleportVersion(ctx) if err != nil { return trace.Wrap(err) } - issuer, err := awsoidc.IssuerForCluster(ctx, authServer) + issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) if err != nil { return trace.Wrap(err) } - token, err := integrationv1.GenerateAWSOIDCToken(ctx, integrationv1.AWSOIDCTokenConfig{ - CAGetter: authServer, - Clock: process.Clock, - TTL: updateDeployAgentsInterval, - Issuer: issuer, + token, err := authClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ + Issuer: issuer, }) if err != nil { return trace.Wrap(err) } - clusterNameConfig, err := authServer.GetClusterName() + clusterNameConfig, err := authClient.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + clusterName := clusterNameConfig.GetClusterName() + + databases, err := authClient.GetDatabases(ctx) if err != nil { return trace.Wrap(err) } + awsRegions := make(map[string]interface{}) + for _, database := range databases { + if database.IsAWSHosted() && database.IsRDS() { + awsRegions[database.GetAWS().Region] = nil + } + } + var resources []types.Integration var nextKey string for { - igs, nextKey, err := authServer.ListIntegrations(ctx, 0, nextKey) + igs, nextKey, err := authClient.ListIntegrations(ctx, 0, nextKey) if err != nil { return trace.Wrap(err) } @@ -144,19 +154,15 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a } } - awsRegions, err := process.listAWSDatabaseRegions() - if err != nil { - return trace.Wrap(err) - } - limit := rate.NewLimiter(rate.Every(updateDeployAgentsRateLimit), 1) for _, ig := range resources { spec := ig.GetAWSOIDCIntegrationSpec() if spec == nil { continue } + integrationName := ig.GetName() - for _, region := range awsRegions { + for region := range awsRegions { if err := limit.Wait(ctx); err != nil { return trace.Wrap(err) } @@ -168,86 +174,51 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a Region: region, } - deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, authServer) + deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, authClient) if err != nil { process.log.Warningf("Failed to update deploy service agents: %v", err) continue } ownershipTags := map[string]string{ - types.ClusterLabel: clusterNameConfig.GetClusterName(), + types.ClusterLabel: clusterName, types.OriginLabel: types.OriginIntegrationAWSOIDC, - types.IntegrationLabel: ig.GetName(), + types.IntegrationLabel: integrationName, } - err = awsoidc.UpdateDeployServiceAgents(ctx, deployServiceClient, clusterNameConfig.GetClusterName(), teleportVersion, ownershipTags) - invalidTokenError := new(ststypes.InvalidIdentityTokenException) - if errors.As(err, &invalidTokenError) { - process.log.Debugf("Invalid identity token for region %v: %v", region, err) - continue - } + // Acquire a lease for the region + integration before attempting to update the deploy service agent. + // If the lease cannot be acquired, the update is already being handled by another instance. + semLock, err := authClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s", region, integrationName), + MaxLeases: 1, + Expires: process.Clock.Now().Add(updateDeployAgentsInterval), + Holder: "update_deploy_service_agents", + }) + if err != nil { - process.log.Warningf("Failed to update deploy service agents: %v", err) - continue + if strings.Contains(err.Error(), teleport.MaxLeases) { + process.log.Debug("Deploy service agent update is already being processed") + continue + } + return trace.Wrap(err) } - } - } - return nil -} - -// listAWSDatabaseRegions returns the list of AWS regions containing a connected database. -func (process *TeleportProcess) listAWSDatabaseRegions() ([]string, error) { - databases, err := process.GetAuthServer().GetDatabases(process.GracefulExitContext()) - if err != nil { - return nil, trace.Wrap(err) - } - - regions := make(map[string]interface{}) - for _, database := range databases { - if database.IsAWSHosted() && database.IsRDS() { - regions[database.GetAWS().Region] = nil - } - } - - var result []string - for region := range regions { - result = append(result, region) - } - - return result, nil -} -// shouldUpdateDeployAgents returns true if deploy agents should be updated. -func (process *TeleportProcess) shouldUpdateDeployAgents() bool { - cmc, err := process.GetAuthServer().GetClusterMaintenanceConfig(process.GracefulExitContext()) - if err != nil { - process.log.Debugf("Failed to get cluster maintenance config: %v", err) - return false - } + if err := awsoidc.UpdateDeployServiceAgents(ctx, deployServiceClient, clusterNameConfig.GetClusterName(), teleportVersion, ownershipTags); err != nil { + process.log.Warningf("Failed to update deploy service agents: %v", err) - var criticalEndpoint string - if automaticupgrades.GetChannel() != "" { - criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") - if err != nil { - process.log.Debugf("Failed to get critical upgrade endpoint: %v", err) - return false + // Release the semaphore lease on failure so that another instance may attempt the update + if err := authClient.CancelSemaphoreLease(ctx, *semLock); err != nil { + process.log.WithError(err).Error("Failed to cancel semaphore lease") + } + } } } - - critical, err := automaticupgrades.Critical(process.GracefulExitContext(), criticalEndpoint) - if err != nil { - process.log.Debugf("Failed to get critical upgrade value: %v", err) - return false - } - - if withinUpgradeWindow(cmc, process.Clock) || critical { - return true - } - - return false + return nil } -func (process *TeleportProcess) getStableTeleportVersion() (string, error) { +// getStableTeleportVersion returns the current stable version of teleport +func getStableTeleportVersion(ctx context.Context) (string, error) { var versionEndpoint string var err error if automaticupgrades.GetChannel() != "" { @@ -257,7 +228,7 @@ func (process *TeleportProcess) getStableTeleportVersion() (string, error) { } } - stableVersion, err := automaticupgrades.Version(process.GracefulExitContext(), versionEndpoint) + stableVersion, err := automaticupgrades.Version(ctx, versionEndpoint) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/service/service.go b/lib/service/service.go index d69373ac06071..8e01a67068425 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -79,7 +79,6 @@ import ( "github.com/gravitational/teleport/lib/auth/keygen" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/authz" - "github.com/gravitational/teleport/lib/automaticupgrades" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/dynamo" "github.com/gravitational/teleport/lib/backend/etcdbk" @@ -1195,9 +1194,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { // at any time with dynamic configuration process.RegisterFunc("common.upload.init", process.initUploaderService) - if automaticupgrades.IsEnabled() { - process.RegisterFunc("update.deploy.agents.auth", process.periodUpdateDeployServiceAgents) - } + process.RegisterFunc("update.deploy.agents", process.periodUpdateDeployServiceAgents) if !serviceStarted { return nil, trace.BadParameter("all services failed to start") From d5c57c98fe643a9ea2ed9b5cbb7e68c20d17de55 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 22 Sep 2023 15:18:32 -0700 Subject: [PATCH 07/20] Revert GenerateAWSOIDCToken --- lib/auth/integration/integrationv1/awsoidc.go | 79 ++++--------------- lib/service/awsoidc.go | 14 ++-- 2 files changed, 21 insertions(+), 72 deletions(-) diff --git a/lib/auth/integration/integrationv1/awsoidc.go b/lib/auth/integration/integrationv1/awsoidc.go index 7bda64efd1d36..28c6a0827b8a1 100644 --- a/lib/auth/integration/integrationv1/awsoidc.go +++ b/lib/auth/integration/integrationv1/awsoidc.go @@ -21,7 +21,6 @@ import ( "time" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" integrationpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1" "github.com/gravitational/teleport/api/types" @@ -30,9 +29,6 @@ import ( "github.com/gravitational/teleport/lib/services" ) -// defaultTokenTTL is the default TTL for AWS OIDC tokens -const defaultTokenTTL = time.Minute - // GenerateAWSOIDCToken generates a token to be used when executing an AWS OIDC Integration action. func (s *Service) GenerateAWSOIDCToken(ctx context.Context, req *integrationpb.GenerateAWSOIDCTokenRequest) (*integrationpb.GenerateAWSOIDCTokenResponse, error) { _, err := authz.AuthorizeWithVerbs(ctx, s.logger, s.authorizer, true, types.KindIntegration, types.VerbUse) @@ -45,89 +41,42 @@ func (s *Service) GenerateAWSOIDCToken(ctx context.Context, req *integrationpb.G return nil, trace.Wrap(err) } - token, err := GenerateAWSOIDCToken(ctx, AWSOIDCTokenConfig{ - CAGetter: s.caGetter, - Clock: s.clock, - Issuer: req.Issuer, - Username: username, - }) + clusterName, err := s.caGetter.GetDomainName() if err != nil { return nil, trace.Wrap(err) } - return &integrationpb.GenerateAWSOIDCTokenResponse{ - Token: token, - }, nil -} - -// AWSOIDCTokenConfig contains configuration to be used when generating an AWS OIDC token. -type AWSOIDCTokenConfig struct { - CAGetter CAGetter - Clock clockwork.Clock - // TTL is the time to live for the token - TTL time.Duration - // Issuer is the issuer of the token. - Issuer string - // Username is the Teleport identity. - Username string -} - -func (c *AWSOIDCTokenConfig) checkAndSetDefaults() error { - if c.CAGetter == nil { - return trace.BadParameter("ca getter is required") - } - if c.Clock == nil { - c.Clock = clockwork.NewRealClock() - } - if c.TTL == 0 { - c.TTL = defaultTokenTTL - } - if c.Issuer == "" { - return trace.BadParameter("issuer is required") - } - return nil -} - -// GenerateAWSOIDCToken generates a token to be used when executing an AWS OIDC Integration action. -func GenerateAWSOIDCToken(ctx context.Context, config AWSOIDCTokenConfig) (string, error) { - if err := config.checkAndSetDefaults(); err != nil { - return "", trace.Wrap(err) - } - - clusterName, err := config.CAGetter.GetDomainName() - if err != nil { - return "", trace.Wrap(err) - } - - ca, err := config.CAGetter.GetCertAuthority(ctx, types.CertAuthID{ + ca, err := s.caGetter.GetCertAuthority(ctx, types.CertAuthID{ Type: types.OIDCIdPCA, DomainName: clusterName, }, true) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } // Extract the JWT signing key and sign the claims. - signer, err := config.CAGetter.GetKeyStore().GetJWTSigner(ctx, ca) + signer, err := s.caGetter.GetKeyStore().GetJWTSigner(ctx, ca) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } - privateKey, err := services.GetJWTSigner(signer, ca.GetClusterName(), config.Clock) + privateKey, err := services.GetJWTSigner(signer, ca.GetClusterName(), s.clock) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } token, err := privateKey.SignAWSOIDC(jwt.SignParams{ - Username: config.Username, + Username: username, Audience: types.IntegrationAWSOIDCAudience, Subject: types.IntegrationAWSOIDCSubject, - Issuer: config.Issuer, - Expires: config.Clock.Now().Add(config.TTL), + Issuer: req.Issuer, + Expires: s.clock.Now().Add(time.Minute), }) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } - return token, nil + return &integrationpb.GenerateAWSOIDCTokenResponse{ + Token: token, + }, nil } diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index aee41b26b7e3a..3ad1d6d75ca25 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -116,13 +116,6 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a return trace.Wrap(err) } - token, err := authClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ - Issuer: issuer, - }) - if err != nil { - return trace.Wrap(err) - } - clusterNameConfig, err := authClient.GetClusterName() if err != nil { return trace.Wrap(err) @@ -167,6 +160,13 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a return trace.Wrap(err) } + token, err := authClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ + Issuer: issuer, + }) + if err != nil { + return trace.Wrap(err) + } + req := &awsoidc.AWSClientRequest{ IntegrationName: ig.GetName(), Token: token, From f028a2420a224127a6d13a7d3cc135900594e486 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Sat, 23 Sep 2023 18:38:51 -0700 Subject: [PATCH 08/20] Move const to start of file --- lib/integrations/awsoidc/deployservice_update.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 1b131fd80d393..0620e12247afc 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -30,6 +30,9 @@ import ( "github.com/gravitational/teleport/lib/modules" ) +// waitDuration specifies the amount of time to wait for a service to become healthy after an update. +const waitDuration = time.Minute * 5 + // UpdateDeployServiceAgents updates the deploy service agents with the specified teleportVersionTag. func UpdateDeployServiceAgents(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportVersionTag string, ownershipTags AWSTags) error { teleportFlavor := teleportOSS @@ -94,9 +97,6 @@ func updateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, tele return nil } -// waitDuration specifies the amount of time to wait for a service to become healthy after an update. -const waitDuration = time.Minute * 5 - func getManagedService(ctx context.Context, clt DeployServiceClient, teleportClusterName string, ownershipTags AWSTags) (*ecsTypes.Service, error) { ecsClusterName := fmt.Sprintf("%s-teleport", normalizeECSResourceName(teleportClusterName)) From d801171edfcf9a65dcbe626699dfdd5d1ef09d16 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Tue, 26 Sep 2023 16:08:54 -0700 Subject: [PATCH 09/20] Address feedback --- api/client/client.go | 16 ++++ api/types/maintenance.go | 29 +++++++ api/types/maintenance_test.go | 57 +++++++++++++ lib/integrations/awsoidc/deployservice.go | 46 +++++++--- .../awsoidc/deployservice_update.go | 27 ++---- lib/service/awsoidc.go | 85 +++++-------------- lib/service/awsoidc_test.go | 84 ------------------ lib/service/service.go | 2 +- 8 files changed, 161 insertions(+), 185 deletions(-) delete mode 100644 lib/service/awsoidc_test.go diff --git a/api/client/client.go b/api/client/client.go index 9bac7060c88a0..b78566455d84b 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -3894,6 +3894,22 @@ func (c *Client) ListIntegrations(ctx context.Context, pageSize int, nextKey str return integrations, resp.GetNextKey(), nil } +// ListAllIntegrations returns the list of all Integrations. +func (c *Client) ListAllIntegrations(ctx context.Context) ([]types.Integration, error) { + var result []types.Integration + var nextKey string + for { + integrations, nextKey, err := c.ListIntegrations(ctx, 0, nextKey) + if err != nil { + return nil, trace.Wrap(err) + } + result = append(result, integrations...) + if nextKey == "" { + return result, nil + } + } +} + // GetIntegration returns an Integration by its name. func (c *Client) GetIntegration(ctx context.Context, name string) (types.Integration, error) { ig, err := c.integrationsClient().GetIntegration(ctx, &integrationpb.GetIntegrationRequest{ diff --git a/api/types/maintenance.go b/api/types/maintenance.go index a71976d297187..9cab6a9ad4765 100644 --- a/api/types/maintenance.go +++ b/api/types/maintenance.go @@ -147,6 +147,10 @@ type ClusterMaintenanceConfig interface { // SetAgentUpgradeWindow sets the agent upgrade window. SetAgentUpgradeWindow(win AgentUpgradeWindow) + // WithinUpgradeWindow returns true if the time is within the configured + // upgrade window. + WithinUpgradeWindow(t time.Time) bool + CheckAndSetDefaults() error } @@ -229,3 +233,28 @@ func (m *ClusterMaintenanceConfigV1) GetAgentUpgradeWindow() (win AgentUpgradeWi func (m *ClusterMaintenanceConfigV1) SetAgentUpgradeWindow(win AgentUpgradeWindow) { m.Spec.AgentUpgrades = &win } + +// WithinUpgradeWindow returns true if the time is within the configured +// upgrade window. +func (m *ClusterMaintenanceConfigV1) WithinUpgradeWindow(t time.Time) bool { + upgradeWindow, ok := m.GetAgentUpgradeWindow() + if !ok { + return false + } + + if len(upgradeWindow.Weekdays) == 0 { + if int(upgradeWindow.UTCStartHour) == t.Hour() { + return true + } + } + + weekday := t.Weekday().String() + for _, upgradeWeekday := range upgradeWindow.Weekdays { + if weekday == upgradeWeekday { + if int(upgradeWindow.UTCStartHour) == t.Hour() { + return true + } + } + } + return false +} diff --git a/api/types/maintenance_test.go b/api/types/maintenance_test.go index 990d1f9670f38..203006a8dee37 100644 --- a/api/types/maintenance_test.go +++ b/api/types/maintenance_test.go @@ -214,3 +214,60 @@ func TestWeekdayParser(t *testing.T) { require.Equal(t, tt.expect, day) } } + +func TestWithinUpgradeWindow(t *testing.T) { + t.Parallel() + + tests := []struct { + desc string + upgradeWindow AgentUpgradeWindow + date string + withinWindow bool + }{ + { + desc: "within upgrade window", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: true, + }, + { + desc: "not within upgrade window", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + }, + date: "Mon, 02 Jan 2006 09:04:05 UTC", + withinWindow: false, + }, + { + desc: "within upgrade window weekday", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + Weekdays: []string{"Monday"}, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: true, + }, + { + desc: "not within upgrade window weekday", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + Weekdays: []string{"Tuesday"}, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: false, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cmc := NewClusterMaintenanceConfig() + cmc.SetAgentUpgradeWindow(tt.upgradeWindow) + + date, err := time.Parse(time.RFC1123, tt.date) + require.NoError(t, err) + require.Equal(t, tt.withinWindow, cmc.WithinUpgradeWindow(date)) + }) + } +} diff --git a/lib/integrations/awsoidc/deployservice.go b/lib/integrations/awsoidc/deployservice.go index 4d385c48084a4..77060ed908939 100644 --- a/lib/integrations/awsoidc/deployservice.go +++ b/lib/integrations/awsoidc/deployservice.go @@ -58,11 +58,10 @@ var ( ) const ( - // teleportOSS is the prefix for the image name when deploying the OSS version of Teleport - teleportOSS = "teleport" - - // teleportEnt is the prefix for the image name when deploying the Enterprise version of Teleport - teleportEnt = "teleport-ent" + // distrolessTeleportOSS is the distroless image of the OSS version of Teleport + distrolessTeleportOSS = "public.ecr.aws/gravitational/teleport-distroless" + // distrolessTeleportEnt is the distroless image of the Enterprise version of Teleport + distrolessTeleportEnt = "public.ecr.aws/gravitational/teleport-ent-distroless" // clusterStatusActive is the string representing an ACTIVE ECS Cluster. clusterStatusActive = "ACTIVE" @@ -181,12 +180,26 @@ func normalizeECSResourceName(name string) string { return replacer.Replace(name) } +// normalizeECSClusterName returns the normalized ECS Cluster Name +func normalizeECSClusterName(teleportClusterName string) string { + return normalizeECSResourceName(fmt.Sprintf("%s-teleport", teleportClusterName)) +} + +// normalizeECSServiceName returns the normalized ECS Service Name +func normalizeECSServiceName(teleportClusterName, deploymentMode string) string { + return normalizeECSResourceName(fmt.Sprintf("%s-teleport-%s", teleportClusterName, deploymentMode)) +} + +// normalizeECSTaskName returns the normalized ECS TaskDefinition Family +func normalizeECSTaskName(teleportClusterName, deploymentMode string) string { + return normalizeECSResourceName(fmt.Sprintf("%s-teleport-%s", teleportClusterName, deploymentMode)) +} + // CheckAndSetDefaults checks if the required fields are present. func (r *DeployServiceRequest) CheckAndSetDefaults() error { if r.TeleportClusterName == "" { return trace.BadParameter("teleport cluster name is required") } - baseResourceName := normalizeECSResourceName(r.TeleportClusterName) if r.TeleportVersionTag == "" { r.TeleportVersionTag = teleport.Version @@ -217,17 +230,17 @@ func (r *DeployServiceRequest) CheckAndSetDefaults() error { } if r.ClusterName == nil || *r.ClusterName == "" { - clusterName := fmt.Sprintf("%s-teleport", baseResourceName) + clusterName := normalizeECSClusterName(r.TeleportClusterName) r.ClusterName = &clusterName } if r.ServiceName == nil || *r.ServiceName == "" { - serviceName := fmt.Sprintf("%s-teleport-%s", baseResourceName, r.DeploymentMode) + serviceName := normalizeECSServiceName(r.TeleportClusterName, r.DeploymentMode) r.ServiceName = &serviceName } if r.TaskName == nil || *r.TaskName == "" { - taskName := fmt.Sprintf("%s-teleport-%s", baseResourceName, r.DeploymentMode) + taskName := normalizeECSTaskName(r.TeleportClusterName, r.DeploymentMode) r.TaskName = &taskName } @@ -441,11 +454,7 @@ func DeployService(ctx context.Context, clt DeployServiceClient, req DeployServi // upsertTask ensures a TaskDefinition with TaskName exists func upsertTask(ctx context.Context, clt DeployServiceClient, req DeployServiceRequest, configB64 string) (*ecsTypes.TaskDefinition, error) { - teleportFlavor := teleportOSS - if modules.GetModules().BuildType() == modules.BuildEnterprise { - teleportFlavor = teleportEnt - } - taskAgentContainerImage := fmt.Sprintf("public.ecr.aws/gravitational/%s-distroless:%s", teleportFlavor, req.TeleportVersionTag) + taskAgentContainerImage := getDistrolessTeleportImage(req.TeleportVersionTag) taskDefOut, err := clt.RegisterTaskDefinition(ctx, &ecs.RegisterTaskDefinitionInput{ Family: req.TaskName, @@ -697,3 +706,12 @@ func upsertService(ctx context.Context, clt DeployServiceClient, req DeployServi return createServiceOut.Service, nil } + +// getDistrolessTeleportImage returns the distroless teleport image string +func getDistrolessTeleportImage(version string) string { + teleportImage := distrolessTeleportOSS + if modules.GetModules().BuildType() == modules.BuildEnterprise { + teleportImage = distrolessTeleportEnt + } + return fmt.Sprintf("%s:%s", teleportImage, version) +} diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 0620e12247afc..16bc2862327ad 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -18,7 +18,6 @@ package awsoidc import ( "context" - "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -26,28 +25,14 @@ import ( ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/modules" ) // waitDuration specifies the amount of time to wait for a service to become healthy after an update. const waitDuration = time.Minute * 5 -// UpdateDeployServiceAgents updates the deploy service agents with the specified teleportVersionTag. -func UpdateDeployServiceAgents(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportVersionTag string, ownershipTags AWSTags) error { - teleportFlavor := teleportOSS - if modules.GetModules().BuildType() == modules.BuildEnterprise { - teleportFlavor = teleportEnt - } - teleportImage := fmt.Sprintf("public.ecr.aws/gravitational/%s-distroless:%s", teleportFlavor, teleportVersionTag) - - if err := updateDeployServiceAgent(ctx, clt, teleportClusterName, teleportImage, ownershipTags); err != nil { - return trace.Wrap(err) - } - return nil -} - -func updateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportImage string, ownershipTags AWSTags) error { +// UpdateDeployServiceAgent updates the deploy service agent with the specified teleportVersionTag. +func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportVersionTag string, ownershipTags AWSTags) error { + teleportImage := getDistrolessTeleportImage(teleportVersionTag) service, err := getManagedService(ctx, clt, teleportClusterName, ownershipTags) if err != nil { return trace.Wrap(err) @@ -98,15 +83,13 @@ func updateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, tele } func getManagedService(ctx context.Context, clt DeployServiceClient, teleportClusterName string, ownershipTags AWSTags) (*ecsTypes.Service, error) { - ecsClusterName := fmt.Sprintf("%s-teleport", normalizeECSResourceName(teleportClusterName)) - var ecsServiceNames []string for _, deploymentMode := range DeploymentModes { - ecsServiceNames = append(ecsServiceNames, fmt.Sprintf("%s-%s", ecsClusterName, deploymentMode)) + ecsServiceNames = append(ecsServiceNames, normalizeECSServiceName(teleportClusterName, deploymentMode)) } describeServicesOut, err := clt.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Cluster: aws.String(ecsClusterName), + Cluster: aws.String(normalizeECSClusterName(teleportClusterName)), Services: ecsServiceNames, Include: []ecsTypes.ServiceField{ecsTypes.ServiceFieldTags}, }) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 3ad1d6d75ca25..b600c7a81f1d0 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -24,7 +24,6 @@ import ( "time" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" "golang.org/x/time/rate" "github.com/gravitational/teleport" @@ -44,14 +43,14 @@ const ( updateDeployAgentsRateLimit = time.Second * 30 ) -func (process *TeleportProcess) periodUpdateDeployServiceAgents() error { +func (process *TeleportProcess) periodicUpdateDeployServiceAgents() error { if !process.Config.Proxy.Enabled { return nil } // start process only after teleport process has started if _, err := process.WaitForEvent(process.GracefulExitContext(), TeleportReadyEvent); err != nil { - return nil + return trace.Wrap(err) } process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) @@ -89,27 +88,37 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a return trace.Wrap(err) } + // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used var criticalEndpoint string + var versionEndpoint string if automaticupgrades.GetChannel() != "" { criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") if err != nil { return trace.Wrap(err) } + versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") + if err != nil { + return trace.Wrap(err) + } } - critical, err := automaticupgrades.Critical(process.GracefulExitContext(), criticalEndpoint) + critical, err := automaticupgrades.Critical(ctx, criticalEndpoint) if err != nil { return trace.Wrap(err) } - if !withinUpgradeWindow(cmc, process.Clock) && !critical { + // Upgrade should only be attempted if the current time is within the configured + // upgrade window, or if a critical upgrade is available + if !cmc.WithinUpgradeWindow(process.Clock.Now()) && !critical { return nil } - teleportVersion, err := getStableTeleportVersion(ctx) + stableVersion, err := automaticupgrades.Version(ctx, versionEndpoint) if err != nil { return trace.Wrap(err) } + // cloudStableVersion has vX.Y.Z format, however the container image tag does not include the `v`. + cloudStableVersion := strings.TrimPrefix(stableVersion, "v") issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) if err != nil { @@ -134,21 +143,13 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a } } - var resources []types.Integration - var nextKey string - for { - igs, nextKey, err := authClient.ListIntegrations(ctx, 0, nextKey) - if err != nil { - return trace.Wrap(err) - } - resources = append(resources, igs...) - if nextKey == "" { - break - } + integrations, err := authClient.ListAllIntegrations(ctx) + if err != nil { + return trace.Wrap(err) } limit := rate.NewLimiter(rate.Every(updateDeployAgentsRateLimit), 1) - for _, ig := range resources { + for _, ig := range integrations { spec := ig.GetAWSOIDCIntegrationSpec() if spec == nil { continue @@ -204,7 +205,8 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a return trace.Wrap(err) } - if err := awsoidc.UpdateDeployServiceAgents(ctx, deployServiceClient, clusterNameConfig.GetClusterName(), teleportVersion, ownershipTags); err != nil { + process.log.Debugf("Updating Deploy Service Agents in AWS region: %s", region) + if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, clusterNameConfig.GetClusterName(), cloudStableVersion, ownershipTags); err != nil { process.log.Warningf("Failed to update deploy service agents: %v", err) // Release the semaphore lease on failure so that another instance may attempt the update @@ -216,48 +218,3 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a } return nil } - -// getStableTeleportVersion returns the current stable version of teleport -func getStableTeleportVersion(ctx context.Context) (string, error) { - var versionEndpoint string - var err error - if automaticupgrades.GetChannel() != "" { - versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") - if err != nil { - return "", trace.Wrap(err) - } - } - - stableVersion, err := automaticupgrades.Version(ctx, versionEndpoint) - if err != nil { - return "", trace.Wrap(err) - } - // cloudStableVersion has vX.Y.Z format, however the container image tag does not include the `v`. - return strings.TrimPrefix(stableVersion, "v"), nil -} - -// withinUpgradeWindow returns true if the current time is within the configured -// upgrade window. -func withinUpgradeWindow(cmc types.ClusterMaintenanceConfig, clock clockwork.Clock) bool { - upgradeWindow, ok := cmc.GetAgentUpgradeWindow() - if !ok { - return false - } - - now := clock.Now() - if len(upgradeWindow.Weekdays) == 0 { - if int(upgradeWindow.UTCStartHour) == now.Hour() { - return true - } - } - - weekday := now.Weekday().String() - for _, upgradeWeekday := range upgradeWindow.Weekdays { - if weekday == upgradeWeekday { - if int(upgradeWindow.UTCStartHour) == now.Hour() { - return true - } - } - } - return false -} diff --git a/lib/service/awsoidc_test.go b/lib/service/awsoidc_test.go deleted file mode 100644 index 9c0cb1eb89f38..0000000000000 --- a/lib/service/awsoidc_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* -Copyright 2023 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package service - -import ( - "testing" - "time" - - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/api/types" -) - -func TestWithinUpgradeWindow(t *testing.T) { - t.Parallel() - - tests := []struct { - desc string - upgradeWindow types.AgentUpgradeWindow - date string - withinWindow bool - }{ - { - desc: "within upgrade window", - upgradeWindow: types.AgentUpgradeWindow{ - UTCStartHour: 8, - }, - date: "Mon, 02 Jan 2006 08:04:05 UTC", - withinWindow: true, - }, - { - desc: "not within upgrade window", - upgradeWindow: types.AgentUpgradeWindow{ - UTCStartHour: 8, - }, - date: "Mon, 02 Jan 2006 09:04:05 UTC", - withinWindow: false, - }, - { - desc: "within upgrade window weekday", - upgradeWindow: types.AgentUpgradeWindow{ - UTCStartHour: 8, - Weekdays: []string{"Monday"}, - }, - date: "Mon, 02 Jan 2006 08:04:05 UTC", - withinWindow: true, - }, - { - desc: "not within upgrade window weekday", - upgradeWindow: types.AgentUpgradeWindow{ - UTCStartHour: 8, - Weekdays: []string{"Tuesday"}, - }, - date: "Mon, 02 Jan 2006 08:04:05 UTC", - withinWindow: false, - }, - } - - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - cmc := types.NewClusterMaintenanceConfig() - cmc.SetAgentUpgradeWindow(tt.upgradeWindow) - - date, err := time.Parse(time.RFC1123, tt.date) - require.NoError(t, err) - require.Equal(t, tt.withinWindow, withinUpgradeWindow(cmc, clockwork.NewFakeClockAt(date))) - }) - } -} diff --git a/lib/service/service.go b/lib/service/service.go index 8e01a67068425..52948d24e4f2a 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1194,7 +1194,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { // at any time with dynamic configuration process.RegisterFunc("common.upload.init", process.initUploaderService) - process.RegisterFunc("update.deploy.agents", process.periodUpdateDeployServiceAgents) + process.RegisterFunc("update.aws-oidc.deploy.agents", process.periodicUpdateDeployServiceAgents) if !serviceStarted { return nil, trace.BadParameter("all services failed to start") From a31d614ba51e6802929b8004a3b2c6742a03ee15 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Wed, 27 Sep 2023 19:50:55 -0700 Subject: [PATCH 10/20] Create separate DeployServiceUpdater struct --- .../awsoidc/deployservice_update.go | 40 +++- lib/service/awsoidc.go | 207 ++++++++++++------ lib/service/service.go | 6 +- 3 files changed, 183 insertions(+), 70 deletions(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 16bc2862327ad..be0a8f4da7097 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -30,15 +30,42 @@ import ( // waitDuration specifies the amount of time to wait for a service to become healthy after an update. const waitDuration = time.Minute * 5 +// UpdateServiceRequest contains the required fields to update a Teleport Service. +type UpdateServiceRequest struct { + // TeleportClusterName specifies the teleport cluster name + TeleportClusterName string + // TeleportVersionTag specifies the desired teleport version in the format "13.4.0" + TeleportVersionTag string + // OwnershipTags specifies ownership tags + OwnershipTags AWSTags +} + +// CheckAndSetDefaults checks and sets default config values. +func (req *UpdateServiceRequest) CheckAndSetDefaults() error { + if req.TeleportClusterName == "" { + return trace.BadParameter("teleport cluster name required") + } + + if req.TeleportVersionTag == "" { + return trace.BadParameter("teleport version tag required") + } + + return nil +} + // UpdateDeployServiceAgent updates the deploy service agent with the specified teleportVersionTag. -func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, teleportClusterName, teleportVersionTag string, ownershipTags AWSTags) error { - teleportImage := getDistrolessTeleportImage(teleportVersionTag) - service, err := getManagedService(ctx, clt, teleportClusterName, ownershipTags) +func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, req UpdateServiceRequest) error { + if err := req.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + teleportImage := getDistrolessTeleportImage(req.TeleportVersionTag) + service, err := getManagedService(ctx, clt, req.TeleportClusterName, req.OwnershipTags) if err != nil { return trace.Wrap(err) } - taskDefinition, err := getManagedTaskDefinition(ctx, clt, aws.ToString(service.TaskDefinition), ownershipTags) + taskDefinition, err := getManagedTaskDefinition(ctx, clt, aws.ToString(service.TaskDefinition), req.OwnershipTags) if err != nil { return trace.Wrap(err) } @@ -52,7 +79,7 @@ func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, tele return nil } - registerTaskDefinitionIn, err := generateTaskDefinitionWithImage(taskDefinition, teleportImage, ownershipTags.ToECSTags()) + registerTaskDefinitionIn, err := generateTaskDefinitionWithImage(taskDefinition, teleportImage, req.OwnershipTags.ToECSTags()) if err != nil { return trace.Wrap(err) } @@ -96,6 +123,9 @@ func getManagedService(ctx context.Context, clt DeployServiceClient, teleportClu if err != nil { return nil, trace.Wrap(err) } + if len(describeServicesOut.Services) == 0 { + return nil, trace.NotFound("service not found") + } if len(describeServicesOut.Services) != 1 { return nil, trace.BadParameter("expected 1 service, but got %d", len(describeServicesOut.Services)) } diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index b600c7a81f1d0..c03f86819daab 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -24,7 +24,8 @@ import ( "time" "github.com/gravitational/trace" - "golang.org/x/time/rate" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -38,21 +39,13 @@ import ( const ( // updateDeployAgentsInterval specifies how frequently to check for available updates. updateDeployAgentsInterval = time.Minute * 30 - - // updateDeployAgentsRateLimit specifies the time between updates across AWS regions. - updateDeployAgentsRateLimit = time.Second * 30 ) -func (process *TeleportProcess) periodicUpdateDeployServiceAgents() error { - if !process.Config.Proxy.Enabled { - return nil - } - +func (process *TeleportProcess) initDeployServiceUpdater() error { // start process only after teleport process has started if _, err := process.WaitForEvent(process.GracefulExitContext(), TeleportReadyEvent); err != nil { return trace.Wrap(err) } - process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) resp, err := process.getInstanceClient().Ping(process.GracefulExitContext()) if err != nil { @@ -63,6 +56,109 @@ func (process *TeleportProcess) periodicUpdateDeployServiceAgents() error { return nil } + process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) + + // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used + var criticalEndpoint string + var versionEndpoint string + if automaticupgrades.GetChannel() != "" { + criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") + if err != nil { + return trace.Wrap(err) + } + versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") + if err != nil { + return trace.Wrap(err) + } + } + + issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) + if err != nil { + return trace.Wrap(err) + } + + clusterNameConfig, err := process.getInstanceClient().GetClusterName() + if err != nil { + return trace.Wrap(err) + } + + updater, err := NewDeployServiceUpdater(DeployServiceUpdaterConfig{ + Log: process.log.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "deployserviceupdater")), + AuthClient: process.getInstanceClient(), + Clock: process.Clock, + TeleportClusterName: clusterNameConfig.GetClusterName(), + AWSOIDCProviderAddr: issuer, + CriticalEndpoint: criticalEndpoint, + VersionEndpoint: versionEndpoint, + }) + if err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(updater.Run(process.GracefulExitContext())) +} + +// DeployServiceUpdaterConfig specifies updater configs +type DeployServiceUpdaterConfig struct { + // Log is the logger + Log *logrus.Entry + // AuthClient is the auth api client + AuthClient *auth.Client + // Clock is the local clock + Clock clockwork.Clock + // TeleportClusterName specifies the teleport cluster name + TeleportClusterName string + // AWSOIDCProvderAddr specifies the aws oidc provider address used to generate AWS OIDC tokens + AWSOIDCProviderAddr string + // CriticalEndpoint specifies the endpoint to check for critical updates + CriticalEndpoint string + // VersionEndpoint specifies the endpoint to check for current teleport version + VersionEndpoint string +} + +// CheckAndSetDefaults checks and sets default config values. +func (cfg *DeployServiceUpdaterConfig) CheckAndSetDefaults() error { + if cfg.AuthClient == nil { + return trace.BadParameter("auth client required") + } + + if cfg.TeleportClusterName == "" { + return trace.BadParameter("teleport cluster name required") + } + + if cfg.AWSOIDCProviderAddr == "" { + return trace.BadParameter("aws oidc provider address required") + } + + if cfg.Log == nil { + cfg.Log = logrus.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "deployserviceupdater")) + } + + if cfg.Clock == nil { + cfg.Clock = clockwork.NewRealClock() + } + + return nil +} + +// DeployServiceUpdater periodically updates deploy service agents +type DeployServiceUpdater struct { + DeployServiceUpdaterConfig +} + +// NewDeployServiceUpdater returns a new DeployServiceUpdater +func NewDeployServiceUpdater(config DeployServiceUpdaterConfig) (*DeployServiceUpdater, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &DeployServiceUpdater{ + DeployServiceUpdaterConfig: config, + }, nil +} + +// Run periodically updates the deploy service agents +func (updater *DeployServiceUpdater) Run(ctx context.Context) error { periodic := interval.New(interval.Config{ Duration: updateDeployAgentsInterval, Jitter: retryutils.NewSeventhJitter(), @@ -70,72 +166,51 @@ func (process *TeleportProcess) periodicUpdateDeployServiceAgents() error { defer periodic.Stop() for { - if err := process.updateDeployServiceAgents(process.GracefulExitContext(), process.getInstanceClient()); err != nil { - process.log.Warningf("Update failed: %v. Retrying in ~%v", err, updateDeployAgentsInterval) + if err := updater.updateDeployServiceAgents(ctx); err != nil { + updater.Log.WithError(err).Warningf("Update failed. Retrying in ~%v.", updateDeployAgentsInterval) } select { case <-periodic.Next(): - case <-process.GracefulExitContext().Done(): + case <-ctx.Done(): return nil } } } -func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, authClient *auth.Client) error { - cmc, err := authClient.GetClusterMaintenanceConfig(ctx) +func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Context) error { + cmc, err := updater.AuthClient.GetClusterMaintenanceConfig(ctx) if err != nil { return trace.Wrap(err) } - // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used - var criticalEndpoint string - var versionEndpoint string - if automaticupgrades.GetChannel() != "" { - criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") - if err != nil { - return trace.Wrap(err) - } - versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") - if err != nil { - return trace.Wrap(err) - } - } - - critical, err := automaticupgrades.Critical(ctx, criticalEndpoint) + critical, err := automaticupgrades.Critical(ctx, updater.CriticalEndpoint) if err != nil { return trace.Wrap(err) } // Upgrade should only be attempted if the current time is within the configured // upgrade window, or if a critical upgrade is available - if !cmc.WithinUpgradeWindow(process.Clock.Now()) && !critical { + if !cmc.WithinUpgradeWindow(updater.Clock.Now()) && !critical { return nil } - stableVersion, err := automaticupgrades.Version(ctx, versionEndpoint) + stableVersion, err := automaticupgrades.Version(ctx, updater.VersionEndpoint) if err != nil { return trace.Wrap(err) } // cloudStableVersion has vX.Y.Z format, however the container image tag does not include the `v`. cloudStableVersion := strings.TrimPrefix(stableVersion, "v") - issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) - if err != nil { - return trace.Wrap(err) - } - - clusterNameConfig, err := authClient.GetClusterName() - if err != nil { - return trace.Wrap(err) - } - clusterName := clusterNameConfig.GetClusterName() - - databases, err := authClient.GetDatabases(ctx) + databases, err := updater.AuthClient.GetDatabases(ctx) if err != nil { return trace.Wrap(err) } + // The updater needs to iterate over all integrations and aws regions to check + // for deploy service agents to update. In order to reduce the number of api + // calls, the aws regions are first reduced to only the regions containing + // an RDS database. awsRegions := make(map[string]interface{}) for _, database := range databases { if database.IsAWSHosted() && database.IsRDS() { @@ -143,12 +218,11 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a } } - integrations, err := authClient.ListAllIntegrations(ctx) + integrations, err := updater.AuthClient.ListAllIntegrations(ctx) if err != nil { return trace.Wrap(err) } - limit := rate.NewLimiter(rate.Every(updateDeployAgentsRateLimit), 1) for _, ig := range integrations { spec := ig.GetAWSOIDCIntegrationSpec() if spec == nil { @@ -157,12 +231,8 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a integrationName := ig.GetName() for region := range awsRegions { - if err := limit.Wait(ctx); err != nil { - return trace.Wrap(err) - } - - token, err := authClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ - Issuer: issuer, + token, err := updater.AuthClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ + Issuer: updater.AWSOIDCProviderAddr, }) if err != nil { return trace.Wrap(err) @@ -175,43 +245,54 @@ func (process *TeleportProcess) updateDeployServiceAgents(ctx context.Context, a Region: region, } - deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, authClient) + // The deploy service client is initialized using AWS OIDC integration + deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, updater.AuthClient) if err != nil { - process.log.Warningf("Failed to update deploy service agents: %v", err) + updater.Log.WithError(err).Warning("Failed to update deploy service agents.") continue } ownershipTags := map[string]string{ - types.ClusterLabel: clusterName, + types.ClusterLabel: updater.TeleportClusterName, types.OriginLabel: types.OriginIntegrationAWSOIDC, types.IntegrationLabel: integrationName, } // Acquire a lease for the region + integration before attempting to update the deploy service agent. // If the lease cannot be acquired, the update is already being handled by another instance. - semLock, err := authClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ + semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ SemaphoreKind: types.SemaphoreKindConnection, SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s", region, integrationName), MaxLeases: 1, - Expires: process.Clock.Now().Add(updateDeployAgentsInterval), + Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), Holder: "update_deploy_service_agents", }) if err != nil { if strings.Contains(err.Error(), teleport.MaxLeases) { - process.log.Debug("Deploy service agent update is already being processed") + updater.Log.Debug("Deploy service agent update is already being processed") continue } return trace.Wrap(err) } - process.log.Debugf("Updating Deploy Service Agents in AWS region: %s", region) - if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, clusterNameConfig.GetClusterName(), cloudStableVersion, ownershipTags); err != nil { - process.log.Warningf("Failed to update deploy service agents: %v", err) + updater.Log.Debugf("Updating Deploy Service Agents in AWS region: %s", region) + if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, awsoidc.UpdateServiceRequest{ + TeleportClusterName: updater.TeleportClusterName, + TeleportVersionTag: cloudStableVersion, + OwnershipTags: ownershipTags, + }); err != nil { + if trace.IsNotFound(err) { + // The updater checks each integration/region combination, so + // there will be regions where there is no ECS cluster deployed + // for the integration + continue + } + updater.Log.WithError(err).Warning("Failed to update deploy service agents.") // Release the semaphore lease on failure so that another instance may attempt the update - if err := authClient.CancelSemaphoreLease(ctx, *semLock); err != nil { - process.log.WithError(err).Error("Failed to cancel semaphore lease") + if err := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); err != nil { + updater.Log.WithError(err).Error("Failed to cancel semaphore lease.") } } } diff --git a/lib/service/service.go b/lib/service/service.go index 52948d24e4f2a..aafa9d6bae5d5 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1014,6 +1014,10 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { process.log.Infof("Configured upgrade window exporter for external upgrader. kind=%s", upgraderKind) } + if process.Config.Proxy.Enabled { + process.RegisterFunc("update.aws-oidc.deploy.agents", process.initDeployServiceUpdater) + } + serviceStarted := false if !cfg.DiagnosticAddr.IsEmpty() { @@ -1194,8 +1198,6 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { // at any time with dynamic configuration process.RegisterFunc("common.upload.init", process.initUploaderService) - process.RegisterFunc("update.aws-oidc.deploy.agents", process.periodicUpdateDeployServiceAgents) - if !serviceStarted { return nil, trace.BadParameter("all services failed to start") } From aeee42d1e107e12653732bd211ab2edee257a4bb Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Thu, 28 Sep 2023 21:59:57 -0700 Subject: [PATCH 11/20] Address feedback - Perform updates in parallel - Add additional logging - Add additional documentation --- .../awsoidc/deployservice_update.go | 13 +- lib/service/awsoidc.go | 177 ++++++++++-------- 2 files changed, 113 insertions(+), 77 deletions(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index be0a8f4da7097..810e7b8fa8dba 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -25,6 +25,7 @@ import ( ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" ) // waitDuration specifies the amount of time to wait for a service to become healthy after an update. @@ -75,6 +76,8 @@ func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, req return trace.Wrap(err) } + // There is no need to update the ecs service if the ecs service is already + // running the latest stable version of teleport. if currentTeleportImage == teleportImage { return nil } @@ -97,15 +100,19 @@ func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, req TaskDefinition: registerTaskDefinitionOut.TaskDefinition.TaskDefinitionArn, }) if rollbackErr != nil { - return trace.Wrap(err, "failed to rollback task definition: %v", rollbackErr) + return trace.NewAggregate(err, trace.Wrap(rollbackErr, "failed to rollback task definition")) } return trace.Wrap(err) } // Attempt to deregister previous task definition but ignore error on failure - clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + _, err = clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ TaskDefinition: taskDefinition.TaskDefinitionArn, }) + if err != nil { + logrus.WithError(err).Warning("Failed to deregister task definition.") + } + return nil } @@ -124,7 +131,7 @@ func getManagedService(ctx context.Context, clt DeployServiceClient, teleportClu return nil, trace.Wrap(err) } if len(describeServicesOut.Services) == 0 { - return nil, trace.NotFound("service not found") + return nil, trace.NotFound("services %v not found", ecsServiceNames) } if len(describeServicesOut.Services) != 1 { return nil, trace.BadParameter("expected 1 service, but got %d", len(describeServicesOut.Services)) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index c03f86819daab..9207f05d9444f 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -21,6 +21,7 @@ import ( "fmt" "net/url" "strings" + "sync" "time" "github.com/gravitational/trace" @@ -32,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/automaticupgrades" + awslib "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/utils/interval" ) @@ -39,6 +41,9 @@ import ( const ( // updateDeployAgentsInterval specifies how frequently to check for available updates. updateDeployAgentsInterval = time.Minute * 30 + + // maxConcurrentUpdates specifies the maximum number of concurrent updates + maxConcurrentUpdates = 3 ) func (process *TeleportProcess) initDeployServiceUpdater() error { @@ -56,8 +61,6 @@ func (process *TeleportProcess) initDeployServiceUpdater() error { return nil } - process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) - // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used var criticalEndpoint string var versionEndpoint string @@ -83,7 +86,7 @@ func (process *TeleportProcess) initDeployServiceUpdater() error { } updater, err := NewDeployServiceUpdater(DeployServiceUpdaterConfig{ - Log: process.log.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "deployserviceupdater")), + Log: process.log.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "aws_oidc_deploy_service_updater")), AuthClient: process.getInstanceClient(), Clock: process.Clock, TeleportClusterName: clusterNameConfig.GetClusterName(), @@ -95,6 +98,7 @@ func (process *TeleportProcess) initDeployServiceUpdater() error { return trace.Wrap(err) } + process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) return trace.Wrap(updater.Run(process.GracefulExitContext())) } @@ -131,7 +135,7 @@ func (cfg *DeployServiceUpdaterConfig) CheckAndSetDefaults() error { } if cfg.Log == nil { - cfg.Log = logrus.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "deployserviceupdater")) + cfg.Log = logrus.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "aws_oidc_deploy_service_updater")) } if cfg.Clock == nil { @@ -199,8 +203,8 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Conte if err != nil { return trace.Wrap(err) } - // cloudStableVersion has vX.Y.Z format, however the container image tag does not include the `v`. - cloudStableVersion := strings.TrimPrefix(stableVersion, "v") + // stableVersion has vX.Y.Z format, however the container image tag does not include the `v`. + stableVersion = strings.TrimPrefix(stableVersion, "v") databases, err := updater.AuthClient.GetDatabases(ctx) if err != nil { @@ -223,79 +227,104 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Conte return trace.Wrap(err) } + // Perform updates in parallel across regions. + var sem = make(chan interface{}, maxConcurrentUpdates) + var wg sync.WaitGroup for _, ig := range integrations { - spec := ig.GetAWSOIDCIntegrationSpec() - if spec == nil { - continue - } - integrationName := ig.GetName() - for region := range awsRegions { - token, err := updater.AuthClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ - Issuer: updater.AWSOIDCProviderAddr, - }) - if err != nil { - return trace.Wrap(err) - } - - req := &awsoidc.AWSClientRequest{ - IntegrationName: ig.GetName(), - Token: token, - RoleARN: spec.RoleARN, - Region: region, - } - - // The deploy service client is initialized using AWS OIDC integration - deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, updater.AuthClient) - if err != nil { - updater.Log.WithError(err).Warning("Failed to update deploy service agents.") - continue - } - - ownershipTags := map[string]string{ - types.ClusterLabel: updater.TeleportClusterName, - types.OriginLabel: types.OriginIntegrationAWSOIDC, - types.IntegrationLabel: integrationName, - } - - // Acquire a lease for the region + integration before attempting to update the deploy service agent. - // If the lease cannot be acquired, the update is already being handled by another instance. - semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ - SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s", region, integrationName), - MaxLeases: 1, - Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), - Holder: "update_deploy_service_agents", - }) - - if err != nil { - if strings.Contains(err.Error(), teleport.MaxLeases) { - updater.Log.Debug("Deploy service agent update is already being processed") - continue - } - return trace.Wrap(err) - } - - updater.Log.Debugf("Updating Deploy Service Agents in AWS region: %s", region) - if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, awsoidc.UpdateServiceRequest{ - TeleportClusterName: updater.TeleportClusterName, - TeleportVersionTag: cloudStableVersion, - OwnershipTags: ownershipTags, - }); err != nil { - if trace.IsNotFound(err) { - // The updater checks each integration/region combination, so - // there will be regions where there is no ECS cluster deployed - // for the integration - continue + sem <- nil + wg.Add(1) + go func(ig types.Integration, region string) { + if err := updater.updateDeployServiceAgent(ctx, ig, region, stableVersion); err != nil { + updater.Log.WithError(err).Warning("Failed to update deploy service agent.") } - updater.Log.WithError(err).Warning("Failed to update deploy service agents.") + wg.Done() + <-sem + }(ig, region) + } + } + wg.Wait() - // Release the semaphore lease on failure so that another instance may attempt the update - if err := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); err != nil { - updater.Log.WithError(err).Error("Failed to cancel semaphore lease.") - } - } + return nil +} + +func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Context, integration types.Integration, awsRegion, teleportVersion string) error { + // Do not attempt update if integration is not an aws oidc integration. + if integration.GetAWSOIDCIntegrationSpec() == nil { + return nil + } + + token, err := updater.AuthClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ + Issuer: updater.AWSOIDCProviderAddr, + }) + if err != nil { + return trace.Wrap(err) + } + + req := &awsoidc.AWSClientRequest{ + IntegrationName: integration.GetName(), + Token: token, + RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, + Region: awsRegion, + } + + // The deploy service client is initialized using AWS OIDC integration. + deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, updater.AuthClient) + if err != nil { + return trace.Wrap(err) + } + + // ownershipTags are used to identify if the ecs resources are managed by the + // teleport integration. + ownershipTags := map[string]string{ + types.ClusterLabel: updater.TeleportClusterName, + types.OriginLabel: types.OriginIntegrationAWSOIDC, + types.IntegrationLabel: integration.GetName(), + } + + // Acquire a lease for the region + integration before attempting to update the deploy service agent. + // If the lease cannot be acquired, the update is already being handled by another instance. + semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s_BERNARD", awsRegion, integration.GetName()), + MaxLeases: 1, + Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), + Holder: "update_deploy_service_agents", + }) + if err != nil { + if strings.Contains(err.Error(), teleport.MaxLeases) { + updater.Log.WithError(err).Debug("Deploy service agent update is already being processed.") + return nil + } + return trace.Wrap(err) + } + + updater.Log.Debugf("Updating Deploy Service Agents for integration %s in AWS region: %s", integration.GetName(), awsRegion) + if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, awsoidc.UpdateServiceRequest{ + TeleportClusterName: updater.TeleportClusterName, + TeleportVersionTag: teleportVersion, + OwnershipTags: ownershipTags, + }); err != nil { + // Release the semaphore lease on failure so that another instance may attempt the update + if cancelErr := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); cancelErr != nil { + updater.Log.WithError(cancelErr).Error("Failed to cancel semaphore lease.") } + + switch { + case trace.IsNotFound(err): + // The updater checks each integration/region combination, so + // there will be regions where there is no ECS cluster deployed + // for the integration. + updater.Log.WithError(err).Debugf("Integration %s does not manage any services within region %s.", integration.GetName(), awsRegion) + return nil + case trace.IsAccessDenied(awslib.ConvertIAMv2Error(trace.Unwrap(err))): + // The aws oidc role may lack permissions due to changes in teleport. + // In this situation users should be notified that they will need to + // re-run the deploy service iam configuration script and update the + // permissions. + updater.Log.WithError(err).Warning("Re-run deploy service configuration script to update permissions.") + } + return trace.Wrap(err) } return nil } From 406b0e11331291cd4c44a02f0135d8d8f26ddcbc Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Thu, 28 Sep 2023 22:13:08 -0700 Subject: [PATCH 12/20] debug --- lib/service/awsoidc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 9207f05d9444f..cda4bb6d25ab6 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -286,7 +286,7 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Contex // If the lease cannot be acquired, the update is already being handled by another instance. semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s_BERNARD", awsRegion, integration.GetName()), + SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s", awsRegion, integration.GetName()), MaxLeases: 1, Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), Holder: "update_deploy_service_agents", From 0f8dc750ed21ee8ded66b69fd653b6869063ee46 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 29 Sep 2023 15:19:37 -0700 Subject: [PATCH 13/20] Address feedback - Check OwnershipTags - Use semaphore pkg - Release semaphore lease on success --- .../awsoidc/deployservice_update.go | 4 +++ lib/service/awsoidc.go | 28 +++++++++---------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 810e7b8fa8dba..4dc79f55576c2 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -51,6 +51,10 @@ func (req *UpdateServiceRequest) CheckAndSetDefaults() error { return trace.BadParameter("teleport version tag required") } + if req.OwnershipTags == nil { + req.OwnershipTags = make(AWSTags) + } + return nil } diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index cda4bb6d25ab6..4df01a8495005 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -21,12 +21,12 @@ import ( "fmt" "net/url" "strings" - "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -228,24 +228,23 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Conte } // Perform updates in parallel across regions. - var sem = make(chan interface{}, maxConcurrentUpdates) - var wg sync.WaitGroup + sem := semaphore.NewWeighted(maxConcurrentUpdates) for _, ig := range integrations { for region := range awsRegions { - sem <- nil - wg.Add(1) + if err := sem.Acquire(ctx, 1); err != nil { + return trace.Wrap(err) + } go func(ig types.Integration, region string) { + defer sem.Release(1) if err := updater.updateDeployServiceAgent(ctx, ig, region, stableVersion); err != nil { - updater.Log.WithError(err).Warning("Failed to update deploy service agent.") + updater.Log.WithError(err).Warningf("Failed to update deploy service agent for integration %s in region %s.", ig.GetName(), region) } - wg.Done() - <-sem }(ig, region) } } - wg.Wait() - return nil + // Wait for all updates to finish. + return trace.Wrap(sem.Acquire(ctx, maxConcurrentUpdates)) } func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Context, integration types.Integration, awsRegion, teleportVersion string) error { @@ -298,6 +297,11 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Contex } return trace.Wrap(err) } + defer func() { + if err := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); err != nil { + updater.Log.WithError(err).Error("Failed to cancel semaphore lease.") + } + }() updater.Log.Debugf("Updating Deploy Service Agents for integration %s in AWS region: %s", integration.GetName(), awsRegion) if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, awsoidc.UpdateServiceRequest{ @@ -305,10 +309,6 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Contex TeleportVersionTag: teleportVersion, OwnershipTags: ownershipTags, }); err != nil { - // Release the semaphore lease on failure so that another instance may attempt the update - if cancelErr := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); cancelErr != nil { - updater.Log.WithError(cancelErr).Error("Failed to cancel semaphore lease.") - } switch { case trace.IsNotFound(err): From 2cc89a3660d17beb3703a729121a53d1c9a91c24 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Fri, 29 Sep 2023 15:23:54 -0700 Subject: [PATCH 14/20] Make OwnershipTags explicitly required --- lib/integrations/awsoidc/deployservice_update.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 4dc79f55576c2..6ac386926aa07 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -52,7 +52,7 @@ func (req *UpdateServiceRequest) CheckAndSetDefaults() error { } if req.OwnershipTags == nil { - req.OwnershipTags = make(AWSTags) + return trace.BadParameter("ownership tags required") } return nil From fec081d114770b57bd29ec4d4acf5bfde990b3fd Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Tue, 3 Oct 2023 23:43:33 -0700 Subject: [PATCH 15/20] Add cluster alert --- api/types/constants.go | 4 ++ lib/authz/permissions.go | 1 + lib/service/awsoidc.go | 83 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/api/types/constants.go b/api/types/constants.go index 5d39e87c3ee41..95aaba5078c8d 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -836,6 +836,10 @@ const ( // AlertLicenseExpired is an internal label that indicates that the license has expired. AlertLicenseExpired = TeleportInternalLabelPrefix + "license-expired-warning" + // AlertAWSOIDCAccessDenied is an internal label that indicates that the aws oidc + // integration permissions need to be reconfigured. + AlertAWSOIDCAccessDenied = TeleportInternalLabelPrefix + "aws-oidc-access-denied" + // TeleportInternalDiscoveryGroupName is the label used to store the name of the discovery group // that the discovered resource is owned by. It is used to differentiate resources // that belong to different discovery services that operate on different sets of resources. diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index c2e93857ce6c9..f09c974bc2cad 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -617,6 +617,7 @@ func roleSpecForProxy(clusterName string) types.RoleSpecV6 { types.NewRule(types.KindSAMLIdPServiceProvider, services.RO()), types.NewRule(types.KindUserGroup, services.RO()), types.NewRule(types.KindClusterMaintenanceConfig, services.RO()), + types.NewRule(types.KindClusterAlert, services.RW()), types.NewRule(types.KindIntegration, append(services.RO(), types.VerbUse)), // this rule allows cloud proxies to read // plugins of `openai` type, since Assist uses the OpenAI API and runs in Proxy. diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 4df01a8495005..8ff554af3e7e6 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net/url" + "path" "strings" "time" @@ -44,6 +45,12 @@ const ( // maxConcurrentUpdates specifies the maximum number of concurrent updates maxConcurrentUpdates = 3 + + // awsOIDCAccessDenied specifies the id for an aws oidc access denied alert + awsOIDCAccessDenied = "aws_oidc_access_denied" + + // awsOIDCTTL specifies the time to live for an access denied error + awsOIDCAlertTTL = time.Hour * 24 ) func (process *TeleportProcess) initDeployServiceUpdater() error { @@ -323,8 +330,84 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Contex // re-run the deploy service iam configuration script and update the // permissions. updater.Log.WithError(err).Warning("Re-run deploy service configuration script to update permissions.") + if err := updater.ensureClusterAlert(ctx, integration, awsRegion); err != nil { + updater.Log.WithError(err).Warning("Failed to ensure cluster alert.") + } } return trace.Wrap(err) } return nil } + +// ensureClusterAlert ensures a cluster alert is created if deploy service permissions +// need to be reconfigured. +func (updater *DeployServiceUpdater) ensureClusterAlert(ctx context.Context, integration types.Integration, awsRegion string) error { + // Acquire semaphore lease before attempting to create a cluster alert. + // If the lease cannot be acquired, a cluster alert is already being created by another instance. + semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: "aws_oidc_access_denied_alert", + MaxLeases: 1, + Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), + Holder: "aws_oidc_access_denied_alert", + }) + if err != nil { + if strings.Contains(err.Error(), teleport.MaxLeases) { + updater.Log.WithError(err).Debug("Cluster alert is already being created.") + return nil + } + return trace.Wrap(err) + } + defer func() { + if err := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); err != nil { + updater.Log.WithError(err).Error("Failed to cancel semaphore lease.") + } + }() + + alertLabels := map[string]string{ + types.AlertAWSOIDCAccessDenied: "yes", + types.AlertOnLogin: "yes", + types.AlertVerbPermit: fmt.Sprintf("%s:%s", types.KindInstance, types.VerbRead), + } + + alerts, err := updater.AuthClient.GetClusterAlerts(ctx, types.GetClusterAlertsRequest{ + Labels: alertLabels, + WithAcknowledged: true, + WithUntargeted: true, + }) + if err != nil { + return trace.Wrap(err) + } + + // Do not create another alert if an alert has already been created + if len(alerts) > 0 { + return nil + } + + clusterConfig, err := updater.AuthClient.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + + scriptURL, err := url.Parse(fmt.Sprintf("https://%swebapi/scripts/integrations/configure/deployservice-iam.sh", clusterConfig.GetClusterName())) + if err != nil { + return trace.Wrap(err) + } + values := scriptURL.Query() + values.Add("integrationName", integration.GetName()) + values.Add("awsRegion", awsRegion) + values.Add("taskRole", url.QueryEscape("")) // The task role needs to be supplied by the user + values.Add("taskRole", path.Base(integration.GetAWSOIDCIntegrationSpec().RoleARN)) + scriptURL.RawQuery = values.Encode() + + message := fmt.Sprintf("Re-run AWS OIDC integration configuration script to update permissions: %s", scriptURL.String()) + + alert, err := types.NewClusterAlert(awsOIDCAccessDenied, message, types.WithAlertSeverity(types.AlertSeverity_LOW)) + if err != nil { + return trace.Wrap(err) + } + alert.Metadata.Labels = alertLabels + alert.SetExpiry(updater.Clock.Now().Add(awsOIDCAlertTTL)) + + return trace.Wrap(updater.AuthClient.UpsertClusterAlert(ctx, alert)) +} From 3978c827a5bec1dd2033985b0986ec8e74fca7db Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Wed, 4 Oct 2023 15:53:28 -0700 Subject: [PATCH 16/20] Fix typo and update message --- lib/service/awsoidc.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 8ff554af3e7e6..d9e4025783770 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -389,18 +389,19 @@ func (updater *DeployServiceUpdater) ensureClusterAlert(ctx context.Context, int return trace.Wrap(err) } - scriptURL, err := url.Parse(fmt.Sprintf("https://%swebapi/scripts/integrations/configure/deployservice-iam.sh", clusterConfig.GetClusterName())) + scriptURL, err := url.Parse(fmt.Sprintf("https://%s/webapi/scripts/integrations/configure/deployservice-iam.sh", clusterConfig.GetClusterName())) if err != nil { return trace.Wrap(err) } values := scriptURL.Query() values.Add("integrationName", integration.GetName()) values.Add("awsRegion", awsRegion) - values.Add("taskRole", url.QueryEscape("")) // The task role needs to be supplied by the user - values.Add("taskRole", path.Base(integration.GetAWSOIDCIntegrationSpec().RoleARN)) + values.Add("role", path.Base(integration.GetAWSOIDCIntegrationSpec().RoleARN)) + values.Add("taskRole", "TASK_ROLE") // The task role needs to be supplied by the user scriptURL.RawQuery = values.Encode() - message := fmt.Sprintf("Re-run AWS OIDC integration configuration script to update permissions: %s", scriptURL.String()) + cmd := fmt.Sprintf(`bash -c "$(curl '%s')"`, scriptURL) + message := fmt.Sprintf("Open Amazon CloudShell and copy/paste the following command to reconfigure integration. Replace TASK_ROLE with desired deploy service task role name: %s", cmd) alert, err := types.NewClusterAlert(awsOIDCAccessDenied, message, types.WithAlertSeverity(types.AlertSeverity_LOW)) if err != nil { From cf4c36e8a90cd97a82607013632b75af5f29071f Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Wed, 4 Oct 2023 16:47:07 -0700 Subject: [PATCH 17/20] Revert cluster alert --- api/types/constants.go | 4 -- lib/authz/permissions.go | 1 - lib/service/awsoidc.go | 84 ---------------------------------------- 3 files changed, 89 deletions(-) diff --git a/api/types/constants.go b/api/types/constants.go index 95aaba5078c8d..5d39e87c3ee41 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -836,10 +836,6 @@ const ( // AlertLicenseExpired is an internal label that indicates that the license has expired. AlertLicenseExpired = TeleportInternalLabelPrefix + "license-expired-warning" - // AlertAWSOIDCAccessDenied is an internal label that indicates that the aws oidc - // integration permissions need to be reconfigured. - AlertAWSOIDCAccessDenied = TeleportInternalLabelPrefix + "aws-oidc-access-denied" - // TeleportInternalDiscoveryGroupName is the label used to store the name of the discovery group // that the discovered resource is owned by. It is used to differentiate resources // that belong to different discovery services that operate on different sets of resources. diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index f09c974bc2cad..c2e93857ce6c9 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -617,7 +617,6 @@ func roleSpecForProxy(clusterName string) types.RoleSpecV6 { types.NewRule(types.KindSAMLIdPServiceProvider, services.RO()), types.NewRule(types.KindUserGroup, services.RO()), types.NewRule(types.KindClusterMaintenanceConfig, services.RO()), - types.NewRule(types.KindClusterAlert, services.RW()), types.NewRule(types.KindIntegration, append(services.RO(), types.VerbUse)), // this rule allows cloud proxies to read // plugins of `openai` type, since Assist uses the OpenAI API and runs in Proxy. diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index d9e4025783770..4df01a8495005 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "net/url" - "path" "strings" "time" @@ -45,12 +44,6 @@ const ( // maxConcurrentUpdates specifies the maximum number of concurrent updates maxConcurrentUpdates = 3 - - // awsOIDCAccessDenied specifies the id for an aws oidc access denied alert - awsOIDCAccessDenied = "aws_oidc_access_denied" - - // awsOIDCTTL specifies the time to live for an access denied error - awsOIDCAlertTTL = time.Hour * 24 ) func (process *TeleportProcess) initDeployServiceUpdater() error { @@ -330,85 +323,8 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Contex // re-run the deploy service iam configuration script and update the // permissions. updater.Log.WithError(err).Warning("Re-run deploy service configuration script to update permissions.") - if err := updater.ensureClusterAlert(ctx, integration, awsRegion); err != nil { - updater.Log.WithError(err).Warning("Failed to ensure cluster alert.") - } } return trace.Wrap(err) } return nil } - -// ensureClusterAlert ensures a cluster alert is created if deploy service permissions -// need to be reconfigured. -func (updater *DeployServiceUpdater) ensureClusterAlert(ctx context.Context, integration types.Integration, awsRegion string) error { - // Acquire semaphore lease before attempting to create a cluster alert. - // If the lease cannot be acquired, a cluster alert is already being created by another instance. - semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ - SemaphoreKind: types.SemaphoreKindConnection, - SemaphoreName: "aws_oidc_access_denied_alert", - MaxLeases: 1, - Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), - Holder: "aws_oidc_access_denied_alert", - }) - if err != nil { - if strings.Contains(err.Error(), teleport.MaxLeases) { - updater.Log.WithError(err).Debug("Cluster alert is already being created.") - return nil - } - return trace.Wrap(err) - } - defer func() { - if err := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); err != nil { - updater.Log.WithError(err).Error("Failed to cancel semaphore lease.") - } - }() - - alertLabels := map[string]string{ - types.AlertAWSOIDCAccessDenied: "yes", - types.AlertOnLogin: "yes", - types.AlertVerbPermit: fmt.Sprintf("%s:%s", types.KindInstance, types.VerbRead), - } - - alerts, err := updater.AuthClient.GetClusterAlerts(ctx, types.GetClusterAlertsRequest{ - Labels: alertLabels, - WithAcknowledged: true, - WithUntargeted: true, - }) - if err != nil { - return trace.Wrap(err) - } - - // Do not create another alert if an alert has already been created - if len(alerts) > 0 { - return nil - } - - clusterConfig, err := updater.AuthClient.GetClusterName() - if err != nil { - return trace.Wrap(err) - } - - scriptURL, err := url.Parse(fmt.Sprintf("https://%s/webapi/scripts/integrations/configure/deployservice-iam.sh", clusterConfig.GetClusterName())) - if err != nil { - return trace.Wrap(err) - } - values := scriptURL.Query() - values.Add("integrationName", integration.GetName()) - values.Add("awsRegion", awsRegion) - values.Add("role", path.Base(integration.GetAWSOIDCIntegrationSpec().RoleARN)) - values.Add("taskRole", "TASK_ROLE") // The task role needs to be supplied by the user - scriptURL.RawQuery = values.Encode() - - cmd := fmt.Sprintf(`bash -c "$(curl '%s')"`, scriptURL) - message := fmt.Sprintf("Open Amazon CloudShell and copy/paste the following command to reconfigure integration. Replace TASK_ROLE with desired deploy service task role name: %s", cmd) - - alert, err := types.NewClusterAlert(awsOIDCAccessDenied, message, types.WithAlertSeverity(types.AlertSeverity_LOW)) - if err != nil { - return trace.Wrap(err) - } - alert.Metadata.Labels = alertLabels - alert.SetExpiry(updater.Clock.Now().Add(awsOIDCAlertTTL)) - - return trace.Wrap(updater.AuthClient.UpsertClusterAlert(ctx, alert)) -} From ffde7f24449bf2feb47b4e057bfb728c2114175b Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Wed, 4 Oct 2023 17:14:58 -0700 Subject: [PATCH 18/20] Update err messages --- lib/integrations/awsoidc/deployservice_update.go | 8 +++++--- lib/service/awsoidc.go | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go index 6ac386926aa07..964a4f69add50 100644 --- a/lib/integrations/awsoidc/deployservice_update.go +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -99,7 +99,9 @@ func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, req // Update service with new task definition _, err = updateServiceOrRollback(ctx, clt, service, registerTaskDefinitionOut.TaskDefinition) if err != nil { - // If update failed, then rollback task definition + // If update failed, then rollback task definition. + // The update will be re-attempted during the next interval if it is still + // within the upgrade window or the critical upgrade flag is still enabled. _, rollbackErr := clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ TaskDefinition: registerTaskDefinitionOut.TaskDefinition.TaskDefinitionArn, }) @@ -199,7 +201,7 @@ func updateServiceOrRollback(ctx context.Context, clt DeployServiceClient, servi // then rollback service with previous task definition rollbackServiceOut, rollbackErr := clt.UpdateService(ctx, generateServiceWithTaskDefinition(service, aws.ToString(service.TaskDefinition))) if rollbackErr != nil { - return nil, trace.Wrap(err, "failed to rollback service: %v", err) + return nil, trace.NewAggregate(err, trace.Wrap(rollbackErr, "failed to rollback service")) } rollbackErr = serviceStableWaiter.Wait(ctx, &ecs.DescribeServicesInput{ @@ -207,7 +209,7 @@ func updateServiceOrRollback(ctx context.Context, clt DeployServiceClient, servi Cluster: updateServiceOut.Service.ClusterArn, }, waitDuration) if rollbackErr != nil { - return nil, trace.Wrap(err, "failed to rollback service: %v", err) + return nil, trace.NewAggregate(err, trace.Wrap(rollbackErr, "failed to rollback service")) } return nil, trace.Wrap(err) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 4df01a8495005..8cb8e1f2f2a94 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -315,14 +315,14 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Contex // The updater checks each integration/region combination, so // there will be regions where there is no ECS cluster deployed // for the integration. - updater.Log.WithError(err).Debugf("Integration %s does not manage any services within region %s.", integration.GetName(), awsRegion) + updater.Log.Debugf("Integration %s does not manage any services within region %s.", integration.GetName(), awsRegion) return nil case trace.IsAccessDenied(awslib.ConvertIAMv2Error(trace.Unwrap(err))): // The aws oidc role may lack permissions due to changes in teleport. // In this situation users should be notified that they will need to // re-run the deploy service iam configuration script and update the // permissions. - updater.Log.WithError(err).Warning("Re-run deploy service configuration script to update permissions.") + updater.Log.WithError(err).Warning("Update integration role and add missing permissions.") } return trace.Wrap(err) } From e381d260fd3c51e8282a0929fb829de4ce8d275c Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Tue, 10 Oct 2023 10:45:29 -0700 Subject: [PATCH 19/20] Check minimum compatible server version --- lib/service/awsoidc.go | 34 ++++++++++++++++++++------ lib/utils/ver.go | 12 +++++++++ lib/utils/ver_test.go | 55 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 8cb8e1f2f2a94..81269d58c341e 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/automaticupgrades" awslib "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/integrations/awsoidc" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/interval" ) @@ -86,13 +87,14 @@ func (process *TeleportProcess) initDeployServiceUpdater() error { } updater, err := NewDeployServiceUpdater(DeployServiceUpdaterConfig{ - Log: process.log.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "aws_oidc_deploy_service_updater")), - AuthClient: process.getInstanceClient(), - Clock: process.Clock, - TeleportClusterName: clusterNameConfig.GetClusterName(), - AWSOIDCProviderAddr: issuer, - CriticalEndpoint: criticalEndpoint, - VersionEndpoint: versionEndpoint, + Log: process.log.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "aws_oidc_deploy_service_updater")), + AuthClient: process.getInstanceClient(), + Clock: process.Clock, + TeleportClusterName: clusterNameConfig.GetClusterName(), + TeleportClusterVersion: resp.GetServerVersion(), + AWSOIDCProviderAddr: issuer, + CriticalEndpoint: criticalEndpoint, + VersionEndpoint: versionEndpoint, }) if err != nil { return trace.Wrap(err) @@ -112,6 +114,8 @@ type DeployServiceUpdaterConfig struct { Clock clockwork.Clock // TeleportClusterName specifies the teleport cluster name TeleportClusterName string + // TeleportClusterVersion specifies the teleport cluster version + TeleportClusterVersion string // AWSOIDCProvderAddr specifies the aws oidc provider address used to generate AWS OIDC tokens AWSOIDCProviderAddr string // CriticalEndpoint specifies the endpoint to check for critical updates @@ -130,6 +134,10 @@ func (cfg *DeployServiceUpdaterConfig) CheckAndSetDefaults() error { return trace.BadParameter("teleport cluster name required") } + if cfg.TeleportClusterVersion == "" { + return trace.BadParameter("teleport cluster version required") + } + if cfg.AWSOIDCProviderAddr == "" { return trace.BadParameter("aws oidc provider address required") } @@ -206,6 +214,18 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Conte // stableVersion has vX.Y.Z format, however the container image tag does not include the `v`. stableVersion = strings.TrimPrefix(stableVersion, "v") + // minServerVersion specifies the minimum version of the cluster required for + // updated agents to remain compatible with the cluster. + minServerVersion, err := utils.MajorSemver(stableVersion) + if err != nil { + return trace.Wrap(err) + } + + if !utils.MeetsVersion(updater.TeleportClusterVersion, minServerVersion) { + updater.Log.Debugf("Server does not meet the minimum required version for agents to be updated: %v.", minServerVersion) + return nil + } + databases, err := updater.AuthClient.GetDatabases(ctx) if err != nil { return trace.Wrap(err) diff --git a/lib/utils/ver.go b/lib/utils/ver.go index 34f796853e23a..9b4965ae19e9a 100644 --- a/lib/utils/ver.go +++ b/lib/utils/ver.go @@ -17,6 +17,8 @@ limitations under the License. package utils import ( + "fmt" + "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" ) @@ -68,6 +70,16 @@ func MinVerWithoutPreRelease(currentVersion, minVersion string) (bool, error) { return !currentSemver.LessThan(*minSemver), nil } +// MajorSemver returns the major version as a semver string. +// Ex: 13.4.3 -> 13.0.0 +func MajorSemver(version string) (string, error) { + ver, err := semver.NewVersion(version) + if err != nil { + return "", trace.Wrap(err) + } + return fmt.Sprintf("%d.0.0", ver.Major), nil +} + func versionStringToSemver(ver1, ver2 string) (*semver.Version, *semver.Version, error) { v1Semver, err := semver.NewVersion(ver1) if err != nil { diff --git a/lib/utils/ver_test.go b/lib/utils/ver_test.go index a78c0e30b96cc..17ee9ae54c989 100644 --- a/lib/utils/ver_test.go +++ b/lib/utils/ver_test.go @@ -16,7 +16,11 @@ limitations under the License. package utils -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) func TestMeetsVersion_emptyOrInvalid(t *testing.T) { // See TestVersions for more comprehensive tests. @@ -33,6 +37,55 @@ func TestMeetsVersion_emptyOrInvalid(t *testing.T) { } } +func TestMajorSemver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + version string + expected string + want bool + wantErr bool + }{ + { + name: "simple semver", + version: "13.4.2", + expected: "13.0.0", + want: true, + }, + { + name: "ignores suffix", + version: "13.4.2-dev", + expected: "13.0.0", + want: true, + }, + { + name: "empty version is rejected", + version: "", + expected: "", + wantErr: true, + }, + { + name: "incorrect version is rejected", + version: "13.4", // missing patch + expected: "13.0.0", + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := MajorSemver(tt.version) + if tt.wantErr { + require.Error(t, err) + } else { + require.Equal(t, tt.expected, got) + } + }) + } +} + func TestMinVerWithoutPreRelease(t *testing.T) { t.Parallel() From d27caf2834c89fdef16932e405307f7aa87f2c0d Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Tue, 10 Oct 2023 14:42:12 -0700 Subject: [PATCH 20/20] Update log msg --- lib/service/awsoidc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go index 81269d58c341e..0ad4d007af523 100644 --- a/lib/service/awsoidc.go +++ b/lib/service/awsoidc.go @@ -222,7 +222,7 @@ func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Conte } if !utils.MeetsVersion(updater.TeleportClusterVersion, minServerVersion) { - updater.Log.Debugf("Server does not meet the minimum required version for agents to be updated: %v.", minServerVersion) + updater.Log.Debugf("Skipping update. %v agents will not be compatible with a %v cluster.", stableVersion, updater.TeleportClusterVersion) return nil }