diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index c9e9cd7a97ef8..aec4a11cebcbe 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -76,7 +76,7 @@ var ( // against a static list of known valid endpoints. We will need to update this // list as AWS adds new regions. func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error { - valid := slices.Contains(validSTSEndpoints, stsHost) + valid := slices.Contains(aws.GetValidSTSEndpoints(), stsHost) if !valid { return trace.AccessDenied("IAM join request uses unknown STS host %q. "+ "This could mean that the Teleport Node attempting to join the cluster is "+ @@ -86,10 +86,10 @@ func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error { "Following is the list of valid STS endpoints known to this auth server. "+ "If a legitimate STS endpoint is not included, please file an issue at "+ "https://github.com/gravitational/teleport. %v", - stsHost, validSTSEndpoints) + stsHost, aws.GetValidSTSEndpoints()) } - if cfg.fips && !slices.Contains(fipsSTSEndpoints, stsHost) { + if cfg.fips && !slices.Contains(aws.GetSTSFipsEndpoints(), stsHost) { return trace.AccessDenied("node selected non-FIPS STS endpoint (%s) for the IAM join method", stsHost) } @@ -467,7 +467,7 @@ func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, stsClient := sts.New(sess) - if slices.Contains(globalSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) { + if slices.Contains(aws.GetSTSGlobalEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) { // If the caller wants to use the regional endpoint but it was not resolved // from the environment, attempt to find the region from the EC2 IMDS if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint { @@ -484,7 +484,7 @@ func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS, } if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled && - !slices.Contains(validSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) { + !slices.Contains(aws.GetValidSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) { // The AWS SDK will generate invalid endpoints when attempting to // resolve the FIPS endpoint for a region which does not have one. // In this case, try to use the FIPS endpoint in us-east-1. This should diff --git a/lib/auth/sts_endpoints.go b/lib/auth/sts_endpoints.go deleted file mode 100644 index 5a6b86c8c3d79..0000000000000 --- a/lib/auth/sts_endpoints.go +++ /dev/null @@ -1,81 +0,0 @@ -/* -Copyright 2022 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package auth - -var ( - // validSTSEndpoints holds a sorted list of all known valid public endpoints for - // the AWS STS service. You can generate this list by running - // $ go run github.com/nklaassen/sts-endpoints@latest --go-list - // Update aws-sdk-go in that package to learn about new endpoints. - validSTSEndpoints = []string{ - "sts-fips.us-east-1.amazonaws.com", - "sts-fips.us-east-2.amazonaws.com", - "sts-fips.us-west-1.amazonaws.com", - "sts-fips.us-west-2.amazonaws.com", - "sts.af-south-1.amazonaws.com", - "sts.amazonaws.com", - "sts.ap-east-1.amazonaws.com", - "sts.ap-northeast-1.amazonaws.com", - "sts.ap-northeast-2.amazonaws.com", - "sts.ap-northeast-3.amazonaws.com", - "sts.ap-south-1.amazonaws.com", - "sts.ap-south-2.amazonaws.com", - "sts.ap-southeast-1.amazonaws.com", - "sts.ap-southeast-2.amazonaws.com", - "sts.ap-southeast-3.amazonaws.com", - "sts.ap-southeast-4.amazonaws.com", - "sts.ca-central-1.amazonaws.com", - "sts.cn-north-1.amazonaws.com.cn", - "sts.cn-northwest-1.amazonaws.com.cn", - "sts.eu-central-1.amazonaws.com", - "sts.eu-central-2.amazonaws.com", - "sts.eu-north-1.amazonaws.com", - "sts.eu-south-1.amazonaws.com", - "sts.eu-south-2.amazonaws.com", - "sts.eu-west-1.amazonaws.com", - "sts.eu-west-2.amazonaws.com", - "sts.eu-west-3.amazonaws.com", - "sts.me-central-1.amazonaws.com", - "sts.me-south-1.amazonaws.com", - "sts.sa-east-1.amazonaws.com", - "sts.us-east-1.amazonaws.com", - "sts.us-east-2.amazonaws.com", - "sts.us-gov-east-1.amazonaws.com", - "sts.us-gov-west-1.amazonaws.com", - "sts.us-iso-east-1.c2s.ic.gov", - "sts.us-iso-west-1.c2s.ic.gov", - "sts.us-isob-east-1.sc2s.sgov.gov", - "sts.us-west-1.amazonaws.com", - "sts.us-west-2.amazonaws.com", - } - - globalSTSEndpoints = []string{ - "sts.amazonaws.com", - // This is not a real endpoint, but the SDK will select it if - // AWS_USE_FIPS_ENDPOINT is set and a region is not. - "sts-fips.aws-global.amazonaws.com", - } - - fipsSTSEndpoints = []string{ - "sts-fips.us-east-1.amazonaws.com", - "sts-fips.us-east-2.amazonaws.com", - "sts-fips.us-west-1.amazonaws.com", - "sts-fips.us-west-2.amazonaws.com", - "sts.us-gov-east-1.amazonaws.com", - "sts.us-gov-west-1.amazonaws.com", - } -) diff --git a/lib/utils/aws/sts_endpoints.go b/lib/utils/aws/sts_endpoints.go new file mode 100644 index 0000000000000..9abfbff58e9b8 --- /dev/null +++ b/lib/utils/aws/sts_endpoints.go @@ -0,0 +1,144 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "strings" + "sync" + + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/service/sts" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +var ( + globalSTSEndpoints = []string{ + "sts.amazonaws.com", + // This is not a real endpoint, but the SDK will select it if + // AWS_USE_FIPS_ENDPOINT is set and a region is not. + "sts-fips.aws-global.amazonaws.com", + } + + fipsSTSEndpoints []string + validSTSEndpoints []string + initSTSEndpointsOnce sync.Once +) + +// GetSTSGlobalEndpoints returns a list of global STS endpoints. +func GetSTSGlobalEndpoints() []string { + return globalSTSEndpoints +} + +// GetSTSFipsEndpoints returns a list of STS fips endpoints. +func GetSTSFipsEndpoints() []string { + initSTSEndpoints() + return fipsSTSEndpoints +} + +// GetValidSTSEndpoints returns a list of all valid STS endpoints. +func GetValidSTSEndpoints() []string { + initSTSEndpoints() + return validSTSEndpoints +} + +func initSTSEndpoints() { + initSTSEndpointsOnce.Do(func() { + nullOption := func(*endpoints.Options) {} + + fipsSTSEndpoints = genSTSEndpoints(genSTSEndointConfig{ + requiredOptions: []func(*endpoints.Options){ + endpoints.StrictMatchingOption, + endpoints.UseFIPSEndpointOption, + }, + multiplyOptions: []func(*endpoints.Options){ + nullOption, + endpoints.STSRegionalEndpointOption, + endpoints.UseDualStackEndpointOption, + }, + }) + + validSTSEndpoints = genSTSEndpoints(genSTSEndointConfig{ + requiredOptions: []func(*endpoints.Options){ + endpoints.StrictMatchingOption, + }, + multiplyOptions: []func(*endpoints.Options){ + nullOption, + endpoints.STSRegionalEndpointOption, + endpoints.UseFIPSEndpointOption, + endpoints.UseDualStackEndpointOption, + }, + }) + }) +} + +// combinations returns all combinations for a given array. This is essentially +// a powerset of the given set except that the empty set is disregarded. +// +// Reference: https://github.com/mxschmitt/golang-combinations +func combinations[T any](set []T) (subsets [][]T) { + length := uint(len(set)) + + // Go through all possible combinations of objects + // from 1 (only first object in subset) to 2^length (all objects in subset) + for subsetBits := 1; subsetBits < (1 << length); subsetBits++ { + var subset []T + + for object := uint(0); object < length; object++ { + // checks if object is contained in subset + // by checking if bit 'object' is set in subsetBits + if (subsetBits>>object)&1 == 1 { + // add object to subset + subset = append(subset, set[object]) + } + } + // add subset to subsets + subsets = append(subsets, subset) + } + return subsets +} + +type genSTSEndointConfig struct { + // requiredOptions are endpoints options that must be set for each + // EndpointFor call. + requiredOptions []func(*endpoints.Options) + // multiplyOptions is a list of endpoints options where all their + // combinations must be iterated to create the endpoints. + multiplyOptions []func(*endpoints.Options) +} + +func genSTSEndpoints(cfg genSTSEndointConfig) []string { + optCombinations := combinations(cfg.multiplyOptions) + endpointsSet := make(map[string]struct{}) + for _, partition := range endpoints.DefaultPartitions() { + for region := range partition.Regions() { + for _, opts := range optCombinations { + endpoint, err := partition.EndpointFor(sts.ServiceName, region, append(cfg.requiredOptions, opts...)...) + if err != nil { + // Skip if no endpoint found for this opts combo. + continue + } + + endpointsSet[strings.TrimPrefix(endpoint.URL, "https://")] = struct{}{} + } + } + } + + endpointsSlice := maps.Keys(endpointsSet) + slices.Sort(endpointsSlice) + return endpointsSlice +} diff --git a/lib/utils/aws/sts_endpoints_test.go b/lib/utils/aws/sts_endpoints_test.go new file mode 100644 index 0000000000000..821c986843cb4 --- /dev/null +++ b/lib/utils/aws/sts_endpoints_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetValidSTSEndpoints(t *testing.T) { + // The number of endpoints may grow over time when SDK is updated. So here + // just tests if GetValidSTSEndpoints contains entries from this selective + // list. + wantEndoints := append( + []string{ + "sts.af-south-1.amazonaws.com", + "sts.amazonaws.com", + "sts.ap-east-1.amazonaws.com", + "sts.cn-north-1.amazonaws.com.cn", + "sts.cn-northwest-1.amazonaws.com.cn", + "sts.il-central-1.amazonaws.com", + "sts.us-gov-west-1.amazonaws.com", + "sts.us-iso-east-1.c2s.ic.gov", + "sts.us-west-1.amazonaws.com", + }, + GetSTSFipsEndpoints()..., + ) + for _, wantEndpoint := range wantEndoints { + require.Contains(t, GetValidSTSEndpoints(), wantEndpoint) + } +} + +func TestGetSTSFipsEndpoints(t *testing.T) { + wantEndoints := []string{ + "sts-fips.us-east-1.amazonaws.com", + "sts-fips.us-east-2.amazonaws.com", + "sts-fips.us-west-1.amazonaws.com", + "sts-fips.us-west-2.amazonaws.com", + "sts.us-gov-east-1.amazonaws.com", + "sts.us-gov-west-1.amazonaws.com", + } + for _, wantEndpoint := range wantEndoints { + require.Contains(t, GetSTSFipsEndpoints(), wantEndpoint) + } +} + +func Test_combinations(t *testing.T) { + require.Nil( + t, + combinations[string](nil), + ) + + require.Equal( + t, + [][]string{ + {"a"}, + }, + combinations([]string{"a"}), + ) + + require.Equal( + t, + [][]string{ + {"a"}, {"b"}, {"a", "b"}, + {"c"}, {"a", "c"}, {"b", "c"}, + {"a", "b", "c"}, + }, + combinations([]string{"a", "b", "c"}), + ) +}