diff --git a/api/client/client.go b/api/client/client.go index 9dbb6b5a6da47..5a403934e7929 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -3907,6 +3907,22 @@ func (c *Client) ListIntegrations(ctx context.Context, pageSize int, nextKey str return integrations, resp.GetNextKey(), nil } +// ListAllIntegrations returns the list of all Integrations. +func (c *Client) ListAllIntegrations(ctx context.Context) ([]types.Integration, error) { + var result []types.Integration + var nextKey string + for { + integrations, nextKey, err := c.ListIntegrations(ctx, 0, nextKey) + if err != nil { + return nil, trace.Wrap(err) + } + result = append(result, integrations...) + if nextKey == "" { + return result, nil + } + } +} + // GetIntegration returns an Integration by its name. func (c *Client) GetIntegration(ctx context.Context, name string) (types.Integration, error) { ig, err := c.integrationsClient().GetIntegration(ctx, &integrationpb.GetIntegrationRequest{ diff --git a/api/types/maintenance.go b/api/types/maintenance.go index a71976d297187..9cab6a9ad4765 100644 --- a/api/types/maintenance.go +++ b/api/types/maintenance.go @@ -147,6 +147,10 @@ type ClusterMaintenanceConfig interface { // SetAgentUpgradeWindow sets the agent upgrade window. SetAgentUpgradeWindow(win AgentUpgradeWindow) + // WithinUpgradeWindow returns true if the time is within the configured + // upgrade window. + WithinUpgradeWindow(t time.Time) bool + CheckAndSetDefaults() error } @@ -229,3 +233,28 @@ func (m *ClusterMaintenanceConfigV1) GetAgentUpgradeWindow() (win AgentUpgradeWi func (m *ClusterMaintenanceConfigV1) SetAgentUpgradeWindow(win AgentUpgradeWindow) { m.Spec.AgentUpgrades = &win } + +// WithinUpgradeWindow returns true if the time is within the configured +// upgrade window. +func (m *ClusterMaintenanceConfigV1) WithinUpgradeWindow(t time.Time) bool { + upgradeWindow, ok := m.GetAgentUpgradeWindow() + if !ok { + return false + } + + if len(upgradeWindow.Weekdays) == 0 { + if int(upgradeWindow.UTCStartHour) == t.Hour() { + return true + } + } + + weekday := t.Weekday().String() + for _, upgradeWeekday := range upgradeWindow.Weekdays { + if weekday == upgradeWeekday { + if int(upgradeWindow.UTCStartHour) == t.Hour() { + return true + } + } + } + return false +} diff --git a/api/types/maintenance_test.go b/api/types/maintenance_test.go index 990d1f9670f38..203006a8dee37 100644 --- a/api/types/maintenance_test.go +++ b/api/types/maintenance_test.go @@ -214,3 +214,60 @@ func TestWeekdayParser(t *testing.T) { require.Equal(t, tt.expect, day) } } + +func TestWithinUpgradeWindow(t *testing.T) { + t.Parallel() + + tests := []struct { + desc string + upgradeWindow AgentUpgradeWindow + date string + withinWindow bool + }{ + { + desc: "within upgrade window", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: true, + }, + { + desc: "not within upgrade window", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + }, + date: "Mon, 02 Jan 2006 09:04:05 UTC", + withinWindow: false, + }, + { + desc: "within upgrade window weekday", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + Weekdays: []string{"Monday"}, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: true, + }, + { + desc: "not within upgrade window weekday", + upgradeWindow: AgentUpgradeWindow{ + UTCStartHour: 8, + Weekdays: []string{"Tuesday"}, + }, + date: "Mon, 02 Jan 2006 08:04:05 UTC", + withinWindow: false, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cmc := NewClusterMaintenanceConfig() + cmc.SetAgentUpgradeWindow(tt.upgradeWindow) + + date, err := time.Parse(time.RFC1123, tt.date) + require.NoError(t, err) + require.Equal(t, tt.withinWindow, cmc.WithinUpgradeWindow(date)) + }) + } +} diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index 7907abb684704..6d52298059bd1 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -616,6 +616,7 @@ func roleSpecForProxy(clusterName string) types.RoleSpecV6 { types.NewRule(types.KindDatabaseService, services.RO()), types.NewRule(types.KindSAMLIdPServiceProvider, services.RO()), types.NewRule(types.KindUserGroup, services.RO()), + types.NewRule(types.KindClusterMaintenanceConfig, services.RO()), types.NewRule(types.KindIntegration, append(services.RO(), types.VerbUse)), // this rule allows cloud proxies to read // plugins of `openai` type, since Assist uses the OpenAI API and runs in Proxy. diff --git a/lib/automaticupgrades/config.go b/lib/automaticupgrades/config.go index 1333df6935283..56d0f4a297c33 100644 --- a/lib/automaticupgrades/config.go +++ b/lib/automaticupgrades/config.go @@ -26,6 +26,9 @@ import ( const ( // automaticUpgradesEnvar defines the env var to lookup when deciding whether to enable AutomaticUpgrades feature. automaticUpgradesEnvar = "TELEPORT_AUTOMATIC_UPGRADES" + + // automaticUpgradesChannelEnvar defines a customer automatic upgrades version release channel. + automaticUpgradesChannelEnvar = "TELEPORT_AUTOMATIC_UPGRADES_CHANNEL" ) // IsEnabled reads the TELEPORT_AUTOMATIC_UPGRADES and returns whether Automatic Upgrades are enabled or disabled. @@ -46,3 +49,10 @@ func IsEnabled() bool { return automaticUpgrades } + +// GetChannel returns the TELEPORT_AUTOMATIC_UPGRADES_CHANNEL value. +// Example of an acceptable value for TELEPORT_AUTOMATIC_UPGRADES_CHANNEL is: +// https://updates.releases.teleport.dev/v1/stable/cloud +func GetChannel() string { + return os.Getenv(automaticUpgradesChannelEnvar) +} diff --git a/lib/automaticupgrades/version.go b/lib/automaticupgrades/version.go index 607d32ef9bbd9..d1f6c393f1605 100644 --- a/lib/automaticupgrades/version.go +++ b/lib/automaticupgrades/version.go @@ -34,22 +34,54 @@ const ( // stableCloudVersionPath is the URL path that returns the current stable/cloud version. stableCloudVersionPath = "/v1/stable/cloud/version" + + // stableCloudCriticalPath is the URL path that returns the stable/cloud critical flag. + stableCloudCriticalPath = "/v1/stable/cloud/critical" ) // Version returns the version that should be used for installing Teleport Services // This is used when installing agents using scripts. // Even when Teleport Auth/Proxy is using vX, the agents must always respect this version. -func Version(ctx context.Context, baseURL string) (string, error) { - if baseURL == "" { - baseURL = stableCloudVersionBaseURL +func Version(ctx context.Context, versionURL string) (string, error) { + versionURL, err := getVersionURL(versionURL) + if err != nil { + return "", trace.Wrap(err) } - fullURL, err := url.JoinPath(baseURL, stableCloudVersionPath) + resp, err := sendRequest(ctx, versionURL) if err != nil { return "", trace.Wrap(err) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + return resp, nil +} + +// Critical returns true if a critical upgrade is available. +func Critical(ctx context.Context, criticalURL string) (bool, error) { + criticalURL, err := getCriticalURL(criticalURL) + if err != nil { + return false, trace.Wrap(err) + } + + critical, err := sendRequest(ctx, criticalURL) + if err != nil { + return false, trace.Wrap(err) + } + + // Expects critical endpoint to return either the string "yes" or "no" + switch critical { + case "yes": + return true, nil + case "no": + return false, nil + default: + return false, trace.BadParameter("critical endpoint returned an unexpected value: %v", critical) + } +} + +// sendRequest sends a GET request to the reqURL and returns the response value +func sendRequest(ctx context.Context, reqURL string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { return "", trace.Wrap(err) } @@ -69,7 +101,29 @@ func Version(ctx context.Context, baseURL string) (string, error) { return "", trace.BadParameter("invalid status code %d, body: %s", resp.StatusCode, string(body)) } - versionString := strings.TrimSpace(string(body)) + return strings.TrimSpace(string(body)), trace.Wrap(err) +} + +// getVersionURL returns the versionURL or the default stable/cloud version url. +func getVersionURL(versionURL string) (string, error) { + if versionURL != "" { + return versionURL, nil + } + cloudStableVersionURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudVersionPath) + if err != nil { + return "", trace.Wrap(err) + } + return cloudStableVersionURL, nil +} - return versionString, trace.Wrap(err) +// getCriticalURL returns the criticalURL or the default stable/cloud critical url. +func getCriticalURL(criticalURL string) (string, error) { + if criticalURL != "" { + return criticalURL, nil + } + cloudStableCriticalURL, err := url.JoinPath(stableCloudVersionBaseURL, stableCloudCriticalPath) + if err != nil { + return "", trace.Wrap(err) + } + return cloudStableCriticalURL, nil } diff --git a/lib/automaticupgrades/version_test.go b/lib/automaticupgrades/version_test.go index e540dbc1cb34b..4ba8d299201e3 100644 --- a/lib/automaticupgrades/version_test.go +++ b/lib/automaticupgrades/version_test.go @@ -20,6 +20,7 @@ import ( "context" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/gravitational/trace" @@ -68,19 +69,133 @@ func TestVersion(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.URL.Path, "/v1/stable/cloud/version") + assert.Equal(t, "/v1/stable/cloud/version", r.URL.Path) w.WriteHeader(tt.mockStatusCode) w.Write([]byte(tt.mockResponseString)) })) defer httpTestServer.Close() - v, err := Version(ctx, httpTestServer.URL) + versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") + require.NoError(t, err) + + v, err := Version(ctx, versionURL) tt.errCheck(t, err) if err != nil { return } - require.Equal(t, v, tt.expectedVersion) + require.Equal(t, tt.expectedVersion, v) + }) + } +} + +func TestCritical(t *testing.T) { + ctx := context.Background() + + isBadParameterErr := func(tt require.TestingT, err error, i ...any) { + require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) + } + + for _, tt := range []struct { + name string + mockStatusCode int + mockResponseString string + errCheck require.ErrorAssertionFunc + expectedCritical bool + }{ + { + name: "critical available", + mockStatusCode: http.StatusOK, + mockResponseString: "yes\n", + errCheck: require.NoError, + expectedCritical: true, + }, + { + name: "critical is not available", + mockStatusCode: http.StatusOK, + mockResponseString: "no\n", + errCheck: require.NoError, + expectedCritical: false, + }, + { + name: "invalid status code (500)", + mockStatusCode: http.StatusInternalServerError, + errCheck: isBadParameterErr, + }, + { + name: "invalid status code (403)", + mockStatusCode: http.StatusForbidden, + errCheck: isBadParameterErr, + }, + } { + t.Run(tt.name, func(t *testing.T) { + httpTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/stable/cloud/critical", r.URL.Path) + w.WriteHeader(tt.mockStatusCode) + w.Write([]byte(tt.mockResponseString)) + })) + defer httpTestServer.Close() + + criticalURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/critical") + require.NoError(t, err) + + v, err := Critical(ctx, criticalURL) + tt.errCheck(t, err) + if err != nil { + return + } + + require.Equal(t, tt.expectedCritical, v) + }) + } +} + +func TestGetVersionURL(t *testing.T) { + for _, tt := range []struct { + name string + versionURL string + expectedURL string + }{ + { + name: "default stable/cloud version url", + versionURL: "", + expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/version", + }, + { + name: "custom version url", + versionURL: "https://custom.dev/version", + expectedURL: "https://custom.dev/version", + }, + } { + t.Run(tt.name, func(t *testing.T) { + v, err := getVersionURL(tt.versionURL) + require.NoError(t, err) + require.Equal(t, tt.expectedURL, v) + }) + } +} + +func TestGetCriticalURL(t *testing.T) { + for _, tt := range []struct { + name string + criticalURL string + expectedURL string + }{ + { + name: "default stable/cloud critical url", + criticalURL: "", + expectedURL: "https://updates.releases.teleport.dev/v1/stable/cloud/critical", + }, + { + name: "custom critical url", + criticalURL: "https://custom.dev/critical", + expectedURL: "https://custom.dev/critical", + }, + } { + t.Run(tt.name, func(t *testing.T) { + v, err := getCriticalURL(tt.criticalURL) + require.NoError(t, err) + require.Equal(t, tt.expectedURL, v) }) } } diff --git a/lib/cloud/aws/policy_statements.go b/lib/cloud/aws/policy_statements.go index 2dc33b954e621..25c7bd71848bd 100644 --- a/lib/cloud/aws/policy_statements.go +++ b/lib/cloud/aws/policy_statements.go @@ -50,7 +50,7 @@ func StatementForECSManageService() *Statement { Actions: []string{ "ecs:DescribeClusters", "ecs:CreateCluster", "ecs:PutClusterCapacityProviders", "ecs:DescribeServices", "ecs:CreateService", "ecs:UpdateService", - "ecs:RegisterTaskDefinition", + "ecs:RegisterTaskDefinition", "ecs:DescribeTaskDefinition", "ecs:DeregisterTaskDefinition", // EC2 DescribeSecurityGroups is required so that the user can list the SG and then pick which ones they want to apply to the ECS Service. "ec2:DescribeSecurityGroups", diff --git a/lib/integrations/awsoidc/deployservice.go b/lib/integrations/awsoidc/deployservice.go index 2a89717e94657..77060ed908939 100644 --- a/lib/integrations/awsoidc/deployservice.go +++ b/lib/integrations/awsoidc/deployservice.go @@ -58,6 +58,11 @@ var ( ) const ( + // distrolessTeleportOSS is the distroless image of the OSS version of Teleport + distrolessTeleportOSS = "public.ecr.aws/gravitational/teleport-distroless" + // distrolessTeleportEnt is the distroless image of the Enterprise version of Teleport + distrolessTeleportEnt = "public.ecr.aws/gravitational/teleport-ent-distroless" + // clusterStatusActive is the string representing an ACTIVE ECS Cluster. clusterStatusActive = "ACTIVE" // clusterStatusInactive is the string representing an INACTIVE ECS Cluster. @@ -175,12 +180,26 @@ func normalizeECSResourceName(name string) string { return replacer.Replace(name) } +// normalizeECSClusterName returns the normalized ECS Cluster Name +func normalizeECSClusterName(teleportClusterName string) string { + return normalizeECSResourceName(fmt.Sprintf("%s-teleport", teleportClusterName)) +} + +// normalizeECSServiceName returns the normalized ECS Service Name +func normalizeECSServiceName(teleportClusterName, deploymentMode string) string { + return normalizeECSResourceName(fmt.Sprintf("%s-teleport-%s", teleportClusterName, deploymentMode)) +} + +// normalizeECSTaskName returns the normalized ECS TaskDefinition Family +func normalizeECSTaskName(teleportClusterName, deploymentMode string) string { + return normalizeECSResourceName(fmt.Sprintf("%s-teleport-%s", teleportClusterName, deploymentMode)) +} + // CheckAndSetDefaults checks if the required fields are present. func (r *DeployServiceRequest) CheckAndSetDefaults() error { if r.TeleportClusterName == "" { return trace.BadParameter("teleport cluster name is required") } - baseResourceName := normalizeECSResourceName(r.TeleportClusterName) if r.TeleportVersionTag == "" { r.TeleportVersionTag = teleport.Version @@ -211,17 +230,17 @@ func (r *DeployServiceRequest) CheckAndSetDefaults() error { } if r.ClusterName == nil || *r.ClusterName == "" { - clusterName := fmt.Sprintf("%s-teleport", baseResourceName) + clusterName := normalizeECSClusterName(r.TeleportClusterName) r.ClusterName = &clusterName } if r.ServiceName == nil || *r.ServiceName == "" { - serviceName := fmt.Sprintf("%s-teleport-%s", baseResourceName, r.DeploymentMode) + serviceName := normalizeECSServiceName(r.TeleportClusterName, r.DeploymentMode) r.ServiceName = &serviceName } if r.TaskName == nil || *r.TaskName == "" { - taskName := fmt.Sprintf("%s-teleport-%s", baseResourceName, r.DeploymentMode) + taskName := normalizeECSTaskName(r.TeleportClusterName, r.DeploymentMode) r.TaskName = &taskName } @@ -285,10 +304,18 @@ type DeployServiceClient interface { // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.CreateService CreateService(ctx context.Context, params *ecs.CreateServiceInput, optFns ...func(*ecs.Options)) (*ecs.CreateServiceOutput, error) + // DescribeTaskDefinition describes the task definition. + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DescribeTaskDefinition + DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) + // RegisterTaskDefinition registers a new task definition from the supplied family and containerDefinitions. // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.RegisterTaskDefinition RegisterTaskDefinition(ctx context.Context, params *ecs.RegisterTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.RegisterTaskDefinitionOutput, error) + // DeregisterTaskDefinition deregisters the task definition. + // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ecs@v1.27.1#Client.DeregisterTaskDefinition + DeregisterTaskDefinition(ctx context.Context, params *ecs.DeregisterTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DeregisterTaskDefinitionOutput, error) + // TokenService are the required methods to manage the IAM Join Token. // When the deployed service connects to the cluster, it will use the IAM Join method. // Before deploying the service, it must ensure that the token exists and has the appropriate token rul. @@ -427,11 +454,7 @@ func DeployService(ctx context.Context, clt DeployServiceClient, req DeployServi // upsertTask ensures a TaskDefinition with TaskName exists func upsertTask(ctx context.Context, clt DeployServiceClient, req DeployServiceRequest, configB64 string) (*ecsTypes.TaskDefinition, error) { - teleportFlavor := types.PackageNameOSS - if modules.GetModules().BuildType() == modules.BuildEnterprise { - teleportFlavor = types.PackageNameEnt - } - taskAgentContainerImage := fmt.Sprintf("public.ecr.aws/gravitational/%s-distroless:%s", teleportFlavor, req.TeleportVersionTag) + taskAgentContainerImage := getDistrolessTeleportImage(req.TeleportVersionTag) taskDefOut, err := clt.RegisterTaskDefinition(ctx, &ecs.RegisterTaskDefinitionInput{ Family: req.TaskName, @@ -450,11 +473,16 @@ func upsertTask(ctx context.Context, clt DeployServiceClient, req DeployServiceR Value: aws.String("true"), }}, Command: []string{ + // --rewrite 15:3 rewrites SIGTERM -> SIGQUIT. This enables graceful shutdown of teleport + "--rewrite", + "15:3", + "--", + "teleport", "start", "--config-string", configB64, }, - EntryPoint: []string{"teleport"}, + EntryPoint: []string{"/usr/bin/dumb-init"}, Image: &taskAgentContainerImage, Name: &taskAgentContainerName, LogConfiguration: &ecsTypes.LogConfiguration{ @@ -678,3 +706,12 @@ func upsertService(ctx context.Context, clt DeployServiceClient, req DeployServi return createServiceOut.Service, nil } + +// getDistrolessTeleportImage returns the distroless teleport image string +func getDistrolessTeleportImage(version string) string { + teleportImage := distrolessTeleportOSS + if modules.GetModules().BuildType() == modules.BuildEnterprise { + teleportImage = distrolessTeleportEnt + } + return fmt.Sprintf("%s:%s", teleportImage, version) +} diff --git a/lib/integrations/awsoidc/deployservice_update.go b/lib/integrations/awsoidc/deployservice_update.go new file mode 100644 index 0000000000000..964a4f69add50 --- /dev/null +++ b/lib/integrations/awsoidc/deployservice_update.go @@ -0,0 +1,246 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package awsoidc + +import ( + "context" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" +) + +// waitDuration specifies the amount of time to wait for a service to become healthy after an update. +const waitDuration = time.Minute * 5 + +// UpdateServiceRequest contains the required fields to update a Teleport Service. +type UpdateServiceRequest struct { + // TeleportClusterName specifies the teleport cluster name + TeleportClusterName string + // TeleportVersionTag specifies the desired teleport version in the format "13.4.0" + TeleportVersionTag string + // OwnershipTags specifies ownership tags + OwnershipTags AWSTags +} + +// CheckAndSetDefaults checks and sets default config values. +func (req *UpdateServiceRequest) CheckAndSetDefaults() error { + if req.TeleportClusterName == "" { + return trace.BadParameter("teleport cluster name required") + } + + if req.TeleportVersionTag == "" { + return trace.BadParameter("teleport version tag required") + } + + if req.OwnershipTags == nil { + return trace.BadParameter("ownership tags required") + } + + return nil +} + +// UpdateDeployServiceAgent updates the deploy service agent with the specified teleportVersionTag. +func UpdateDeployServiceAgent(ctx context.Context, clt DeployServiceClient, req UpdateServiceRequest) error { + if err := req.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + teleportImage := getDistrolessTeleportImage(req.TeleportVersionTag) + service, err := getManagedService(ctx, clt, req.TeleportClusterName, req.OwnershipTags) + if err != nil { + return trace.Wrap(err) + } + + taskDefinition, err := getManagedTaskDefinition(ctx, clt, aws.ToString(service.TaskDefinition), req.OwnershipTags) + if err != nil { + return trace.Wrap(err) + } + + currentTeleportImage, err := getTaskDefinitionTeleportImage(taskDefinition) + if err != nil { + return trace.Wrap(err) + } + + // There is no need to update the ecs service if the ecs service is already + // running the latest stable version of teleport. + if currentTeleportImage == teleportImage { + return nil + } + + registerTaskDefinitionIn, err := generateTaskDefinitionWithImage(taskDefinition, teleportImage, req.OwnershipTags.ToECSTags()) + if err != nil { + return trace.Wrap(err) + } + + registerTaskDefinitionOut, err := clt.RegisterTaskDefinition(ctx, registerTaskDefinitionIn) + if err != nil { + return trace.Wrap(err) + } + + // Update service with new task definition + _, err = updateServiceOrRollback(ctx, clt, service, registerTaskDefinitionOut.TaskDefinition) + if err != nil { + // If update failed, then rollback task definition. + // The update will be re-attempted during the next interval if it is still + // within the upgrade window or the critical upgrade flag is still enabled. + _, rollbackErr := clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: registerTaskDefinitionOut.TaskDefinition.TaskDefinitionArn, + }) + if rollbackErr != nil { + return trace.NewAggregate(err, trace.Wrap(rollbackErr, "failed to rollback task definition")) + } + return trace.Wrap(err) + } + + // Attempt to deregister previous task definition but ignore error on failure + _, err = clt.DeregisterTaskDefinition(ctx, &ecs.DeregisterTaskDefinitionInput{ + TaskDefinition: taskDefinition.TaskDefinitionArn, + }) + if err != nil { + logrus.WithError(err).Warning("Failed to deregister task definition.") + } + + return nil +} + +func getManagedService(ctx context.Context, clt DeployServiceClient, teleportClusterName string, ownershipTags AWSTags) (*ecsTypes.Service, error) { + var ecsServiceNames []string + for _, deploymentMode := range DeploymentModes { + ecsServiceNames = append(ecsServiceNames, normalizeECSServiceName(teleportClusterName, deploymentMode)) + } + + describeServicesOut, err := clt.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Cluster: aws.String(normalizeECSClusterName(teleportClusterName)), + Services: ecsServiceNames, + Include: []ecsTypes.ServiceField{ecsTypes.ServiceFieldTags}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if len(describeServicesOut.Services) == 0 { + return nil, trace.NotFound("services %v not found", ecsServiceNames) + } + if len(describeServicesOut.Services) != 1 { + return nil, trace.BadParameter("expected 1 service, but got %d", len(describeServicesOut.Services)) + } + service := describeServicesOut.Services[0] + + if !ownershipTags.MatchesECSTags(service.Tags) { + return nil, trace.Errorf("ECS Service %q already exists but is not managed by Teleport. "+ + "Add the following tags to allow Teleport to manage this service: %s", aws.ToString(service.ServiceName), ownershipTags) + } + // If the LaunchType is the required one, than we can update the current Service. + // Otherwise we have to delete it. + if service.LaunchType != ecsTypes.LaunchTypeFargate { + return nil, trace.Errorf("ECS Service %q already exists but has an invalid LaunchType %q. Delete the Service and try again.", aws.ToString(service.ServiceName), service.LaunchType) + } + + return &service, nil +} + +func getManagedTaskDefinition(ctx context.Context, clt DeployServiceClient, taskDefinitionName string, ownershipTags AWSTags) (*ecsTypes.TaskDefinition, error) { + describeTaskDefinitionOut, err := clt.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: aws.String(taskDefinitionName), + Include: []ecsTypes.TaskDefinitionField{ecsTypes.TaskDefinitionFieldTags}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + if !ownershipTags.MatchesECSTags(describeTaskDefinitionOut.Tags) { + return nil, trace.Errorf("ECS Task Definition %q already exists but is not managed by Teleport. "+ + "Add the following tags to allow Teleport to manage this task definition: %s", taskDefinitionName, ownershipTags) + } + return describeTaskDefinitionOut.TaskDefinition, nil +} + +func getTaskDefinitionTeleportImage(taskDefinition *ecsTypes.TaskDefinition) (string, error) { + if len(taskDefinition.ContainerDefinitions) != 1 { + return "", trace.BadParameter("expected 1 task container definition, but got %d", len(taskDefinition.ContainerDefinitions)) + } + return aws.ToString(taskDefinition.ContainerDefinitions[0].Image), nil +} + +// updateServiceOrRollback attempts to update the service with the specified task definition. +// The service will be rolled back if the service fails to become healthy. +func updateServiceOrRollback(ctx context.Context, clt DeployServiceClient, service *ecsTypes.Service, taskDefinition *ecsTypes.TaskDefinition) (*ecsTypes.Service, error) { + // Update service with new task definition + updateServiceOut, err := clt.UpdateService(ctx, generateServiceWithTaskDefinition(service, aws.ToString(taskDefinition.TaskDefinitionArn))) + if err != nil { + return nil, trace.Wrap(err) + } + + serviceStableWaiter := ecs.NewServicesStableWaiter(clt) + err = serviceStableWaiter.Wait(ctx, &ecs.DescribeServicesInput{ + Services: []string{aws.ToString(updateServiceOut.Service.ServiceName)}, + Cluster: updateServiceOut.Service.ClusterArn, + }, waitDuration) + if err == nil { + return updateServiceOut.Service, nil + } + + // If the service fails to reach a stable state within the allowed wait time, + // then rollback service with previous task definition + rollbackServiceOut, rollbackErr := clt.UpdateService(ctx, generateServiceWithTaskDefinition(service, aws.ToString(service.TaskDefinition))) + if rollbackErr != nil { + return nil, trace.NewAggregate(err, trace.Wrap(rollbackErr, "failed to rollback service")) + } + + rollbackErr = serviceStableWaiter.Wait(ctx, &ecs.DescribeServicesInput{ + Services: []string{aws.ToString(rollbackServiceOut.Service.ServiceName)}, + Cluster: updateServiceOut.Service.ClusterArn, + }, waitDuration) + if rollbackErr != nil { + return nil, trace.NewAggregate(err, trace.Wrap(rollbackErr, "failed to rollback service")) + } + + return nil, trace.Wrap(err) +} + +// generateTaskDefinitionWithImage returns new register task definition input with the desired teleport image +func generateTaskDefinitionWithImage(taskDefinition *ecsTypes.TaskDefinition, teleportImage string, tags []ecsTypes.Tag) (*ecs.RegisterTaskDefinitionInput, error) { + if len(taskDefinition.ContainerDefinitions) != 1 { + return nil, trace.BadParameter("expected 1 task container definition, but got %d", len(taskDefinition.ContainerDefinitions)) + } + + // Copy container definition and replace the teleport image with desired version + newContainerDefinition := new(ecsTypes.ContainerDefinition) + awsutil.Copy(newContainerDefinition, &taskDefinition.ContainerDefinitions[0]) + newContainerDefinition.Image = aws.String(teleportImage) + + // Copy task definition and replace container definitions + registerTaskDefinitionIn := new(ecs.RegisterTaskDefinitionInput) + awsutil.Copy(registerTaskDefinitionIn, taskDefinition) + registerTaskDefinitionIn.ContainerDefinitions = []ecsTypes.ContainerDefinition{*newContainerDefinition} + registerTaskDefinitionIn.Tags = tags + + return registerTaskDefinitionIn, nil +} + +// generateServiceWithTaskDefinition returns new update service input with the desired task definition +func generateServiceWithTaskDefinition(service *ecsTypes.Service, taskDefinitionName string) *ecs.UpdateServiceInput { + updateServiceIn := new(ecs.UpdateServiceInput) + awsutil.Copy(updateServiceIn, service) + updateServiceIn.Service = service.ServiceName + updateServiceIn.Cluster = service.ClusterArn + updateServiceIn.TaskDefinition = aws.String(taskDefinitionName) + return updateServiceIn +} diff --git a/lib/integrations/awsoidc/deployservice_update_test.go b/lib/integrations/awsoidc/deployservice_update_test.go new file mode 100644 index 0000000000000..9ef8b7c1fb176 --- /dev/null +++ b/lib/integrations/awsoidc/deployservice_update_test.go @@ -0,0 +1,140 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package awsoidc + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +func TestGenerateServiceWithTaskDefinition(t *testing.T) { + service := &ecsTypes.Service{ + ServiceName: aws.String("service"), + ClusterArn: aws.String("cluster"), + TaskDefinition: aws.String("task-definition-v1"), + NetworkConfiguration: &ecsTypes.NetworkConfiguration{ + AwsvpcConfiguration: &ecsTypes.AwsVpcConfiguration{ + AssignPublicIp: ecsTypes.AssignPublicIpEnabled, + Subnets: []string{"subnet"}, + }, + }, + PropagateTags: ecsTypes.PropagateTagsService, + } + + expected := &ecs.UpdateServiceInput{ + Service: aws.String("service"), + Cluster: aws.String("cluster"), + TaskDefinition: aws.String("task-definition-v2"), + NetworkConfiguration: &ecsTypes.NetworkConfiguration{ + AwsvpcConfiguration: &ecsTypes.AwsVpcConfiguration{ + AssignPublicIp: ecsTypes.AssignPublicIpEnabled, + Subnets: []string{"subnet"}, + }, + }, + PropagateTags: ecsTypes.PropagateTagsService, + } + + require.Equal(t, expected, generateServiceWithTaskDefinition(service, "task-definition-v2")) +} + +func TestGenerateTaskDefinitionWithImage(t *testing.T) { + taskDefinition := &ecsTypes.TaskDefinition{ + Family: aws.String("example-task"), + RequiresCompatibilities: []ecsTypes.Compatibility{ + ecsTypes.CompatibilityFargate, + }, + Cpu: &taskCPU, + Memory: &taskMem, + + NetworkMode: ecsTypes.NetworkModeAwsvpc, + TaskRoleArn: aws.String("task-role-arn"), + ExecutionRoleArn: aws.String("task-role-arn"), + ContainerDefinitions: []ecsTypes.ContainerDefinition{{ + Environment: []ecsTypes.KeyValuePair{{ + Name: aws.String(types.InstallMethodAWSOIDCDeployServiceEnvVar), + Value: aws.String("true"), + }}, + Command: []string{ + "start", + "--config-string", + "config-bytes", + }, + EntryPoint: []string{"teleport"}, + Image: aws.String("image-v1"), + Name: &taskAgentContainerName, + LogConfiguration: &ecsTypes.LogConfiguration{ + LogDriver: ecsTypes.LogDriverAwslogs, + Options: map[string]string{ + "awslogs-group": "ecs-cluster", + "awslogs-region": "us-west-2", + "awslogs-create-group": "true", + "awslogs-stream-prefix": "service/example-task", + }, + }, + }}, + } + tags := []ecsTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + } + + expected := &ecs.RegisterTaskDefinitionInput{ + Family: aws.String("example-task"), + RequiresCompatibilities: []ecsTypes.Compatibility{ + ecsTypes.CompatibilityFargate, + }, + Cpu: &taskCPU, + Memory: &taskMem, + + NetworkMode: ecsTypes.NetworkModeAwsvpc, + TaskRoleArn: aws.String("task-role-arn"), + ExecutionRoleArn: aws.String("task-role-arn"), + ContainerDefinitions: []ecsTypes.ContainerDefinition{{ + Environment: []ecsTypes.KeyValuePair{{ + Name: aws.String(types.InstallMethodAWSOIDCDeployServiceEnvVar), + Value: aws.String("true"), + }}, + Command: []string{ + "start", + "--config-string", + "config-bytes", + }, + EntryPoint: []string{"teleport"}, + Image: aws.String("image-v2"), + Name: &taskAgentContainerName, + LogConfiguration: &ecsTypes.LogConfiguration{ + LogDriver: ecsTypes.LogDriverAwslogs, + Options: map[string]string{ + "awslogs-group": "ecs-cluster", + "awslogs-region": "us-west-2", + "awslogs-create-group": "true", + "awslogs-stream-prefix": "service/example-task", + }, + }, + }}, + Tags: tags, + } + + input, err := generateTaskDefinitionWithImage(taskDefinition, "image-v2", tags) + require.NoError(t, err) + require.Equal(t, expected, input) +} diff --git a/lib/service/awsoidc.go b/lib/service/awsoidc.go new file mode 100644 index 0000000000000..0ad4d007af523 --- /dev/null +++ b/lib/service/awsoidc.go @@ -0,0 +1,350 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package service + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/automaticupgrades" + awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/integrations/awsoidc" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/interval" +) + +const ( + // updateDeployAgentsInterval specifies how frequently to check for available updates. + updateDeployAgentsInterval = time.Minute * 30 + + // maxConcurrentUpdates specifies the maximum number of concurrent updates + maxConcurrentUpdates = 3 +) + +func (process *TeleportProcess) initDeployServiceUpdater() error { + // start process only after teleport process has started + if _, err := process.WaitForEvent(process.GracefulExitContext(), TeleportReadyEvent); err != nil { + return trace.Wrap(err) + } + + resp, err := process.getInstanceClient().Ping(process.GracefulExitContext()) + if err != nil { + return trace.Wrap(err) + } + + if !resp.ServerFeatures.AutomaticUpgrades { + return nil + } + + // If criticalEndpoint or versionEndpoint are empty, the default stable/cloud endpoint will be used + var criticalEndpoint string + var versionEndpoint string + if automaticupgrades.GetChannel() != "" { + criticalEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "critical") + if err != nil { + return trace.Wrap(err) + } + versionEndpoint, err = url.JoinPath(automaticupgrades.GetChannel(), "version") + if err != nil { + return trace.Wrap(err) + } + } + + issuer, err := awsoidc.IssuerFromPublicAddress(process.proxyPublicAddr().Addr) + if err != nil { + return trace.Wrap(err) + } + + clusterNameConfig, err := process.getInstanceClient().GetClusterName() + if err != nil { + return trace.Wrap(err) + } + + updater, err := NewDeployServiceUpdater(DeployServiceUpdaterConfig{ + Log: process.log.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "aws_oidc_deploy_service_updater")), + AuthClient: process.getInstanceClient(), + Clock: process.Clock, + TeleportClusterName: clusterNameConfig.GetClusterName(), + TeleportClusterVersion: resp.GetServerVersion(), + AWSOIDCProviderAddr: issuer, + CriticalEndpoint: criticalEndpoint, + VersionEndpoint: versionEndpoint, + }) + if err != nil { + return trace.Wrap(err) + } + + process.log.Infof("The new service has started successfully. Checking for deploy service updates every %v.", updateDeployAgentsInterval) + return trace.Wrap(updater.Run(process.GracefulExitContext())) +} + +// DeployServiceUpdaterConfig specifies updater configs +type DeployServiceUpdaterConfig struct { + // Log is the logger + Log *logrus.Entry + // AuthClient is the auth api client + AuthClient *auth.Client + // Clock is the local clock + Clock clockwork.Clock + // TeleportClusterName specifies the teleport cluster name + TeleportClusterName string + // TeleportClusterVersion specifies the teleport cluster version + TeleportClusterVersion string + // AWSOIDCProvderAddr specifies the aws oidc provider address used to generate AWS OIDC tokens + AWSOIDCProviderAddr string + // CriticalEndpoint specifies the endpoint to check for critical updates + CriticalEndpoint string + // VersionEndpoint specifies the endpoint to check for current teleport version + VersionEndpoint string +} + +// CheckAndSetDefaults checks and sets default config values. +func (cfg *DeployServiceUpdaterConfig) CheckAndSetDefaults() error { + if cfg.AuthClient == nil { + return trace.BadParameter("auth client required") + } + + if cfg.TeleportClusterName == "" { + return trace.BadParameter("teleport cluster name required") + } + + if cfg.TeleportClusterVersion == "" { + return trace.BadParameter("teleport cluster version required") + } + + if cfg.AWSOIDCProviderAddr == "" { + return trace.BadParameter("aws oidc provider address required") + } + + if cfg.Log == nil { + cfg.Log = logrus.WithField(trace.Component, teleport.Component(teleport.ComponentProxy, "aws_oidc_deploy_service_updater")) + } + + if cfg.Clock == nil { + cfg.Clock = clockwork.NewRealClock() + } + + return nil +} + +// DeployServiceUpdater periodically updates deploy service agents +type DeployServiceUpdater struct { + DeployServiceUpdaterConfig +} + +// NewDeployServiceUpdater returns a new DeployServiceUpdater +func NewDeployServiceUpdater(config DeployServiceUpdaterConfig) (*DeployServiceUpdater, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &DeployServiceUpdater{ + DeployServiceUpdaterConfig: config, + }, nil +} + +// Run periodically updates the deploy service agents +func (updater *DeployServiceUpdater) Run(ctx context.Context) error { + periodic := interval.New(interval.Config{ + Duration: updateDeployAgentsInterval, + Jitter: retryutils.NewSeventhJitter(), + }) + defer periodic.Stop() + + for { + if err := updater.updateDeployServiceAgents(ctx); err != nil { + updater.Log.WithError(err).Warningf("Update failed. Retrying in ~%v.", updateDeployAgentsInterval) + } + + select { + case <-periodic.Next(): + case <-ctx.Done(): + return nil + } + } +} + +func (updater *DeployServiceUpdater) updateDeployServiceAgents(ctx context.Context) error { + cmc, err := updater.AuthClient.GetClusterMaintenanceConfig(ctx) + if err != nil { + return trace.Wrap(err) + } + + critical, err := automaticupgrades.Critical(ctx, updater.CriticalEndpoint) + if err != nil { + return trace.Wrap(err) + } + + // Upgrade should only be attempted if the current time is within the configured + // upgrade window, or if a critical upgrade is available + if !cmc.WithinUpgradeWindow(updater.Clock.Now()) && !critical { + return nil + } + + stableVersion, err := automaticupgrades.Version(ctx, updater.VersionEndpoint) + if err != nil { + return trace.Wrap(err) + } + // stableVersion has vX.Y.Z format, however the container image tag does not include the `v`. + stableVersion = strings.TrimPrefix(stableVersion, "v") + + // minServerVersion specifies the minimum version of the cluster required for + // updated agents to remain compatible with the cluster. + minServerVersion, err := utils.MajorSemver(stableVersion) + if err != nil { + return trace.Wrap(err) + } + + if !utils.MeetsVersion(updater.TeleportClusterVersion, minServerVersion) { + updater.Log.Debugf("Skipping update. %v agents will not be compatible with a %v cluster.", stableVersion, updater.TeleportClusterVersion) + return nil + } + + databases, err := updater.AuthClient.GetDatabases(ctx) + if err != nil { + return trace.Wrap(err) + } + + // The updater needs to iterate over all integrations and aws regions to check + // for deploy service agents to update. In order to reduce the number of api + // calls, the aws regions are first reduced to only the regions containing + // an RDS database. + awsRegions := make(map[string]interface{}) + for _, database := range databases { + if database.IsAWSHosted() && database.IsRDS() { + awsRegions[database.GetAWS().Region] = nil + } + } + + integrations, err := updater.AuthClient.ListAllIntegrations(ctx) + if err != nil { + return trace.Wrap(err) + } + + // Perform updates in parallel across regions. + sem := semaphore.NewWeighted(maxConcurrentUpdates) + for _, ig := range integrations { + for region := range awsRegions { + if err := sem.Acquire(ctx, 1); err != nil { + return trace.Wrap(err) + } + go func(ig types.Integration, region string) { + defer sem.Release(1) + if err := updater.updateDeployServiceAgent(ctx, ig, region, stableVersion); err != nil { + updater.Log.WithError(err).Warningf("Failed to update deploy service agent for integration %s in region %s.", ig.GetName(), region) + } + }(ig, region) + } + } + + // Wait for all updates to finish. + return trace.Wrap(sem.Acquire(ctx, maxConcurrentUpdates)) +} + +func (updater *DeployServiceUpdater) updateDeployServiceAgent(ctx context.Context, integration types.Integration, awsRegion, teleportVersion string) error { + // Do not attempt update if integration is not an aws oidc integration. + if integration.GetAWSOIDCIntegrationSpec() == nil { + return nil + } + + token, err := updater.AuthClient.GenerateAWSOIDCToken(ctx, types.GenerateAWSOIDCTokenRequest{ + Issuer: updater.AWSOIDCProviderAddr, + }) + if err != nil { + return trace.Wrap(err) + } + + req := &awsoidc.AWSClientRequest{ + IntegrationName: integration.GetName(), + Token: token, + RoleARN: integration.GetAWSOIDCIntegrationSpec().RoleARN, + Region: awsRegion, + } + + // The deploy service client is initialized using AWS OIDC integration. + deployServiceClient, err := awsoidc.NewDeployServiceClient(ctx, req, updater.AuthClient) + if err != nil { + return trace.Wrap(err) + } + + // ownershipTags are used to identify if the ecs resources are managed by the + // teleport integration. + ownershipTags := map[string]string{ + types.ClusterLabel: updater.TeleportClusterName, + types.OriginLabel: types.OriginIntegrationAWSOIDC, + types.IntegrationLabel: integration.GetName(), + } + + // Acquire a lease for the region + integration before attempting to update the deploy service agent. + // If the lease cannot be acquired, the update is already being handled by another instance. + semLock, err := updater.AuthClient.AcquireSemaphore(ctx, types.AcquireSemaphoreRequest{ + SemaphoreKind: types.SemaphoreKindConnection, + SemaphoreName: fmt.Sprintf("update_deploy_service_agents_%s_%s", awsRegion, integration.GetName()), + MaxLeases: 1, + Expires: updater.Clock.Now().Add(updateDeployAgentsInterval), + Holder: "update_deploy_service_agents", + }) + if err != nil { + if strings.Contains(err.Error(), teleport.MaxLeases) { + updater.Log.WithError(err).Debug("Deploy service agent update is already being processed.") + return nil + } + return trace.Wrap(err) + } + defer func() { + if err := updater.AuthClient.CancelSemaphoreLease(ctx, *semLock); err != nil { + updater.Log.WithError(err).Error("Failed to cancel semaphore lease.") + } + }() + + updater.Log.Debugf("Updating Deploy Service Agents for integration %s in AWS region: %s", integration.GetName(), awsRegion) + if err := awsoidc.UpdateDeployServiceAgent(ctx, deployServiceClient, awsoidc.UpdateServiceRequest{ + TeleportClusterName: updater.TeleportClusterName, + TeleportVersionTag: teleportVersion, + OwnershipTags: ownershipTags, + }); err != nil { + + switch { + case trace.IsNotFound(err): + // The updater checks each integration/region combination, so + // there will be regions where there is no ECS cluster deployed + // for the integration. + updater.Log.Debugf("Integration %s does not manage any services within region %s.", integration.GetName(), awsRegion) + return nil + case trace.IsAccessDenied(awslib.ConvertIAMv2Error(trace.Unwrap(err))): + // The aws oidc role may lack permissions due to changes in teleport. + // In this situation users should be notified that they will need to + // re-run the deploy service iam configuration script and update the + // permissions. + updater.Log.WithError(err).Warning("Update integration role and add missing permissions.") + } + return trace.Wrap(err) + } + return nil +} diff --git a/lib/service/service.go b/lib/service/service.go index dcb3d7a1ee7ef..81b09ed1fe99e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1014,6 +1014,10 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { process.log.Infof("Configured upgrade window exporter for external upgrader. kind=%s", upgraderKind) } + if process.Config.Proxy.Enabled { + process.RegisterFunc("update.aws-oidc.deploy.agents", process.initDeployServiceUpdater) + } + serviceStarted := false if !cfg.DiagnosticAddr.IsEmpty() { diff --git a/lib/utils/ver.go b/lib/utils/ver.go index 42c8d8263c1d6..9b4965ae19e9a 100644 --- a/lib/utils/ver.go +++ b/lib/utils/ver.go @@ -17,10 +17,24 @@ limitations under the License. package utils import ( + "fmt" + "github.com/coreos/go-semver/semver" "github.com/gravitational/trace" ) +// MeetsVersion returns true if gotVer is empty or at least minVer. +func MeetsVersion(gotVer, minVer string) bool { + if gotVer == "" { + return true // Ignore empty versions. + } + + err := CheckVersion(gotVer, minVer) + + // Non BadParameter errors are semver parsing errors. + return !trace.IsBadParameter(err) +} + // CheckVersion compares a version with a minimum version supported. func CheckVersion(currentVersion, minVersion string) error { currentSemver, minSemver, err := versionStringToSemver(currentVersion, minVersion) @@ -56,6 +70,16 @@ func MinVerWithoutPreRelease(currentVersion, minVersion string) (bool, error) { return !currentSemver.LessThan(*minSemver), nil } +// MajorSemver returns the major version as a semver string. +// Ex: 13.4.3 -> 13.0.0 +func MajorSemver(version string) (string, error) { + ver, err := semver.NewVersion(version) + if err != nil { + return "", trace.Wrap(err) + } + return fmt.Sprintf("%d.0.0", ver.Major), nil +} + func versionStringToSemver(ver1, ver2 string) (*semver.Version, *semver.Version, error) { v1Semver, err := semver.NewVersion(ver1) if err != nil { diff --git a/lib/utils/ver_test.go b/lib/utils/ver_test.go index 56fa2a288bdf2..17ee9ae54c989 100644 --- a/lib/utils/ver_test.go +++ b/lib/utils/ver_test.go @@ -16,7 +16,75 @@ limitations under the License. package utils -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMeetsVersion_emptyOrInvalid(t *testing.T) { + // See TestVersions for more comprehensive tests. + + if !MeetsVersion("", "v1.2.3") { + t.Error("MeetsVersion with an empty gotVer should always succeed") + } + + if !MeetsVersion("banana", "v1.2.3") { + t.Error("MeetsVersion with an invalid version should always succeed") + } + if !MeetsVersion("v1.2.3", "banana") { + t.Error("MeetsVersion with an invalid version should always succeed") + } +} + +func TestMajorSemver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + version string + expected string + want bool + wantErr bool + }{ + { + name: "simple semver", + version: "13.4.2", + expected: "13.0.0", + want: true, + }, + { + name: "ignores suffix", + version: "13.4.2-dev", + expected: "13.0.0", + want: true, + }, + { + name: "empty version is rejected", + version: "", + expected: "", + wantErr: true, + }, + { + name: "incorrect version is rejected", + version: "13.4", // missing patch + expected: "13.0.0", + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := MajorSemver(tt.version) + if tt.wantErr { + require.Error(t, err) + } else { + require.Equal(t, tt.expected, got) + } + }) + } +} func TestMinVerWithoutPreRelease(t *testing.T) { t.Parallel() diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index 3bc6ab2c754df..9423b24b662da 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -77,9 +77,9 @@ type scriptSettings struct { databaseInstallMode bool installUpdater bool - // automaticUpgradesVersionBaseURL is the base URL for getting the version when using the cloud/stable channel. + // automaticUpgradesVersionURL is the URL for getting the version when using the cloud/stable channel. // Optional. - automaticUpgradesVersionBaseURL string + automaticUpgradesVersionURL string } // automaticUpgrades returns whether automaticUpgrades should be enabled. @@ -395,7 +395,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter // This ensures the initial installed version is the same as the `teleport-ent-updater` would install. if settings.installUpdater { repoChannel = stableCloudChannelRepo - cloudStableVersion, err := automaticupgrades.Version(ctx, settings.automaticUpgradesVersionBaseURL) + cloudStableVersion, err := automaticupgrades.Version(ctx, settings.automaticUpgradesVersionURL) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index 1c4f706a5bdce..eb2c88e403e98 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -22,6 +22,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "regexp" "testing" @@ -942,7 +943,10 @@ func TestJoinScript(t *testing.T) { })) defer httpTestServer.Close() - script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersionBaseURL: httpTestServer.URL}, m) + versionURL, err := url.JoinPath(httpTestServer.URL, "/v1/stable/cloud/version") + require.NoError(t, err) + + script, err := getJoinScript(context.Background(), scriptSettings{token: validToken, installUpdater: true, automaticUpgradesVersionURL: versionURL}, m) require.NoError(t, err) // list of packages must include the updater