diff --git a/docs/pages/setup/guides/joining-nodes-aws.mdx b/docs/pages/setup/guides/joining-nodes-aws.mdx index 96c0a62af6d6d..2e33364957794 100644 --- a/docs/pages/setup/guides/joining-nodes-aws.mdx +++ b/docs/pages/setup/guides/joining-nodes-aws.mdx @@ -35,8 +35,9 @@ policies is sufficient. No IAM credentials at all are required on the Teleport Auth server. -The IAM join method is not compatible with the `--insecure-no-tls` flag, and is -currently not supported in the AWS China or GovCloud partitions. +The IAM join method will not work if TLS is terminated at a load balancer in +front of your Teleport Proxy Service unless the Node using this method is +connecting directly to the Auth Service. ## Prerequisites diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index e1c1894dcf667..e904a9b6a9341 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -33,9 +33,12 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils" + apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/aws" + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sts" "github.com/gravitational/trace" @@ -47,19 +50,45 @@ const ( // update our AWS SDK dependency. Since Auth should always be upgraded // before nodes, we will have a chance to update the check on Auth if we // ever have a need to allow a newer API version. - expectedStsIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15" + expectedSTSIdentityRequestBody = "Action=GetCallerIdentity&Version=2011-06-15" - // Only allowing the global sts endpoint here, Teleport nodes will only send - // requests for this endpoint. If we want to start using regional endpoints - // we can update this check before updating the nodes. - stsHost = "sts.amazonaws.com" + // Used to check if we were unable to resolve the regional STS endpoint. + globalSTSEndpoint = "https://sts.amazonaws.com" // AWS SignedHeaders will always be lowercase // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html#sigv4-auth-header-overview challengeHeaderKey = "x-teleport-challenge" ) -// validateStsIdentityRequest checks that a received sts:GetCallerIdentity +// validateSTSHost returns an error if the given stsHost is not a valid regional +// endpoint for the AWS STS service, or nil if it is valid. +// +// This is a security-critical check: we are allowing the client to tell us +// which URL we should use to validate their identity. If the client could pass +// off an attacker-controlled URL as the STS endpoint, the entire security +// mechanism of the IAM join method would be compromised. +// +// To keep this validation simple and secure, we check the given endpoint +// against a static list of known valid endpoints. We will need to update this +// list as AWS adds new regions. +func validateSTSHost(stsHost string) error { + valid := apiutils.SliceContainsStr(validSTSEndpoints, stsHost) + if valid { + return nil + } + + return trace.AccessDenied("IAM join request uses unknown STS host %q. "+ + "This could mean that the Teleport Node attempting to join the cluster is "+ + "running in a new AWS region which is unknown to this Teleport auth server. "+ + "Alternatively, if this URL looks suspicious, an attacker may be attempting to "+ + "join your Teleport cluster. "+ + "Following is the list of valid STS endpoints known to this auth server. "+ + "If a legitimate STS endpoint is not included, please file an issue at "+ + "https://github.com/gravitational/teleport. %v", + stsHost, validSTSEndpoints) +} + +// validateSTSIdentityRequest checks that a received sts:GetCallerIdentity // request is valid and includes the challenge as a signed header. An example // valid request looks like: // ``` @@ -76,9 +105,18 @@ const ( // // Action=GetCallerIdentity&Version=2011-06-15 // ``` -func validateStsIdentityRequest(req *http.Request, challenge string) error { - if req.Host != stsHost { - return trace.AccessDenied("sts identity request is for unknown host %q", req.Host) +func validateSTSIdentityRequest(req *http.Request, challenge string) (err error) { + defer func() { + // Always log a warning on the Auth server if the function detects an + // invalid sts:GetCallerIdentity request, it's either going to be caused + // by a node in a unknown region or an attacker. + if err != nil { + log.WithError(err).Warn("Detected an invalid sts:GetCallerIdentity used by a client attempting to use the IAM join method.") + } + }() + + if err := validateSTSHost(req.Host); err != nil { + return trace.Wrap(err) } if req.Method != http.MethodPost { @@ -95,7 +133,7 @@ func validateStsIdentityRequest(req *http.Request, challenge string) error { if err != nil { return trace.Wrap(err) } - if !utils.SliceContainsStr(sigV4.SignedHeaders, challengeHeaderKey) { + if !apiutils.SliceContainsStr(sigV4.SignedHeaders, challengeHeaderKey) { return trace.AccessDenied("sts identity request auth header %q does not include "+ challengeHeaderKey+" as a signed header", authHeader) } @@ -104,8 +142,8 @@ func validateStsIdentityRequest(req *http.Request, challenge string) error { if err != nil { return trace.Wrap(err) } - if !bytes.Equal([]byte(expectedStsIdentityRequestBody), body) { - return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedStsIdentityRequestBody) + if !bytes.Equal([]byte(expectedSTSIdentityRequestBody), body) { + return trace.BadParameter("sts request body %q does not equal expected %q", string(body), expectedSTSIdentityRequestBody) } return nil @@ -125,7 +163,7 @@ func parseSTSRequest(req []byte) (*http.Request, error) { httpReq.RequestURI = "" httpReq.URL = &url.URL{ Scheme: "https", - Host: stsHost, + Host: httpReq.Host, } return httpReq, nil } @@ -161,9 +199,9 @@ func stsClientFromContext(ctx context.Context) stsClient { return http.DefaultClient } -// executeStsIdentityRequest sends the sts:GetCallerIdentity HTTP request to the +// executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the // AWS API, parses the response, and returns the awsIdentity -func executeStsIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) { +func executeSTSIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) { client := stsClientFromContext(ctx) // set the http request context so it can be cancelled @@ -260,13 +298,13 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro // validate that the host, method, and headers are correct and the expected // challenge is included in the signed portion of the request - if err := validateStsIdentityRequest(identityRequest, challenge); err != nil { + if err := validateSTSIdentityRequest(identityRequest, challenge); err != nil { return trace.Wrap(err) } // send the signed request to the public AWS API and get the node identity // from the response - identity, err := executeStsIdentityRequest(ctx, identityRequest) + identity, err := executeSTSIdentityRequest(ctx, identityRequest) if err != nil { return trace.Wrap(err) } @@ -349,18 +387,15 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c return certs, nil } -// createSignedStsIdentityRequest is called on the client side and returns an +// createSignedSTSIdentityRequest is called on the client side and returns an // sts:GetCallerIdentity request signed with the local AWS credentials -func createSignedStsIdentityRequest(challenge string) ([]byte, error) { - // use the aws sdk to generate the request - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - }) +func createSignedSTSIdentityRequest(ctx context.Context, endpointOption stsEndpointOption, challenge string) ([]byte, error) { + stsClient, err := endpointOption(ctx) if err != nil { return nil, trace.Wrap(err) } - stsService := sts.New(sess) - req, _ := stsService.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) + + req, _ := stsClient.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) // set challenge header req.HTTPRequest.Header.Set(challengeHeaderKey, challenge) // request json for simpler parsing @@ -376,3 +411,90 @@ func createSignedStsIdentityRequest(challenge string) ([]byte, error) { } return signedRequest.Bytes(), nil } + +type stsEndpointOption func(context.Context) (*sts.STS, error) + +var ( + stsEndpointOptionGlobal = newGlobalSTSClient + stsEndpointOptionRegional = newRegionalSTSClient +) + +// newRegionalSTSClient returns an STS client will resolve the "global" endpoint +// for the STS service. +func newGlobalSTSClient(ctx context.Context) (*sts.STS, error) { + // sess will be used as a ConfigProvider to be passed to sts.New. It will + // load AWS configuration options from the environment, which means that AWS + // credentials may come from environment variables, files in ~/.aws/, or + // from the attached role on an EC2 instance. + sess, err := session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return sts.New(sess), nil +} + +// newRegionalSTSClient returns an STS client which attempts to resolve the local +// regional endpoint for the STS service, rather than the "global" endpoint +// which is not supported in non-default AWS partitions. +func newRegionalSTSClient(ctx context.Context) (*sts.STS, error) { + // sess will be used as a ConfigProvider to be passed to sts.New. It will + // load AWS configuration options from the environment, which means that AWS + // credentials may come from environment variables, files in ~/.aws/, or + // from the attached role on an EC2 instance. The regional STS endpoint will + // be used instead of the global endopint if the local (or preferred) region + // can be resolved from the environment. + sess, err := session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + Config: *awssdk.NewConfig().WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // will set the local region on extraConfigOptions if we can find it from + // the environment or IMDS + extraConfigOptions := awssdk.NewConfig() + + // If the region was not resolved from the environment the client will try to + // use the global STS endpoint, which will not be supported if the AWS identity + // being used is for a non-default AWS partition (such as China or + // GovCloud.) This is the default behavior on EC2, so let's try to find the + // region from the IMDS. + if clientConfig := sess.ClientConfig(sts.ServiceName); clientConfig.Endpoint == globalSTSEndpoint { + region, err := getEC2LocalRegion(ctx) + if trace.IsNotFound(err) { + // Unfortunately we could not find the region from the IMDS, go with + // the default global endpoint and hope it works. + log.Info("Unable to find the local AWS region from the environment or IMDSv2. " + + "Attempting to use the global STS endpoint for the IAM join method. " + + "This will probably fail in non-default AWS partitions such as China or GovCloud. " + + "Consider setting the AWS_REGION environment variable, setting the region in ~/.aws/config, or enabling the IMDSv2.") + } else if err != nil { + // Return the unexpected error. + return nil, trace.Wrap(err) + } else { + // Found the region, set it on the config. + extraConfigOptions.Region = ®ion + } + } + + return sts.New(sess, extraConfigOptions), nil +} + +// getEC2LocalRegion returns the AWS region this EC2 instance is running in, or +// a NotFound error if the EC2 IMDS is unavailable. +func getEC2LocalRegion(ctx context.Context) (string, error) { + imdsClient, err := utils.NewInstanceMetadataClient(ctx) + if err != nil { + return "", trace.Wrap(err) + } + + if !imdsClient.IsAvailable(ctx) { + return "", trace.NotFound("IMDS is unavailable") + } + + region, err := imdsClient.GetRegion(ctx) + return region, trace.Wrap(err) +} diff --git a/lib/auth/join_iam_test.go b/lib/auth/join_iam_test.go index f870a8b4113a5..68b6e03e56109 100644 --- a/lib/auth/join_iam_test.go +++ b/lib/auth/join_iam_test.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/trace" + "github.com/stretchr/testify/require" ) diff --git a/lib/auth/register.go b/lib/auth/register.go index fc7619a9f0f47..d938fa7789404 100644 --- a/lib/auth/register.go +++ b/lib/auth/register.go @@ -448,32 +448,58 @@ type joinServiceClient interface { func registerUsingIAMMethod(joinServiceClient joinServiceClient, token string, params RegisterParams) (*proto.Certs, error) { ctx := context.Background() - // call RegisterUsingIAMMethod with a callback to respond to the challenge - // with the join request - certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { - // create the signed sts:GetCallerIdentity request and include the challenge - signedRequest, err := createSignedStsIdentityRequest(challenge) + // Attempt to use the regional STS endpoint, fall back to using the global + // endpoint. The regional endpoint may fail if Auth is on an older version + // which does not support regional endpoints, the STS service is not + // enabled in the current region, or an unknown AWS region is configured. + var errs []error + for _, s := range []struct { + desc string + opt stsEndpointOption + }{ + { + desc: "regional", + opt: stsEndpointOptionRegional, + }, + { + desc: "global", + opt: stsEndpointOptionGlobal, + }, + } { + log.Infof("Attempting to register %s with IAM method using %s STS endpoint", params.ID.Role, s.desc) + // Call RegisterUsingIAMMethod and pass a callback to respond to the challenge with a signed join request. + certs, err := joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { + // create the signed sts:GetCallerIdentity request and include the challenge + signedRequest, err := createSignedSTSIdentityRequest(ctx, s.opt, challenge) + if err != nil { + return nil, trace.Wrap(err) + } + + // send the register request including the challenge response + return &proto.RegisterUsingIAMMethodRequest{ + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ + Token: token, + HostID: params.ID.HostUUID, + NodeName: params.ID.NodeName, + Role: params.ID.Role, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + PublicTLSKey: params.PublicTLSKey, + PublicSSHKey: params.PublicSSHKey, + }, + StsIdentityRequest: signedRequest, + }, nil + }) if err != nil { - return nil, trace.Wrap(err) + log.WithError(err).Infof("Failed to register %s using %s STS endpoint", params.ID.Role, s.desc) + errs = append(errs, err) + } else { + log.Infof("Successfully registered %s with IAM method using %s STS endpoint", params.ID.Role, s.desc) + return certs, nil } + } - // send the register request including the challenge response - return &proto.RegisterUsingIAMMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: token, - HostID: params.ID.HostUUID, - NodeName: params.ID.NodeName, - Role: params.ID.Role, - AdditionalPrincipals: params.AdditionalPrincipals, - DNSNames: params.DNSNames, - PublicTLSKey: params.PublicTLSKey, - PublicSSHKey: params.PublicSSHKey, - }, - StsIdentityRequest: signedRequest, - }, nil - }) - - return certs, trace.Wrap(err) + return nil, trace.NewAggregate(errs...) } // ReRegisterParams specifies parameters for re-registering diff --git a/lib/auth/valid_sts_endpoints.go b/lib/auth/valid_sts_endpoints.go new file mode 100644 index 0000000000000..48f89c2ca9e33 --- /dev/null +++ b/lib/auth/valid_sts_endpoints.go @@ -0,0 +1,58 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +// validSTSEndpoints holds a sorted list of all known valid public endpoints for +// the AWS STS service. You can generate this list by running +// $ go run github.com/nklaassen/sts-endpoints@latest --go-list +// Update aws-sdk-go in that package to learn about new endpoints. +var validSTSEndpoints = []string{ + "sts-fips.us-east-1.amazonaws.com", + "sts-fips.us-east-2.amazonaws.com", + "sts-fips.us-west-1.amazonaws.com", + "sts-fips.us-west-2.amazonaws.com", + "sts.af-south-1.amazonaws.com", + "sts.amazonaws.com", + "sts.ap-east-1.amazonaws.com", + "sts.ap-northeast-1.amazonaws.com", + "sts.ap-northeast-2.amazonaws.com", + "sts.ap-northeast-3.amazonaws.com", + "sts.ap-south-1.amazonaws.com", + "sts.ap-southeast-1.amazonaws.com", + "sts.ap-southeast-2.amazonaws.com", + "sts.ap-southeast-3.amazonaws.com", + "sts.ca-central-1.amazonaws.com", + "sts.cn-north-1.amazonaws.com.cn", + "sts.cn-northwest-1.amazonaws.com.cn", + "sts.eu-central-1.amazonaws.com", + "sts.eu-north-1.amazonaws.com", + "sts.eu-south-1.amazonaws.com", + "sts.eu-west-1.amazonaws.com", + "sts.eu-west-2.amazonaws.com", + "sts.eu-west-3.amazonaws.com", + "sts.me-south-1.amazonaws.com", + "sts.sa-east-1.amazonaws.com", + "sts.us-east-1.amazonaws.com", + "sts.us-east-2.amazonaws.com", + "sts.us-gov-east-1.amazonaws.com", + "sts.us-gov-west-1.amazonaws.com", + "sts.us-iso-east-1.c2s.ic.gov", + "sts.us-iso-west-1.c2s.ic.gov", + "sts.us-isob-east-1.sc2s.sgov.gov", + "sts.us-west-1.amazonaws.com", + "sts.us-west-2.amazonaws.com", +} diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go new file mode 100644 index 0000000000000..766579ec74f64 --- /dev/null +++ b/lib/cloud/aws/errors.go @@ -0,0 +1,81 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "errors" + "net/http" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/gravitational/trace" +) + +// ConvertRequestFailureError converts `error` into AWS RequestFailure errors +// to trace errors. If the provided error is not an `RequestFailure` it returns +// the error without modifying it. +func ConvertRequestFailureError(err error) error { + requestErr, ok := err.(awserr.RequestFailure) + if !ok { + return err + } + + switch requestErr.StatusCode() { + case http.StatusForbidden: + return trace.AccessDenied(requestErr.Error()) + case http.StatusConflict: + return trace.AlreadyExists(requestErr.Error()) + case http.StatusNotFound: + return trace.NotFound(requestErr.Error()) + } + + return err // Return unmodified. +} + +// ConvertIAMError converts common errors from IAM clients to trace errors. +func ConvertIAMError(err error) error { + // By error code. + if awsErr, ok := err.(awserr.Error); ok { + switch awsErr.Code() { + case iam.ErrCodeUnmodifiableEntityException: + return trace.AccessDenied(awsErr.Error()) + + case iam.ErrCodeNoSuchEntityException: + return trace.NotFound(awsErr.Error()) + + case iam.ErrCodeMalformedPolicyDocumentException, + iam.ErrCodeInvalidInputException, + iam.ErrCodeDeleteConflictException: + return trace.BadParameter(awsErr.Error()) + + case iam.ErrCodeLimitExceededException: + return trace.LimitExceeded(awsErr.Error()) + } + } + + // By status code. + return ConvertRequestFailureError(err) +} + +// ParseMetadataClientError converts a failed instance metadata service call to a trace error. +func ParseMetadataClientError(err error) error { + var httpError interface{ HTTPStatusCode() int } + if errors.As(err, &httpError) { + return trace.ReadError(httpError.HTTPStatusCode(), nil) + } + return trace.Wrap(err) +} diff --git a/lib/utils/ec2.go b/lib/utils/ec2.go index 666344b6e9041..bdf84a0fae70b 100644 --- a/lib/utils/ec2.go +++ b/lib/utils/ec2.go @@ -18,14 +18,22 @@ package utils import ( "context" + "fmt" "io" "regexp" + "strings" + "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/cloud/aws" ) +// metadataReadLimit is the largest number of bytes that will be read from imds responses. +const metadataReadLimit = 1_000_000 + // GetEC2IdentityDocument fetches the PKCS7 RSA2048 InstanceIdentityDocument // from the IMDS for this EC2 instance. func GetEC2IdentityDocument() ([]byte, error) { @@ -74,7 +82,11 @@ func GetEC2NodeID() (string, error) { // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/resource-ids.html var ec2NodeIDRE = regexp.MustCompile("^[0-9]{12}-i-[0-9a-f]{8,}$") -// IsEC2NodeID returns true if the given ID looks like an EC2 node ID +// IsEC2NodeID returns true if the given ID looks like an EC2 node ID. Uses a +// simple regex to check. Node IDs are almost always UUIDs when set +// automatically, but can be manually overridden by admins. If someone manually +// sets a host ID that looks like one of our generated EC2 node IDs, they may be +// able to trick this function, so don't use it for any critical purpose. func IsEC2NodeID(id string) bool { return ec2NodeIDRE.MatchString(id) } @@ -84,3 +96,93 @@ func IsEC2NodeID(id string) bool { func NodeIDFromIID(iid *imds.InstanceIdentityDocument) string { return iid.AccountID + "-" + iid.InstanceID } + +// InstanceMetadataClient is a wrapper for an imds.Client. +type InstanceMetadataClient struct { + c *imds.Client +} + +// InstanceMetadataClientOption allows setting options as functional arguments to an InstanceMetadataClient. +type InstanceMetadataClientOption func(client *InstanceMetadataClient) error + +// WithIMDSClient adds a custom internal imds.Client to an InstanceMetadataClient. +func WithIMDSClient(client *imds.Client) InstanceMetadataClientOption { + return func(clt *InstanceMetadataClient) error { + clt.c = client + return nil + } +} + +// NewInstanceMetadataClient creates a new instance metadata client. +func NewInstanceMetadataClient(ctx context.Context, opts ...InstanceMetadataClientOption) (*InstanceMetadataClient, error) { + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + clt := &InstanceMetadataClient{ + c: imds.NewFromConfig(cfg), + } + + for _, opt := range opts { + if err := opt(clt); err != nil { + return nil, trace.Wrap(err) + } + } + + return clt, nil +} + +// EC2 resource ID is i-{8 or 17 hex digits}, see +// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/resource-ids.html +var ec2ResourceIDRE = regexp.MustCompile("^i-[0-9a-f]{8,}$") + +// IsAvailable checks if instance metadata is available. +func (client *InstanceMetadataClient) IsAvailable(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, 250*time.Millisecond) + defer cancel() + + // try to retrieve the instance id of our EC2 instance + id, err := client.getMetadata(ctx, "instance-id") + return err == nil && ec2ResourceIDRE.MatchString(id) +} + +// getMetadata gets the raw metadata from a specified path. +func (client *InstanceMetadataClient) getMetadata(ctx context.Context, path string) (string, error) { + output, err := client.c.GetMetadata(ctx, &imds.GetMetadataInput{Path: path}) + if err != nil { + return "", trace.Wrap(aws.ParseMetadataClientError(err)) + } + defer output.Content.Close() + body, err := ReadAtMost(output.Content, metadataReadLimit) + if err != nil { + return "", trace.Wrap(err) + } + return string(body), nil +} + +// GetTagKeys gets all of the EC2 tag keys. +func (client *InstanceMetadataClient) GetTagKeys(ctx context.Context) ([]string, error) { + body, err := client.getMetadata(ctx, "tags/instance") + if err != nil { + return nil, trace.Wrap(err) + } + return strings.Split(body, "\n"), nil +} + +// GetTagValue gets the value for a specified tag key. +func (client *InstanceMetadataClient) GetTagValue(ctx context.Context, key string) (string, error) { + body, err := client.getMetadata(ctx, fmt.Sprintf("tags/instance/%s", key)) + if err != nil { + return "", trace.Wrap(err) + } + return body, nil +} + +func (client *InstanceMetadataClient) GetRegion(ctx context.Context) (string, error) { + getRegionOutput, err := client.c.GetRegion(ctx, nil) + if err != nil { + return "", trace.Wrap(err) + } + return getRegionOutput.Region, nil +}