Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ var (
// against a static list of known valid endpoints. We will need to update this
// list as AWS adds new regions.
func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error {
valid := slices.Contains(validSTSEndpoints, stsHost)
valid := slices.Contains(aws.GetValidSTSEndpoints(), stsHost)
if !valid {
return trace.AccessDenied("IAM join request uses unknown STS host %q. "+
"This could mean that the Teleport Node attempting to join the cluster is "+
Expand All @@ -86,10 +86,10 @@ func validateSTSHost(stsHost string, cfg *iamRegisterConfig) error {
"Following is the list of valid STS endpoints known to this auth server. "+
"If a legitimate STS endpoint is not included, please file an issue at "+
"https://github.com/gravitational/teleport. %v",
stsHost, validSTSEndpoints)
stsHost, aws.GetValidSTSEndpoints())
}

if cfg.fips && !slices.Contains(fipsSTSEndpoints, stsHost) {
if cfg.fips && !slices.Contains(aws.GetSTSFipsEndpoints(), stsHost) {
return trace.AccessDenied("node selected non-FIPS STS endpoint (%s) for the IAM join method", stsHost)
}

Expand Down Expand Up @@ -467,7 +467,7 @@ func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS,

stsClient := sts.New(sess)

if slices.Contains(globalSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) {
if slices.Contains(aws.GetSTSGlobalEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) {
// If the caller wants to use the regional endpoint but it was not resolved
// from the environment, attempt to find the region from the EC2 IMDS
if cfg.regionalEndpointOption == endpoints.RegionalSTSEndpoint {
Expand All @@ -484,7 +484,7 @@ func newSTSClient(ctx context.Context, cfg *stsIdentityRequestConfig) (*sts.STS,
}

if cfg.fipsEndpointOption == endpoints.FIPSEndpointStateEnabled &&
!slices.Contains(validSTSEndpoints, strings.TrimPrefix(stsClient.Endpoint, "https://")) {
!slices.Contains(aws.GetValidSTSEndpoints(), strings.TrimPrefix(stsClient.Endpoint, "https://")) {
// The AWS SDK will generate invalid endpoints when attempting to
// resolve the FIPS endpoint for a region which does not have one.
// In this case, try to use the FIPS endpoint in us-east-1. This should
Expand Down
81 changes: 0 additions & 81 deletions lib/auth/sts_endpoints.go

This file was deleted.

144 changes: 144 additions & 0 deletions lib/utils/aws/sts_endpoints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
"strings"
"sync"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/service/sts"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

var (
globalSTSEndpoints = []string{
"sts.amazonaws.com",
// This is not a real endpoint, but the SDK will select it if
// AWS_USE_FIPS_ENDPOINT is set and a region is not.
"sts-fips.aws-global.amazonaws.com",
}

fipsSTSEndpoints []string
validSTSEndpoints []string
initSTSEndpointsOnce sync.Once
)

// GetSTSGlobalEndpoints returns a list of global STS endpoints.
func GetSTSGlobalEndpoints() []string {
return globalSTSEndpoints
}

// GetSTSFipsEndpoints returns a list of STS fips endpoints.
func GetSTSFipsEndpoints() []string {
initSTSEndpoints()
return fipsSTSEndpoints
}

// GetValidSTSEndpoints returns a list of all valid STS endpoints.
func GetValidSTSEndpoints() []string {
initSTSEndpoints()
return validSTSEndpoints
}

func initSTSEndpoints() {
initSTSEndpointsOnce.Do(func() {
nullOption := func(*endpoints.Options) {}

fipsSTSEndpoints = genSTSEndpoints(genSTSEndointConfig{
requiredOptions: []func(*endpoints.Options){
endpoints.StrictMatchingOption,
endpoints.UseFIPSEndpointOption,
},
multiplyOptions: []func(*endpoints.Options){
nullOption,
endpoints.STSRegionalEndpointOption,
endpoints.UseDualStackEndpointOption,
},
})

validSTSEndpoints = genSTSEndpoints(genSTSEndointConfig{
requiredOptions: []func(*endpoints.Options){
endpoints.StrictMatchingOption,
},
multiplyOptions: []func(*endpoints.Options){
nullOption,
endpoints.STSRegionalEndpointOption,
endpoints.UseFIPSEndpointOption,
endpoints.UseDualStackEndpointOption,
},
})
})
}

// combinations returns all combinations for a given array. This is essentially
// a powerset of the given set except that the empty set is disregarded.
//
// Reference: https://github.com/mxschmitt/golang-combinations
func combinations[T any](set []T) (subsets [][]T) {
length := uint(len(set))

// Go through all possible combinations of objects
// from 1 (only first object in subset) to 2^length (all objects in subset)
for subsetBits := 1; subsetBits < (1 << length); subsetBits++ {
var subset []T

for object := uint(0); object < length; object++ {
// checks if object is contained in subset
// by checking if bit 'object' is set in subsetBits
if (subsetBits>>object)&1 == 1 {
// add object to subset
subset = append(subset, set[object])
}
}
// add subset to subsets
subsets = append(subsets, subset)
}
return subsets
}

type genSTSEndointConfig struct {
// requiredOptions are endpoints options that must be set for each
// EndpointFor call.
requiredOptions []func(*endpoints.Options)
// multiplyOptions is a list of endpoints options where all their
// combinations must be iterated to create the endpoints.
multiplyOptions []func(*endpoints.Options)
}

func genSTSEndpoints(cfg genSTSEndointConfig) []string {
optCombinations := combinations(cfg.multiplyOptions)
endpointsSet := make(map[string]struct{})
for _, partition := range endpoints.DefaultPartitions() {
for region := range partition.Regions() {
for _, opts := range optCombinations {
endpoint, err := partition.EndpointFor(sts.ServiceName, region, append(cfg.requiredOptions, opts...)...)
if err != nil {
// Skip if no endpoint found for this opts combo.
continue
}

endpointsSet[strings.TrimPrefix(endpoint.URL, "https://")] = struct{}{}
}
}
}

endpointsSlice := maps.Keys(endpointsSet)
slices.Sort(endpointsSlice)
return endpointsSlice
}
85 changes: 85 additions & 0 deletions lib/utils/aws/sts_endpoints_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package aws

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGetValidSTSEndpoints(t *testing.T) {
// The number of endpoints may grow over time when SDK is updated. So here
// just tests if GetValidSTSEndpoints contains entries from this selective
// list.
wantEndoints := append(
[]string{
"sts.af-south-1.amazonaws.com",
"sts.amazonaws.com",
"sts.ap-east-1.amazonaws.com",
"sts.cn-north-1.amazonaws.com.cn",
"sts.cn-northwest-1.amazonaws.com.cn",
"sts.il-central-1.amazonaws.com",
"sts.us-gov-west-1.amazonaws.com",
"sts.us-iso-east-1.c2s.ic.gov",
"sts.us-west-1.amazonaws.com",
},
GetSTSFipsEndpoints()...,
)
for _, wantEndpoint := range wantEndoints {
require.Contains(t, GetValidSTSEndpoints(), wantEndpoint)
}
}

func TestGetSTSFipsEndpoints(t *testing.T) {
wantEndoints := []string{
"sts-fips.us-east-1.amazonaws.com",
"sts-fips.us-east-2.amazonaws.com",
"sts-fips.us-west-1.amazonaws.com",
"sts-fips.us-west-2.amazonaws.com",
"sts.us-gov-east-1.amazonaws.com",
"sts.us-gov-west-1.amazonaws.com",
}
for _, wantEndpoint := range wantEndoints {
require.Contains(t, GetSTSFipsEndpoints(), wantEndpoint)
}
}

func Test_combinations(t *testing.T) {
require.Nil(
t,
combinations[string](nil),
)

require.Equal(
t,
[][]string{
{"a"},
},
combinations([]string{"a"}),
)

require.Equal(
t,
[][]string{
{"a"}, {"b"}, {"a", "b"},
{"c"}, {"a", "c"}, {"b", "c"},
{"a", "b", "c"},
},
combinations([]string{"a", "b", "c"}),
)
}