diff --git a/docs/pages/management/guides/joining-nodes-aws-iam.mdx b/docs/pages/management/guides/joining-nodes-aws-iam.mdx index 8a97280be5608..7b891945aefe9 100644 --- a/docs/pages/management/guides/joining-nodes-aws-iam.mdx +++ b/docs/pages/management/guides/joining-nodes-aws-iam.mdx @@ -58,15 +58,6 @@ The IAM join method will not work if TLS is terminated at a load balancer in front of your Teleport Proxy Service unless the Node using this method is connecting directly to the Auth Service. -The IAM join method is currently not supported in the AWS China or GovCloud -partitions. - - - - -The IAM join method is currently not supported in the AWS China or GovCloud -partitions. - ## Prerequisites @@ -169,4 +160,4 @@ proxy_service: ## Step 4/4. Launch your Teleport Node Start Teleport on the Node and confirm that it is able to connect to and join -your cluster. You're all set! \ No newline at end of file +your cluster. You're all set! diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index b1923f9ac412c..7b64cc84fa424 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -33,9 +33,12 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils" + apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/aws" + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/gravitational/trace" @@ -47,19 +50,45 @@ const ( // update our AWS SDK dependency. Since Auth should always be upgraded // before nodes, we will have a chance to update the check on Auth if we // ever have a need to allow a newer API version. - expectedStsIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15" + expectedSTSIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15" - // Only allowing the global sts endpoint here, Teleport nodes will only send - // requests for this endpoint. If we want to start using regional endpoints - // we can update this check before updating the nodes. - stsHost = "sts.amazonaws.com" + // Used to check if we were unable to resolve the regional STS endpoint. + globalSTSEndpoint = "https://sts.amazonaws.com" // AWS SignedHeaders will always be lowercase // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview challengeHeaderKey = "x-teleport-challenge" ) -// validateStsIdentityRequest checks that a received sts:GetCallerIdentity +// validateSTSHost returns an error if the given stsHost is not a valid regional +// endpoint for the AWS STS service, or nil if it is valid. +// +// This is a security-critical check: we are allowing the client to tell us +// which URL we should use to validate their identity. If the client could pass +// off an attacker-controlled URL as the STS endpoint, the entire security +// mechanism of the IAM join method would be compromised. +// +// To keep this validation simple and secure, we check the given endpoint +// against a static list of known valid endpoints. We will need to update this +// list as AWS adds new regions. +func validateSTSHost(stsHost string) error { + valid := apiutils.SliceContainsStr(validSTSEndpoints, stsHost) + if valid { + return nil + } + + return trace.AccessDenied("IAM join request uses unknown STS host %q. "+ + "This could mean that the Teleport Node attempting to join the cluster is "+ + "running in a new AWS region which is unknown to this Teleport auth server. "+ + "Alternatively, if this URL looks suspicious, an attacker may be attempting to "+ + "join your Teleport cluster. "+ + "Following is the list of valid STS endpoints known to this auth server. "+ + "If a legitimate STS endpoint is not included, please file an issue at "+ + "https://github.com/gravitational/teleport. %v", + stsHost, validSTSEndpoints) +} + +// validateSTSIdentityRequest checks that a received sts:GetCallerIdentity // request is valid and includes the challenge as a signed header. An example // valid request looks like: // ``` @@ -76,9 +105,18 @@ const ( // // Action=GetCallerIdentity&Version=2011-06-15 // ``` -func validateStsIdentityRequest(req *http.Request, challenge string) error { - if req.Host != stsHost { - return trace.AccessDenied("sts identity request is for unknown host %q", req.Host) +func validateSTSIdentityRequest(req *http.Request, challenge string) (err error) { + defer func() { + // Always log a warning on the Auth server if the function detects an + // invalid sts:GetCallerIdentity request, it's either going to be caused + // by a node in a unknown region or an attacker. + if err != nil { + log.WithError(err).Warn("Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method.") + } + }() + + if err := validateSTSHost(req.Host); err != nil { + return trace.Wrap(err) } if req.Method != http.MethodPost { @@ -95,7 +133,7 @@ func validateStsIdentityRequest(req *http.Request, challenge string) error { if err != nil { return trace.Wrap(err) } - if !utils.SliceContainsStr(sigV4.SignedHeaders, challengeHeaderKey) { + if !apiutils.SliceContainsStr(sigV4.SignedHeaders, challengeHeaderKey) { return trace.AccessDenied("sts identity request auth header %q does not include "+ challengeHeaderKey+" as a signed header", authHeader) } @@ -104,8 +142,8 @@ func validateStsIdentityRequest(req *http.Request, challenge string) error { if err != nil { return trace.Wrap(err) } - if !bytes.Equal([]byte(expectedStsIdentityRequestBody), body) { - return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedStsIdentityRequestBody) + if !bytes.Equal([]byte(expectedSTSIdentityRequestBody), body) { + return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedSTSIdentityRequestBody) } return nil @@ -125,7 +163,7 @@ func parseSTSRequest(req []byte) (*http.Request, error) { httpReq.RequestURI = "" httpReq.URL = &url.URL{ Scheme: "https", - Host: stsHost, + Host: httpReq.Host, } return httpReq, nil } @@ -161,9 +199,9 @@ func stsClientFromContext(ctx context.Context) stsClient { return http.DefaultClient } -// executeStsIdentityRequest sends the sts:GetCallerIdentity HTTP request to the +// executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the // AWS API, parses the response, and returns the awsIdentity -func executeStsIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) { +func executeSTSIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) { client := stsClientFromContext(ctx) // set the http request context so it can be cancelled @@ -260,13 +298,13 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro // validate that the host, method, and headers are correct and the expected // challenge is included in the signed portion of the request - if err := validateStsIdentityRequest(identityRequest, challenge); err != nil { + if err := validateSTSIdentityRequest(identityRequest, challenge); err != nil { return trace.Wrap(err) } // send the signed request to the public AWS API and get the node identity // from the response - identity, err := executeStsIdentityRequest(ctx, identityRequest) + identity, err := executeSTSIdentityRequest(ctx, identityRequest) if err != nil { return trace.Wrap(err) } @@ -333,18 +371,15 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c return certs, trace.Wrap(err) } -// createSignedStsIdentityRequest is called on the client side and returns an +// createSignedSTSIdentityRequest is called on the client side and returns an // sts:GetCallerIdentity request signed with the local AWS credentials -func createSignedStsIdentityRequest(challenge string) ([]byte, error) { - // use the aws sdk to generate the request - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - }) +func createSignedSTSIdentityRequest(ctx context.Context, endpointOption stsEndpointOption, challenge string) ([]byte, error) { + stsClient, err := endpointOption(ctx) if err != nil { return nil, trace.Wrap(err) } - stsService := sts.New(sess) - req, _ := stsService.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) + + req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) // set challenge header req.HTTPRequest.Header.Set(challengeHeaderKey, challenge) // request json for simpler parsing @@ -360,3 +395,90 @@ func createSignedStsIdentityRequest(challenge string) ([]byte, error) { } return signedRequest.Bytes(), nil } + +type stsEndpointOption func(context.Context) (*sts.STS, error) + +var ( + stsEndpointOptionGlobal = newGlobalSTSClient + stsEndpointOptionRegional = newRegionalSTSClient +) + +// newRegionalSTSClient returns an STS client will resolve the "global" endpoint +// for the STS service. +func newGlobalSTSClient(ctx context.Context) (*sts.STS, error) { + // sess will be used as a ConfigProvider to be passed to sts.New. It will + // load AWS configuration options from the environment, which means that AWS + // credentials may come from environment variables, files in ~/.aws/, or + // from the attached role on an EC2 instance. + sess, err := session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return sts.New(sess), nil +} + +// newRegionalSTSClient returns an STS client which attempts to resolve the local +// regional endpoint for the STS service, rather than the "global" endpoint +// which is not supported in non-default AWS partitions. +func newRegionalSTSClient(ctx context.Context) (*sts.STS, error) { + // sess will be used as a ConfigProvider to be passed to sts.New. It will + // load AWS configuration options from the environment, which means that AWS + // credentials may come from environment variables, files in ~/.aws/, or + // from the attached role on an EC2 instance. The regional STS endpoint will + // be used instead of the global endopint if the local (or preferred) region + // can be resolved from the environment. + sess, err := session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + Config: *awssdk.NewConfig().WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // will set the local region on extraConfigOptions if we can find it from + // the environment or IMDS + extraConfigOptions := awssdk.NewConfig() + + // If the region was not resolved from the environment the client will try to + // use the global STS endpoint, which will not be supported if the AWS identity + // being used is for a non-default AWS partition (such as China or + // GovCloud.) This is the default behavior on EC2, so let's try to find the + // region from the IMDS. + if clientConfig := sess.ClientConfig(sts.ServiceName); clientConfig.Endpoint == globalSTSEndpoint { + region, err := getEC2LocalRegion(ctx) + if trace.IsNotFound(err) { + // Unfortunately we could not find the region from the IMDS, go with + // the default global endpoint and hope it works. + log.Info("Unable to find the local AWS region from the environment or IMDSv2. " + + "Attempting to use the global STS endpoint for the IAM join method. " + + "This will probably fail in non-default AWS partitions such as China or GovCloud. " + + "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2.") + } else if err != nil { + // Return the unexpected error. + return nil, trace.Wrap(err) + } else { + // Found the region, set it on the config. + extraConfigOptions.Region = ®ion + } + } + + return sts.New(sess, extraConfigOptions), nil +} + +// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or +// a NotFound error if the EC2 IMDS is unavailable. +func getEC2LocalRegion(ctx context.Context) (string, error) { + imdsClient, err := utils.NewInstanceMetadataClient(ctx) + if err != nil { + return "", trace.Wrap(err) + } + + if !imdsClient.IsAvailable(ctx) { + return "", trace.NotFound("IMDS is unavailable") + } + + region, err := imdsClient.GetRegion(ctx) + return region, trace.Wrap(err) +} diff --git a/lib/auth/join_iam_test.go b/lib/auth/join_iam_test.go index c054fdd85b552..3ad6ef3ec7a3a 100644 --- a/lib/auth/join_iam_test.go +++ b/lib/auth/join_iam_test.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/trace" + "github.com/stretchr/testify/require" ) diff --git a/lib/auth/register.go b/lib/auth/register.go index a052eb1fc9309..44315dfa1554d 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -464,32 +464,58 @@ type joinServiceClient interface { func registerUsingIAMMethod(joinServiceClient joinServiceClient, token string, params RegisterParams) (*proto.Certs, error) { ctx := context.Background() - // call RegisterUsingIAMMethod with a callback to respond to the challenge - // with the join request - certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { - // create the signed sts:GetCallerIdentity request and include the challenge - signedRequest, err := createSignedStsIdentityRequest(challenge) + // Attempt to use the regional STS endpoint, fall back to using the global + // endpoint. The regional endpoint may fail if Auth is on an older version + // which does not support regional endpoints, the STS service is not + // enabled in the current region, or an unknown AWS region is configured. + var errs []error + for _, s := range []struct { + desc string + opt stsEndpointOption + }{ + { + desc: "regional", + opt: stsEndpointOptionRegional, + }, + { + desc: "global", + opt: stsEndpointOptionGlobal, + }, + } { + log.Infof("Attempting to register %s with IAM method using %s STS endpoint", params.ID.Role, s.desc) + // Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request. + certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { + // create the signed sts:GetCallerIdentity request and include the challenge + signedRequest, err := createSignedSTSIdentityRequest(ctx, s.opt, challenge) + if err != nil { + return nil, trace.Wrap(err) + } + + // send the register request including the challenge response + return &proto.RegisterUsingIAMMethodRequest{ + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ + Token: token, + HostID: params.ID.HostUUID, + NodeName: params.ID.NodeName, + Role: params.ID.Role, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + PublicTLSKey: params.PublicTLSKey, + PublicSSHKey: params.PublicSSHKey, + }, + StsIdentityRequest: signedRequest, + }, nil + }) if err != nil { - return nil, trace.Wrap(err) + log.WithError(err).Infof("Failed to register %s using %s STS endpoint", params.ID.Role, s.desc) + errs = append(errs, err) + } else { + log.Infof("Successfully registered %s with IAM method using %s STS endpoint", params.ID.Role, s.desc) + return certs, nil } + } - // send the register request including the challenge response - return &proto.RegisterUsingIAMMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - }, - StsIdentityRequest: signedRequest, - }, nil - }) - - return certs, trace.Wrap(err) + return nil, trace.NewAggregate(errs...) } // ReRegisterParams specifies parameters for re-registering diff --git a/lib/auth/valid_sts_endpoints.go b/lib/auth/valid_sts_endpoints.go new file mode 100644 index 0000000000000..48f89c2ca9e33 --- /dev/null +++ b/lib/auth/valid_sts_endpoints.go @@ -0,0 +1,58 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +// validSTSEndpoints holds a sorted list of all known valid public endpoints for +// the AWS STS service. You can generate this list by running +// $ go run github.com/nklaassen/sts-endpoints@latest --go-list +// Update aws-sdk-go in that package to learn about new endpoints. +var validSTSEndpoints = []string{ + "sts-fips.us-east-1.amazonaws.com", + "sts-fips.us-east-2.amazonaws.com", + "sts-fips.us-west-1.amazonaws.com", + "sts-fips.us-west-2.amazonaws.com", + "sts.af-south-1.amazonaws.com", + "sts.amazonaws.com", + "sts.ap-east-1.amazonaws.com", + "sts.ap-northeast-1.amazonaws.com", + "sts.ap-northeast-2.amazonaws.com", + "sts.ap-northeast-3.amazonaws.com", + "sts.ap-south-1.amazonaws.com", + "sts.ap-southeast-1.amazonaws.com", + "sts.ap-southeast-2.amazonaws.com", + "sts.ap-southeast-3.amazonaws.com", + "sts.ca-central-1.amazonaws.com", + "sts.cn-north-1.amazonaws.com.cn", + "sts.cn-northwest-1.amazonaws.com.cn", + "sts.eu-central-1.amazonaws.com", + "sts.eu-north-1.amazonaws.com", + "sts.eu-south-1.amazonaws.com", + "sts.eu-west-1.amazonaws.com", + "sts.eu-west-2.amazonaws.com", + "sts.eu-west-3.amazonaws.com", + "sts.me-south-1.amazonaws.com", + "sts.sa-east-1.amazonaws.com", + "sts.us-east-1.amazonaws.com", + "sts.us-east-2.amazonaws.com", + "sts.us-gov-east-1.amazonaws.com", + "sts.us-gov-west-1.amazonaws.com", + "sts.us-iso-east-1.c2s.ic.gov", + "sts.us-iso-west-1.c2s.ic.gov", + "sts.us-isob-east-1.sc2s.sgov.gov", + "sts.us-west-1.amazonaws.com", + "sts.us-west-2.amazonaws.com", +} diff --git a/lib/utils/ec2.go b/lib/utils/ec2.go index c94ebc6d12eec..bdf84a0fae70b 100644 --- a/lib/utils/ec2.go +++ b/lib/utils/ec2.go @@ -178,3 +178,11 @@ func (client *InstanceMetadataClient) GetTagValue(ctx context.Context, key strin } return body, nil } + +func (client *InstanceMetadataClient) GetRegion(ctx context.Context) (string, error) { + getRegionOutput, err := client.c.GetRegion(ctx, nil) + if err != nil { + return "", trace.Wrap(err) + } + return getRegionOutput.Region, nil +}