Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/pages/setup/guides/joining-nodes-aws.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ policies is sufficient.
No IAM credentials at all are required on the Teleport Auth server.

<Admonition type="warning">
The IAM join method is not compatible with the `--insecure-no-tls` flag, and is
currently not supported in the AWS China or GovCloud partitions.
The IAM join method will not work if TLS is terminated at a load balancer in
front of your Teleport Proxy Service unless the Node using this method is
connecting directly to the Auth Service.
</Admonition>

## Prerequisites
Expand Down
174 changes: 148 additions & 26 deletions lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ import (
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/aws"

awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gravitational/trace"
Expand All @@ -47,19 +50,45 @@ const (
// update our AWS SDK dependency. Since Auth should always be upgraded
// before nodes, we will have a chance to update the check on Auth if we
// ever have a need to allow a newer API version.
expectedStsIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15"
expectedSTSIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15"

// Only allowing the global sts endpoint here, Teleport nodes will only send
// requests for this endpoint. If we want to start using regional endpoints
// we can update this check before updating the nodes.
stsHost = "sts.amazonaws.com"
// Used to check if we were unable to resolve the regional STS endpoint.
globalSTSEndpoint = "https://sts.amazonaws.com"

// AWS SignedHeaders will always be lowercase
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview
challengeHeaderKey = "x-teleport-challenge"
)

// validateStsIdentityRequest checks that a received sts:GetCallerIdentity
// validateSTSHost returns an error if the given stsHost is not a valid regional
// endpoint for the AWS STS service, or nil if it is valid.
//
// This is a security-critical check: we are allowing the client to tell us
// which URL we should use to validate their identity. If the client could pass
// off an attacker-controlled URL as the STS endpoint, the entire security
// mechanism of the IAM join method would be compromised.
//
// To keep this validation simple and secure, we check the given endpoint
// against a static list of known valid endpoints. We will need to update this
// list as AWS adds new regions.
func validateSTSHost(stsHost string) error {
valid := apiutils.SliceContainsStr(validSTSEndpoints, stsHost)
if valid {
return nil
}

return trace.AccessDenied("IAM join request uses unknown STS host %q. "+
"This could mean that the Teleport Node attempting to join the cluster is "+
"running in a new AWS region which is unknown to this Teleport auth server. "+
"Alternatively, if this URL looks suspicious, an attacker may be attempting to "+
"join your Teleport cluster. "+
"Following is the list of valid STS endpoints known to this auth server. "+
"If a legitimate STS endpoint is not included, please file an issue at "+
"https://github.com/gravitational/teleport. %v",
stsHost, validSTSEndpoints)
}

// validateSTSIdentityRequest checks that a received sts:GetCallerIdentity
// request is valid and includes the challenge as a signed header. An example
// valid request looks like:
// ```
Expand All @@ -76,9 +105,18 @@ const (
//
// Action=GetCallerIdentity&Version=2011-06-15
// ```
func validateStsIdentityRequest(req *http.Request, challenge string) error {
if req.Host != stsHost {
return trace.AccessDenied("sts identity request is for unknown host %q", req.Host)
func validateSTSIdentityRequest(req *http.Request, challenge string) (err error) {
defer func() {
// Always log a warning on the Auth server if the function detects an
// invalid sts:GetCallerIdentity request, it's either going to be caused
// by a node in a unknown region or an attacker.
if err != nil {
log.WithError(err).Warn("Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method.")
}
}()

if err := validateSTSHost(req.Host); err != nil {
return trace.Wrap(err)
}

if req.Method != http.MethodPost {
Expand All @@ -95,7 +133,7 @@ func validateStsIdentityRequest(req *http.Request, challenge string) error {
if err != nil {
return trace.Wrap(err)
}
if !utils.SliceContainsStr(sigV4.SignedHeaders, challengeHeaderKey) {
if !apiutils.SliceContainsStr(sigV4.SignedHeaders, challengeHeaderKey) {
return trace.AccessDenied("sts identity request auth header %q does not include "+
challengeHeaderKey+" as a signed header", authHeader)
}
Expand All @@ -104,8 +142,8 @@ func validateStsIdentityRequest(req *http.Request, challenge string) error {
if err != nil {
return trace.Wrap(err)
}
if !bytes.Equal([]byte(expectedStsIdentityRequestBody), body) {
return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedStsIdentityRequestBody)
if !bytes.Equal([]byte(expectedSTSIdentityRequestBody), body) {
return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedSTSIdentityRequestBody)
}

return nil
Expand All @@ -125,7 +163,7 @@ func parseSTSRequest(req []byte) (*http.Request, error) {
httpReq.RequestURI = ""
httpReq.URL = &url.URL{
Scheme: "https",
Host: stsHost,
Host: httpReq.Host,
}
return httpReq, nil
}
Expand Down Expand Up @@ -161,9 +199,9 @@ func stsClientFromContext(ctx context.Context) stsClient {
return http.DefaultClient
}

// executeStsIdentityRequest sends the sts:GetCallerIdentity HTTP request to the
// executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the
// AWS API, parses the response, and returns the awsIdentity
func executeStsIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) {
func executeSTSIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) {
client := stsClientFromContext(ctx)

// set the http request context so it can be cancelled
Expand Down Expand Up @@ -260,13 +298,13 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro

// validate that the host, method, and headers are correct and the expected
// challenge is included in the signed portion of the request
if err := validateStsIdentityRequest(identityRequest, challenge); err != nil {
if err := validateSTSIdentityRequest(identityRequest, challenge); err != nil {
return trace.Wrap(err)
}

// send the signed request to the public AWS API and get the node identity
// from the response
identity, err := executeStsIdentityRequest(ctx, identityRequest)
identity, err := executeSTSIdentityRequest(ctx, identityRequest)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -349,18 +387,15 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c
return certs, nil
}

// createSignedStsIdentityRequest is called on the client side and returns an
// createSignedSTSIdentityRequest is called on the client side and returns an
// sts:GetCallerIdentity request signed with the local AWS credentials
func createSignedStsIdentityRequest(challenge string) ([]byte, error) {
// use the aws sdk to generate the request
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
func createSignedSTSIdentityRequest(ctx context.Context, endpointOption stsEndpointOption, challenge string) ([]byte, error) {
stsClient, err := endpointOption(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
stsService := sts.New(sess)
req, _ := stsService.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})

req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
// set challenge header
req.HTTPRequest.Header.Set(challengeHeaderKey, challenge)
// request json for simpler parsing
Expand All @@ -376,3 +411,90 @@ func createSignedStsIdentityRequest(challenge string) ([]byte, error) {
}
return signedRequest.Bytes(), nil
}

type stsEndpointOption func(context.Context) (*sts.STS, error)

var (
stsEndpointOptionGlobal = newGlobalSTSClient
stsEndpointOptionRegional = newRegionalSTSClient
)

// newRegionalSTSClient returns an STS client will resolve the "global" endpoint
// for the STS service.
func newGlobalSTSClient(ctx context.Context) (*sts.STS, error) {
// sess will be used as a ConfigProvider to be passed to sts.New. It will
// load AWS configuration options from the environment, which means that AWS
// credentials may come from environment variables, files in ~/.aws/, or
// from the attached role on an EC2 instance.
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, trace.Wrap(err)
}
return sts.New(sess), nil
}

// newRegionalSTSClient returns an STS client which attempts to resolve the local
// regional endpoint for the STS service, rather than the "global" endpoint
// which is not supported in non-default AWS partitions.
func newRegionalSTSClient(ctx context.Context) (*sts.STS, error) {
// sess will be used as a ConfigProvider to be passed to sts.New. It will
// load AWS configuration options from the environment, which means that AWS
// credentials may come from environment variables, files in ~/.aws/, or
// from the attached role on an EC2 instance. The regional STS endpoint will
// be used instead of the global endopint if the local (or preferred) region
// can be resolved from the environment.
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
Config: *awssdk.NewConfig().WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint),
})
if err != nil {
return nil, trace.Wrap(err)
}

// will set the local region on extraConfigOptions if we can find it from
// the environment or IMDS
extraConfigOptions := awssdk.NewConfig()

// If the region was not resolved from the environment the client will try to
// use the global STS endpoint, which will not be supported if the AWS identity
// being used is for a non-default AWS partition (such as China or
// GovCloud.) This is the default behavior on EC2, so let's try to find the
// region from the IMDS.
if clientConfig := sess.ClientConfig(sts.ServiceName); clientConfig.Endpoint == globalSTSEndpoint {
region, err := getEC2LocalRegion(ctx)
if trace.IsNotFound(err) {
// Unfortunately we could not find the region from the IMDS, go with
// the default global endpoint and hope it works.
log.Info("Unable to find the local AWS region from the environment or IMDSv2. " +
"Attempting to use the global STS endpoint for the IAM join method. " +
"This will probably fail in non-default AWS partitions such as China or GovCloud. " +
"Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2.")
} else if err != nil {
// Return the unexpected error.
return nil, trace.Wrap(err)
} else {
// Found the region, set it on the config.
extraConfigOptions.Region = &region
}
}

return sts.New(sess, extraConfigOptions), nil
}

// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or
// a NotFound error if the EC2 IMDS is unavailable.
func getEC2LocalRegion(ctx context.Context) (string, error) {
imdsClient, err := utils.NewInstanceMetadataClient(ctx)
if err != nil {
return "", trace.Wrap(err)
}

if !imdsClient.IsAvailable(ctx) {
return "", trace.NotFound("IMDS is unavailable")
}

region, err := imdsClient.GetRegion(ctx)
return region, trace.Wrap(err)
}
1 change: 1 addition & 0 deletions lib/auth/join_iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/trace"

"github.com/stretchr/testify/require"
)

Expand Down
72 changes: 49 additions & 23 deletions lib/auth/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,32 +448,58 @@ type joinServiceClient interface {
func registerUsingIAMMethod(joinServiceClient joinServiceClient, token string, params RegisterParams) (*proto.Certs, error) {
ctx := context.Background()

// call RegisterUsingIAMMethod with a callback to respond to the challenge
// with the join request
certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
// create the signed sts:GetCallerIdentity request and include the challenge
signedRequest, err := createSignedStsIdentityRequest(challenge)
// Attempt to use the regional STS endpoint, fall back to using the global
// endpoint. The regional endpoint may fail if Auth is on an older version
// which does not support regional endpoints, the STS service is not
// enabled in the current region, or an unknown AWS region is configured.
var errs []error
for _, s := range []struct {
desc string
opt stsEndpointOption
}{
{
desc: "regional",
opt: stsEndpointOptionRegional,
},
{
desc: "global",
opt: stsEndpointOptionGlobal,
},
} {
log.Infof("Attempting to register %s with IAM method using %s STS endpoint", params.ID.Role, s.desc)
// Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request.
certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
// create the signed sts:GetCallerIdentity request and include the challenge
signedRequest, err := createSignedSTSIdentityRequest(ctx, s.opt, challenge)
if err != nil {
return nil, trace.Wrap(err)
}

// send the register request including the challenge response
return &proto.RegisterUsingIAMMethodRequest{
RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
Token: token,
HostID: params.ID.HostUUID,
NodeName: params.ID.NodeName,
Role: params.ID.Role,
AdditionalPrincipals: params.AdditionalPrincipals,
DNSNames: params.DNSNames,
PublicTLSKey: params.PublicTLSKey,
PublicSSHKey: params.PublicSSHKey,
},
StsIdentityRequest: signedRequest,
}, nil
})
if err != nil {
return nil, trace.Wrap(err)
log.WithError(err).Infof("Failed to register %s using %s STS endpoint", params.ID.Role, s.desc)
errs = append(errs, err)
} else {
log.Infof("Successfully registered %s with IAM method using %s STS endpoint", params.ID.Role, s.desc)
return certs, nil
}
}

// send the register request including the challenge response
return &proto.RegisterUsingIAMMethodRequest{
RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
Token: token,
HostID: params.ID.HostUUID,
NodeName: params.ID.NodeName,
Role: params.ID.Role,
AdditionalPrincipals: params.AdditionalPrincipals,
DNSNames: params.DNSNames,
PublicTLSKey: params.PublicTLSKey,
PublicSSHKey: params.PublicSSHKey,
},
StsIdentityRequest: signedRequest,
}, nil
})

return certs, trace.Wrap(err)
return nil, trace.NewAggregate(errs...)
}

// ReRegisterParams specifies parameters for re-registering
Expand Down
Loading