diff --git a/api/types/appserver.go b/api/types/appserver.go index aa5a96b43d2d4..0921dc12470f7 100644 --- a/api/types/appserver.go +++ b/api/types/appserver.go @@ -285,6 +285,16 @@ func (s *AppServerV3) GetAllLabels() map[string]string { return CombineLabels(staticLabels, dynamicLabels) } +// GetStaticLabels returns the app server static labels. +func (s *AppServerV3) GetStaticLabels() map[string]string { + return s.Metadata.Labels +} + +// SetStaticLabels sets the app server static labels. +func (s *AppServerV3) SetStaticLabels(sl map[string]string) { + s.Metadata.Labels = sl +} + // Copy returns a copy of this app server object. func (s *AppServerV3) Copy() AppServer { return proto.Clone(s).(*AppServerV3) diff --git a/api/types/constants.go b/api/types/constants.go index 967bbaef6c866..0a07de243cddc 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -323,6 +323,9 @@ const ( OriginCloud = "cloud" ) +// EC2HostnameTag is the name of the EC2 tag used to override a node's hostname. +const EC2HostnameTag = "TeleportHostname" + // OriginValues lists all possible origin values. var OriginValues = []string{OriginDefaults, OriginConfigFile, OriginDynamic, OriginCloud} diff --git a/api/types/databaseserver.go b/api/types/databaseserver.go index db2637513a3cc..2313b9ce4bc56 100644 --- a/api/types/databaseserver.go +++ b/api/types/databaseserver.go @@ -272,6 +272,16 @@ func (s *DatabaseServerV3) GetAllLabels() map[string]string { return CombineLabels(staticLabels, s.Spec.DynamicLabels) } +// GetStaticLabels returns the database server static labels. +func (s *DatabaseServerV3) GetStaticLabels() map[string]string { + return s.Metadata.Labels +} + +// SetStaticLabels sets the database server static labels. +func (s *DatabaseServerV3) SetStaticLabels(sl map[string]string) { + s.Metadata.Labels = sl +} + // Copy returns a copy of this database server object. func (s *DatabaseServerV3) Copy() DatabaseServer { return proto.Clone(s).(*DatabaseServerV3) diff --git a/api/types/desktop.go b/api/types/desktop.go index 6fdda0b05fa75..8bc7d23cf15f2 100644 --- a/api/types/desktop.go +++ b/api/types/desktop.go @@ -99,6 +99,16 @@ func (s *WindowsDesktopServiceV3) GetAllLabels() map[string]string { return s.Metadata.Labels } +// GetStaticLabels returns the windows desktop static labels. +func (s *WindowsDesktopServiceV3) GetStaticLabels() map[string]string { + return s.Metadata.Labels +} + +// SetStaticLabels sets the windows desktop static labels. +func (s *WindowsDesktopServiceV3) SetStaticLabels(sl map[string]string) { + s.Metadata.Labels = sl +} + // GetHostname returns the windows hostname of this service. func (s *WindowsDesktopServiceV3) GetHostname() string { return s.Spec.Hostname @@ -177,6 +187,16 @@ func (d *WindowsDesktopV3) GetAllLabels() map[string]string { return CombineLabels(d.Metadata.Labels, nil) } +// GetStaticLabels returns the windows desktop static labels. +func (d *WindowsDesktopV3) GetStaticLabels() map[string]string { + return d.Metadata.Labels +} + +// SetStaticLabels sets the windows desktop static labels. +func (d *WindowsDesktopV3) SetStaticLabels(sl map[string]string) { + d.Metadata.Labels = sl +} + // LabelsString returns all desktop labels as a string. func (d *WindowsDesktopV3) LabelsString() string { return LabelsAsString(d.Metadata.Labels, nil) diff --git a/api/types/resource.go b/api/types/resource.go index 86740fb17ab4b..7f9be4eb9511a 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -82,6 +82,10 @@ type ResourceWithLabels interface { ResourceWithOrigin // GetAllLabels returns all resource's labels. GetAllLabels() map[string]string + // GetStaticLabels returns the resource's static labels. + GetStaticLabels() map[string]string + // SetStaticLabels sets the resource's static labels. + SetStaticLabels(sl map[string]string) // MatchSearch goes through select field values of a resource // and tries to match against the list of search values. MatchSearch(searchValues []string) bool diff --git a/api/types/server.go b/api/types/server.go index e6cd77e5dfa54..866ea49e73bb7 100644 --- a/api/types/server.go +++ b/api/types/server.go @@ -221,11 +221,25 @@ func (s *ServerV2) GetHostname() string { return s.Spec.Hostname } +// GetLabels and GetStaticLabels are the same, and that is intentional. GetLabels +// exists to preserve backwards compatibility, while GetStaticLabels exists to +// implement ResourcesWithLabels. + // GetLabels returns server's static label key pairs func (s *ServerV2) GetLabels() map[string]string { return s.Metadata.Labels } +// GetStaticLabels returns the server static labels. +func (s *ServerV2) GetStaticLabels() map[string]string { + return s.Metadata.Labels +} + +// SetStaticLabels sets the server static labels. +func (s *ServerV2) SetStaticLabels(sl map[string]string) { + s.Metadata.Labels = sl +} + // GetCmdLabels returns command labels func (s *ServerV2) GetCmdLabels() map[string]CommandLabel { if s.Spec.CmdLabels == nil { diff --git a/docs/img/aws/allow-tags.png b/docs/img/aws/allow-tags.png new file mode 100644 index 0000000000000..43cbbca289c14 Binary files /dev/null and b/docs/img/aws/allow-tags.png differ diff --git a/docs/img/aws/instance-settings.png b/docs/img/aws/instance-settings.png new file mode 100644 index 0000000000000..bf8d600d40145 Binary files /dev/null and b/docs/img/aws/instance-settings.png differ diff --git a/docs/img/aws/launch-instance-advanced-options.png b/docs/img/aws/launch-instance-advanced-options.png new file mode 100644 index 0000000000000..c9ba1f5e342d4 Binary files /dev/null and b/docs/img/aws/launch-instance-advanced-options.png differ diff --git a/docs/pages/setup/guides/ec2-tags.mdx b/docs/pages/setup/guides/ec2-tags.mdx index c5085065d5ef7..7508ddfe31779 100644 --- a/docs/pages/setup/guides/ec2-tags.mdx +++ b/docs/pages/setup/guides/ec2-tags.mdx @@ -4,131 +4,81 @@ description: How to set up Teleport Node labels based on EC2 tags h1: Sync EC2 Tags and Teleport Node Labels --- -This guide will explain how to set up Teleport Node labels based on Amazon EC2 tags. +When running on an AWS EC2 instance, Teleport will automatically detect and import EC2 tags as +Teleport labels for SSH nodes, Applications, Databases, and Kubernetes clusters. Labels created +this way will have the `aws/` prefix. + +If the tag `TeleportHostname` (case-sensitive) is present, its value will override the node's hostname. + +```bash +$ tsh ls +Node Name Address Labels +-------------------- -------------- ----------------------------------------------------------------------------------------------------------------------- +fakehost.example.com 127.0.0.1:3022 env=example,hostname=ip-172-31-53-70,aws/Name=atburke-dev,aws/TagKey=TagValue,aws/TeleportHostname=fakehost.example.com +``` + +<Notice type="note"> + For services that manage multiple resources (such as the database service), each resource will receive the + same labels from EC2. +</Notice> ## Prerequisites (!docs/pages/includes/edition-prereqs-tabs.mdx!) - - One Teleport Node running on an Amazon EC2 instance. See [Adding Nodes](../admin/adding-nodes.mdx) for how to set up a Teleport Node. -- The following software installed on your Teleport Node: `curl`, `python`, and - the `aws` CLI, which comes from the `awscli` Python package. -## Step 1/3. Deploy the script +## Enable tags in instance metadata -You’ll need a script on your EC2 instance that can query the AWS API and get the -values of your instance's tags for you. The Teleport Node will then use these -values to execute RBAC rules. +To allow Teleport to import EC2 tags, tags must be enabled in the instance metadata. This can be done +via the AWS console or the AWS CLI. See the [AWS documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#allow-access-to-tags-in-IMDS) +for more details. -Here’s one script you can use: +<Admonition type="note" title="Note"> + Only instances that are running on the [Nitro system](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-types.html#ec2-nitro-instances) + will update their tags while running. All other instance types [must be restarted](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Tags.html#work-with-tags-in-IMDS) + to update tags. +</Admonition> -```code -#!/bin/bash -if [[ "$1" == "" ]]; then - echo "Usage: $(basename $0) <tag>" - exit 1 -fi -TAG_NAME=$1 - -IMDS_TOKEN=$(curl -sS -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 300") -IMDS_TOKEN_HEADER="-H \"X-aws-ec2-metadata-token: ${IMDS_TOKEN}\"" -INSTANCE_ID=$(curl -sS "${IMDS_TOKEN_HEADER}" http://169.254.169.254/latest/meta-data/instance-id) -REGION=$(curl -sS "${IMDS_TOKEN_HEADER}" http://169.254.169.254/latest/meta-data/placement/availability-zone | sed -e "s:\([0-9][0-9]*\)[a-z]*\$:\\1:") -TAG_VALUE="$(aws ec2 describe-tags --filters "Name=resource-id,Values=$INSTANCE_ID" "Name=key,Values=$TAG_NAME" --region $REGION --output=text | cut -f5)" - -if [[ "${TAG_VALUE}" == "" ]]; then - echo "<null>" -else - echo $TAG_VALUE -fi -``` +### AWS EC2 Console -Save this script to `/usr/local/bin/get-tag.sh` on your EC2 instance. -Run the command below to make it executable: +To launch a new instance with instance metadata tags enabled: +1. Open `Advanced Options` at the bottom of the page. +2. Ensure that `Metadata accessible` is not disabled. +3. Enable `Allow tags in metadata`. -```code -$ chmod +x /usr/local/bin/get-tag.sh -``` +<Figure align="left" bordered caption="Advanced Options"> +  +</Figure> -## Step 2/3. Set up an IAM role +To modify an existing instance to enable instance metadata tags: +1. From the instance summary, go to `Actions > Instance Settings > Allow tags in instance metadata`. +2. Enable `Allow`. -Grant your EC2 instance an IAM role that will allow it to query tag values for EC2 instances. +<Figure align="left" bordered caption="Instance Settings"> +  +</Figure> -Here’s an example policy which you can add to an IAM role: +<Figure align="left" bordered caption="Allow Tags"> +  +</Figure> +### AWS CLI -```json -{ - "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "ec2:DescribeTags", - "Resource": "*" - } - ] -} -``` - -Once this is done, query the value of the test tag on your EC2 instance by running the following command: +To modify the instance at launch: ```code -$ /usr/local/bin/get-tag.sh test -tagValue +$ aws ec2 run-instances \ + --image-id <image-id> \ + --instance-type <instance-type> \ + --metadata-options "InstanceMetadataTags=enabled" + ... ``` -## Step 3/3. Modify the Teleport Node config file +To modify a running instance: -Modify the Teleport config file on your node (`/etc/teleport.yaml`) as follows: - -```yaml -teleport: - ssh_service: - enabled: yes - listen_addr: 0.0.0.0:3022 - commands: - - name: aws_tag_test - command: ['/usr/local/bin/get-tag.sh', 'test'] - period: 1h -``` - -This config will add a label with the key `aws_tag_test` to your instance. Its value will be set to whatever the `test` tag is set to and it will be updated once every hour. - -Restart Teleport on the node and you should see the new label appear: - -```txt -Node Name Address Labels ------------------------------ ----------------------------------------------------------------------- ------------------------------------------------------------------------------------------- -example 172.31.26.55:3022 aws_tag_test=tagValue -``` - -Now you have a label on the instance which you can use inside a Teleport role. Here’s an example role: - -```yaml -kind: role -version: v5 -metadata: - name: test-tag-role -spec: - allow: - logins: - - ec2-user - node_labels: - 'aws_tag_test': 'tagValue' - deny: {} - options: - cert_format: standard - forward_agent: true - max_session_ttl: 2h0m0s - port_forwarding: true +```code +$ aws ec2 modify-instance-metadata-options \ + --instance-id i-123456789example \ + --instance-metadata-tags enabled ``` - -When assigned to Teleport users, this role will only allow them to log in to -Teleport Nodes which have the `aws_tag_test` label with the value of `tagValue`, -effectively gating access to infrastructure based on the value of the EC2 `test` -tag. - -By adding multiple commands to a Teleport Node, setting the values of different -tags, then adding Teleport roles that reference these tags, you can build -fine-grained RBAC systems based on your EC2 tagging structure. diff --git a/integration/ec2_test.go b/integration/ec2_test.go index 2a56ac22a208c..1810f75867202 100644 --- a/integration/ec2_test.go +++ b/integration/ec2_test.go @@ -18,6 +18,7 @@ package integration import ( "context" + "fmt" "io" "net" "os" @@ -30,9 +31,11 @@ import ( "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/labels/ec2" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" @@ -282,3 +285,180 @@ func TestIAMNodeJoin(t *testing.T) { return len(nodes) > 0 }, time.Minute, time.Second, "waiting for node to join cluster") } + +type mockIMDSClient struct { + tags map[string]string +} + +func (m *mockIMDSClient) IsAvailable(ctx context.Context) bool { + return true +} + +func (m *mockIMDSClient) GetTagKeys(ctx context.Context) ([]string, error) { + keys := make([]string, 0, len(m.tags)) + for k := range m.tags { + keys = append(keys, k) + } + return keys, nil +} + +func (m *mockIMDSClient) GetTagValue(ctx context.Context, key string) (string, error) { + if value, ok := m.tags[key]; ok { + return value, nil + } + return "", trace.NotFound("Tag %q not found", key) +} + +// TestEC2Labels is an integration test which asserts that Teleport correctly picks up +// EC2 tags when running on an EC2 instance. +func TestEC2Labels(t *testing.T) { + storageConfig := backend.Config{ + Type: lite.GetName(), + Params: backend.Params{ + "path": t.TempDir(), + "poll_stream_period": 50 * time.Millisecond, + }, + } + tconf := service.MakeDefaultConfig() + tconf.Log = newSilentLogger() + tconf.DataDir = t.TempDir() + tconf.Auth.Enabled = true + tconf.Proxy.Enabled = true + tconf.Proxy.DisableWebInterface = true + tconf.Auth.StorageConfig = storageConfig + tconf.Auth.SSHAddr.Addr = net.JoinHostPort(Host, ports.Pop()) + tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.SSHAddr) + + tconf.SSH.Enabled = true + tconf.SSH.Addr.Addr = net.JoinHostPort(Host, ports.Pop()) + + appConf := service.App{ + Name: "test-app", + URI: "app.example.com", + } + + tconf.Apps.Enabled = true + tconf.Apps.Apps = []service.App{appConf} + + dbConfig := service.Database{ + Name: "test-db", + Protocol: "postgres", + URI: "postgres://somewhere.example.com", + } + tconf.Databases.Enabled = true + tconf.Databases.Databases = []service.Database{dbConfig} + + enableKubernetesService(t, tconf) + + imClient := &mockIMDSClient{ + tags: map[string]string{ + "Name": "my-instance", + }, + } + + proc, err := service.NewTeleport(tconf, service.WithIMDSClient(imClient)) + require.NoError(t, err) + require.NoError(t, proc.Start()) + t.Cleanup(func() { require.NoError(t, proc.Close()) }) + + ctx := context.Background() + authServer := proc.GetAuthServer() + + var nodes []types.Server + var apps []types.AppServer + var databases []types.DatabaseServer + var kubes []types.Server + + // Wait for everything to come online. + require.Eventually(t, func() bool { + var err error + nodes, err = authServer.GetNodes(ctx, tconf.SSH.Namespace) + require.NoError(t, err) + apps, err = authServer.GetApplicationServers(ctx, tconf.SSH.Namespace) + require.NoError(t, err) + databases, err = authServer.GetDatabaseServers(ctx, tconf.SSH.Namespace) + require.NoError(t, err) + kubes, err = authServer.GetKubeServices(ctx) + require.NoError(t, err) + return len(nodes) == 1 && len(apps) == 1 && len(databases) == 1 && len(kubes) == 1 + }, 10*time.Second, time.Second) + + tagName := fmt.Sprintf("%s/Name", ec2.AWSNamespace) + + // Check that EC2 labels were applied. + require.Eventually(t, func() bool { + node, err := authServer.GetNode(ctx, tconf.SSH.Namespace, nodes[0].GetName()) + require.NoError(t, err) + _, nodeHasLabel := node.GetAllLabels()[tagName] + apps, err := authServer.GetApplicationServers(ctx, tconf.SSH.Namespace) + require.NoError(t, err) + require.Len(t, apps, 1) + app := apps[0].GetApp() + _, appHasLabel := app.GetAllLabels()[tagName] + + databases, err := authServer.GetDatabaseServers(ctx, tconf.SSH.Namespace) + require.NoError(t, err) + require.Len(t, databases, 1) + database := databases[0].GetDatabase() + _, dbHasLabel := database.GetAllLabels()[tagName] + + kubeClusters := getKubeClusters(t, authServer) + require.Len(t, kubeClusters, 1) + kube := kubeClusters[0] + _, kubeHasLabel := kube.StaticLabels[tagName] + return nodeHasLabel && appHasLabel && dbHasLabel && kubeHasLabel + }, 10*time.Second, time.Second) +} + +// TestEC2Hostname is an integration test which asserts that Teleport sets its +// hostname if the EC2 tag `TeleportHostname` is available. This test must be +// run on an instance with tag `TeleportHostname=fakehost.example.com`. +func TestEC2Hostname(t *testing.T) { + teleportHostname := "fakehost.example.com" + + storageConfig := backend.Config{ + Type: lite.GetName(), + Params: backend.Params{ + "path": t.TempDir(), + "poll_stream_period": 50 * time.Millisecond, + }, + } + tconf := service.MakeDefaultConfig() + tconf.Log = newSilentLogger() + tconf.DataDir = t.TempDir() + tconf.Auth.Enabled = true + tconf.Proxy.Enabled = true + tconf.Proxy.DisableWebInterface = true + tconf.Auth.StorageConfig = storageConfig + tconf.Auth.SSHAddr.Addr = net.JoinHostPort(Host, ports.Pop()) + tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.SSHAddr) + + tconf.SSH.Enabled = true + tconf.SSH.Addr.Addr = net.JoinHostPort(Host, ports.Pop()) + + imClient := &mockIMDSClient{ + tags: map[string]string{ + types.EC2HostnameTag: teleportHostname, + }, + } + + proc, err := service.NewTeleport(tconf, service.WithIMDSClient(imClient)) + require.NoError(t, err) + require.NoError(t, proc.Start()) + t.Cleanup(func() { require.NoError(t, proc.Close()) }) + + ctx := context.Background() + authServer := proc.GetAuthServer() + var node types.Server + require.Eventually(t, func() bool { + nodes, err := authServer.GetNodes(ctx, tconf.SSH.Namespace) + require.NoError(t, err) + if len(nodes) == 1 { + node = nodes[0] + return true + } + return false + }, 10*time.Second, time.Second) + + require.Equal(t, teleportHostname, node.GetHostname()) +} diff --git a/integration/helpers.go b/integration/helpers.go index 2e8f388e194ad..2b73d8d24b95a 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -41,6 +41,7 @@ import ( "github.com/stretchr/testify/require" + apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" @@ -1786,7 +1787,7 @@ func enableKubernetesService(t *testing.T, config *service.Config) { err = kubeconfig.Update(kubeConfigPath, kubeconfig.Values{ TeleportClusterName: "teleport-cluster", - ClusterAddr: net.JoinHostPort(Host, ports.Pop()), + ClusterAddr: "https://" + net.JoinHostPort(Host, ports.Pop()), Credentials: key, }) require.NoError(t, err) @@ -1796,6 +1797,23 @@ func enableKubernetesService(t *testing.T, config *service.Config) { config.Kube.ListenAddr = utils.MustParseAddr(net.JoinHostPort(Host, ports.Pop())) } +// getKubeClusters gets all kubernetes clusters accessible from a given auth server. +func getKubeClusters(t *testing.T, as *auth.Server) []*types.KubernetesCluster { + ctx := context.Background() + resources, err := apiclient.GetResourcesWithFilters(ctx, as, proto.ListResourcesRequest{ + ResourceType: types.KindKubeService, + }) + require.NoError(t, err) + kss, err := types.ResourcesWithLabels(resources).AsServers() + require.NoError(t, err) + + clusters := make([]*types.KubernetesCluster, 0) + for _, ks := range kss { + clusters = append(clusters, ks.GetKubernetesClusters()...) + } + return clusters +} + func genUserKey() (*client.Key, error) { caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{ CommonName: "localhost", diff --git a/integration/utmp_integration_test.go b/integration/utmp_integration_test.go index d912a589ae4a0..f679d4c4d5e28 100644 --- a/integration/utmp_integration_test.go +++ b/integration/utmp_integration_test.go @@ -277,7 +277,7 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx { Period: types.NewDuration(time.Millisecond), Command: []string{"expr", "1", "+", "3"}, }, - }, + }, nil, ), regular.SetBPF(&bpf.NOP{}), regular.SetRestrictedSessionManager(&restricted.NOP{}), diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go index 25bb84877eae6..6a26d3d478eec 100644 --- a/lib/cloud/aws/errors.go +++ b/lib/cloud/aws/errors.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "errors" "net/http" "github.com/aws/aws-sdk-go/aws/awserr" @@ -43,3 +44,12 @@ func ConvertRequestFailureError(err error) error { return err // Return unmodified. } + +// ParseMetadataClientError converts a failed instance metadata service call to a trace error. +func ParseMetadataClientError(err error) error { + var httpError interface{ HTTPStatusCode() int } + if errors.As(err, &httpError) { + return trace.ReadError(httpError.HTTPStatusCode(), nil) + } + return trace.Wrap(err) +} diff --git a/lib/cloud/aws/imds.go b/lib/cloud/aws/imds.go new file mode 100644 index 0000000000000..442003aec8ce9 --- /dev/null +++ b/lib/cloud/aws/imds.go @@ -0,0 +1,30 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import "context" + +// InstanceMetadata is an interface for fetching information from EC2 instance +// metadata. +type InstanceMetadata interface { + // IsAvailable checks if instance metadata is available. + IsAvailable(ctx context.Context) bool + // GetTagKeys gets all of the EC2 tag keys. + GetTagKeys(ctx context.Context) ([]string, error) + // GetTagValue gets the value for a specified tag key. + GetTagValue(ctx context.Context, key string) (string, error) +} diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index f5a57d7bf5168..af091d58a8eb7 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -141,6 +141,9 @@ type ForwarderConfig struct { // DynamicLabels is map of dynamic labels associated with this cluster. // Used for RBAC. DynamicLabels *labels.Dynamic + // CloudLabels is a map of labels imported from a cloud provider associated with this + // cluster. Used for RBAC. + CloudLabels labels.Importer // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher // CheckImpersonationPermissions is an optional override of the default @@ -1980,6 +1983,20 @@ func (f *Forwarder) requestCertificate(ctx authContext) (*tls.Config, error) { return tlsConfig, nil } +// getStaticLabels gets the labels that the forwarder should present as static, +// which includes EC2 labels if available. +func (f *Forwarder) getStaticLabels() map[string]string { + if f.cfg.CloudLabels == nil { + return f.cfg.StaticLabels + } + labels := f.cfg.CloudLabels.Get() + // Let static labels override ec2 labels. + for k, v := range f.cfg.StaticLabels { + labels[k] = v + } + return labels +} + func (f *Forwarder) kubeClusters() []*types.KubernetesCluster { var dynLabels map[string]types.CommandLabelV2 if f.cfg.DynamicLabels != nil { @@ -1990,7 +2007,7 @@ func (f *Forwarder) kubeClusters() []*types.KubernetesCluster { for n := range f.creds { res = append(res, &types.KubernetesCluster{ Name: n, - StaticLabels: f.cfg.StaticLabels, + StaticLabels: f.getStaticLabels(), DynamicLabels: dynLabels, }) } diff --git a/lib/labels/ec2/ec2.go b/lib/labels/ec2/ec2.go new file mode 100644 index 0000000000000..f0ecf9df0c51e --- /dev/null +++ b/lib/labels/ec2/ec2.go @@ -0,0 +1,154 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ec2 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" +) + +const ( + // AWSNamespace is used as the namespace prefix for any labels + // imported from AWS. + AWSNamespace = "aws" + // ec2LabelUpdatePeriod is the period for updating EC2 labels. + ec2LabelUpdatePeriod = time.Hour +) + +// Config is the configuration for the EC2 label service. +type Config struct { + Client aws.InstanceMetadata + Clock clockwork.Clock + Log logrus.FieldLogger +} + +func (conf *Config) checkAndSetDefaults(ctx context.Context) error { + if conf.Client == nil { + client, err := utils.NewInstanceMetadataClient(ctx) + if err != nil { + return trace.Wrap(err) + } + conf.Client = client + } + if conf.Clock == nil { + conf.Clock = clockwork.NewRealClock() + } + if conf.Log == nil { + conf.Log = logrus.WithField(trace.Component, "ec2labels") + } + return nil +} + +// EC2 is a service that periodically imports tags from EC2 via instance +// metadata. +type EC2 struct { + c *Config + mu sync.RWMutex + labels map[string]string + + closeCh chan struct{} +} + +func New(ctx context.Context, c *Config) (*EC2, error) { + if err := c.checkAndSetDefaults(ctx); err != nil { + return nil, trace.Wrap(err) + } + return &EC2{ + c: c, + labels: make(map[string]string), + closeCh: make(chan struct{}), + }, nil +} + +// Get returns the list of updated EC2 labels. +func (l *EC2) Get() map[string]string { + l.mu.RLock() + defer l.mu.RUnlock() + return l.labels +} + +// Apply adds EC2 labels to the provided resource. +func (l *EC2) Apply(r types.ResourceWithLabels) { + labels := l.Get() + for k, v := range r.GetStaticLabels() { + labels[k] = v + } + r.SetStaticLabels(labels) +} + +// Sync will block and synchronously update EC2 labels. +func (l *EC2) Sync(ctx context.Context) error { + m := make(map[string]string) + + tags, err := l.c.Client.GetTagKeys(ctx) + if err != nil { + return trace.Wrap(err) + } + + for _, t := range tags { + value, err := l.c.Client.GetTagValue(ctx, t) + if err != nil { + return trace.Wrap(err) + } + m[t] = value + } + + l.mu.Lock() + defer l.mu.Unlock() + l.labels = toAWSLabels(m) + + return nil +} + +// Start will start a loop that continually keeps EC2 labels updated. +func (l *EC2) Start(ctx context.Context) { + go l.periodicUpdateLabels(ctx) +} + +func (l *EC2) periodicUpdateLabels(ctx context.Context) { + ticker := l.c.Clock.NewTicker(ec2LabelUpdatePeriod) + defer ticker.Stop() + + for { + if err := l.Sync(ctx); err != nil { + l.c.Log.Errorf("Error fetching EC2 tags: %v", err) + } + select { + case <-ticker.Chan(): + case <-ctx.Done(): + return + } + } +} + +// toAWSLabels formats labels coming from EC2. +func toAWSLabels(labels map[string]string) map[string]string { + m := make(map[string]string, len(labels)) + for k, v := range labels { + m[fmt.Sprintf("%s/%s", AWSNamespace, k)] = v + } + return m +} diff --git a/lib/labels/ec2/ec2_test.go b/lib/labels/ec2/ec2_test.go new file mode 100644 index 0000000000000..a86a302b11bd6 --- /dev/null +++ b/lib/labels/ec2/ec2_test.go @@ -0,0 +1,107 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package ec2 + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +type mockIMDSClient struct { + tags map[string]string +} + +func (m *mockIMDSClient) IsAvailable(ctx context.Context) bool { + return true +} + +func (m *mockIMDSClient) GetTagKeys(ctx context.Context) ([]string, error) { + keys := make([]string, 0, len(m.tags)) + for k := range m.tags { + keys = append(keys, k) + } + return keys, nil +} + +func (m *mockIMDSClient) GetTagValue(ctx context.Context, key string) (string, error) { + if value, ok := m.tags[key]; ok { + return value, nil + } + return "", trace.NotFound("Tag %q not found", key) +} + +func TestEC2LabelsSync(t *testing.T) { + ctx := context.Background() + tags := map[string]string{"a": "1", "b": "2"} + imdsClient := &mockIMDSClient{ + tags: tags, + } + ec2Labels, err := New(ctx, &Config{ + Client: imdsClient, + }) + require.NoError(t, err) + require.NoError(t, ec2Labels.Sync(ctx)) + require.Equal(t, toAWSLabels(tags), ec2Labels.Get()) +} + +func TestEC2LabelsAsync(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + imdsClient := &mockIMDSClient{} + clock := clockwork.NewFakeClock() + ec2Labels, err := New(ctx, &Config{ + Client: imdsClient, + Clock: clock, + }) + require.NoError(t, err) + + compareLabels := func(m map[string]string) func() bool { + return func() bool { + labels := ec2Labels.Get() + if len(labels) != len(m) { + return false + } + for k, v := range labels { + if m[k] != v { + return false + } + } + return true + } + } + + // Check that initial tags are read. + initialTags := map[string]string{"a": "1", "b": "2"} + imdsClient.tags = initialTags + ec2Labels.Start(ctx) + require.Eventually(t, compareLabels(toAWSLabels(initialTags)), time.Second, 100*time.Microsecond) + + // Check that tags are updated over time. + updatedTags := map[string]string{"a": "3", "c": "4"} + imdsClient.tags = updatedTags + clock.Advance(ec2LabelUpdatePeriod) + require.Eventually(t, compareLabels(toAWSLabels(updatedTags)), time.Second, 100*time.Millisecond) + + // Check that service stops updating when closed. + cancel() + imdsClient.tags = map[string]string{"x": "8", "y": "9", "z": "10"} + clock.Advance(ec2LabelUpdatePeriod) + require.Eventually(t, compareLabels(toAWSLabels(updatedTags)), time.Second, 100*time.Millisecond) +} diff --git a/lib/labels/labels.go b/lib/labels/labels.go index 3fce155e1e6cb..e93f3ac0560e7 100644 --- a/lib/labels/labels.go +++ b/lib/labels/labels.go @@ -166,3 +166,16 @@ func (l *Dynamic) setLabel(name string, value types.CommandLabel) { l.c.Labels[name] = value } + +// Importer is an interface for labels imported from an external source, +// such as a cloud provider. +type Importer interface { + // Get returns the current labels. + Get() map[string]string + // Apply adds the current labels to the provided resource's static labels. + Apply(r types.ResourceWithLabels) + // Sync blocks and synchronously updates the labels. + Sync(context.Context) error + // Start starts a loop that continually keeps the labels updated. + Start(context.Context) +} diff --git a/lib/labels/labels_test.go b/lib/labels/labels_test.go index 509c76cdb4418..e3097dd14629d 100644 --- a/lib/labels/labels_test.go +++ b/lib/labels/labels_test.go @@ -19,15 +19,14 @@ package labels import ( "context" "os" - "strings" "testing" "time" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/utils" + "github.com/stretchr/testify/require" "github.com/google/uuid" - "gopkg.in/check.v1" ) func TestMain(m *testing.M) { @@ -35,14 +34,7 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -type LabelSuite struct { -} - -var _ = check.Suite(&LabelSuite{}) - -func TestLabels(t *testing.T) { check.TestingT(t) } - -func (s *LabelSuite) TestSync(c *check.C) { +func TestSync(t *testing.T) { // Create dynamic labels and sync right away. l, err := NewDynamic(context.Background(), &DynamicConfig{ Labels: map[string]types.CommandLabel{ @@ -52,14 +44,14 @@ func (s *LabelSuite) TestSync(c *check.C) { }, }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) l.Sync() // Check that the result contains the output of the command. - c.Assert(l.Get()["foo"].GetResult(), check.Equals, "4") + require.Equal(t, "4", l.Get()["foo"].GetResult()) } -func (s *LabelSuite) TestStart(c *check.C) { +func TestStart(t *testing.T) { // Create dynamic labels and setup async update. l, err := NewDynamic(context.Background(), &DynamicConfig{ Labels: map[string]types.CommandLabel{ @@ -69,24 +61,18 @@ func (s *LabelSuite) TestStart(c *check.C) { }, }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) l.Start() - // Wait a maximum of 5 seconds for dynamic labels to be updated. - select { - case <-time.Tick(50 * time.Millisecond): + require.Eventually(t, func() bool { val, ok := l.Get()["foo"] - c.Assert(ok, check.Equals, true) - if val.GetResult() == "4" { - break - } - case <-time.After(5 * time.Second): - c.Fatalf("Timed out waiting for label to be updated.") - } + require.True(t, ok) + return val.GetResult() == "4" + }, 5*time.Second, 50*time.Millisecond) } // TestInvalidCommand makes sure that invalid commands return a error message. -func (s *LabelSuite) TestInvalidCommand(c *check.C) { +func TestInvalidCommand(t *testing.T) { // Create invalid labels and sync right away. l, err := NewDynamic(context.Background(), &DynamicConfig{ Labels: map[string]types.CommandLabel{ @@ -95,11 +81,11 @@ func (s *LabelSuite) TestInvalidCommand(c *check.C) { Command: []string{uuid.New().String()}}, }, }) - c.Assert(err, check.IsNil) + require.NoError(t, err) l.Sync() // Check that the output contains that the command was not found. val, ok := l.Get()["foo"] - c.Assert(ok, check.Equals, true) - c.Assert(strings.Contains(val.GetResult(), "output:"), check.Equals, true) + require.True(t, ok) + require.Contains(t, val.GetResult(), "output:") } diff --git a/lib/service/db.go b/lib/service/db.go index c99233c4b85e6..40ca551226142 100644 --- a/lib/service/db.go +++ b/lib/service/db.go @@ -204,6 +204,7 @@ func (process *TeleportProcess) initDatabaseService() (retErr error) { Hostname: process.Config.Hostname, HostID: process.Config.HostUUID, Databases: databases, + CloudLabels: process.cloudLabels, ResourceMatchers: process.Config.Databases.ResourceMatchers, AWSMatchers: process.Config.Databases.AWSMatchers, OnHeartbeat: process.onHeartbeat(teleport.ComponentDatabase), diff --git a/lib/service/kubernetes.go b/lib/service/kubernetes.go index 314b4325c6050..9cd2e69c7faf5 100644 --- a/lib/service/kubernetes.go +++ b/lib/service/kubernetes.go @@ -250,6 +250,7 @@ func (process *TeleportProcess) initKubernetesService(log *logrus.Entry, conn *C Component: teleport.ComponentKube, StaticLabels: cfg.Kube.StaticLabels, DynamicLabels: dynLabels, + CloudLabels: process.cloudLabels, LockWatcher: lockWatcher, CheckImpersonationPermissions: cfg.Kube.CheckImpersonationPermissions, PublicAddr: publicAddr, diff --git a/lib/service/service.go b/lib/service/service.go index 9ca42d089513a..217fa5adac444 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -60,6 +60,7 @@ import ( "github.com/gravitational/teleport/lib/backend/postgres" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/cache" + "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/dynamoevents" @@ -69,6 +70,8 @@ import ( "github.com/gravitational/teleport/lib/events/s3sessions" "github.com/gravitational/teleport/lib/joinserver" kubeproxy "github.com/gravitational/teleport/lib/kube/proxy" + "github.com/gravitational/teleport/lib/labels" + "github.com/gravitational/teleport/lib/labels/ec2" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/multiplexer" @@ -335,6 +338,9 @@ type TeleportProcess struct { // clusterFeatures contain flags for supported and unsupported features. clusterFeatures proto.Features + // cloudLabels is a set of labels imported from a cloud provider and shared between + // services. + cloudLabels labels.Importer // TracingProvider is the provider to be used for exporting traces. In the event // that tracing is disabled this will be a no-op provider that drops all spans. TracingProvider *tracing.Provider @@ -345,6 +351,20 @@ type keyPairKey struct { reason string } +// newTeleportConfig provides extra options to NewTeleport(). +type newTeleportConfig struct { + imdsClient aws.InstanceMetadata +} + +type NewTeleportOption func(*newTeleportConfig) + +// WithIMDSClient provides NewTeleport with a custom EC2 instance metadata client. +func WithIMDSClient(client aws.InstanceMetadata) NewTeleportOption { + return func(c *newTeleportConfig) { + c.imdsClient = client + } +} + // processIndex is an internal process index // to help differentiate between two different teleport processes // during in-process reload. @@ -617,7 +637,11 @@ func waitAndReload(ctx context.Context, cfg Config, srv Process, newTeleport New // NewTeleport takes the daemon configuration, instantiates all required services // and starts them under a supervisor, returning the supervisor object. -func NewTeleport(cfg *Config) (*TeleportProcess, error) { +func NewTeleport(cfg *Config, opts ...NewTeleportOption) (*TeleportProcess, error) { + newTeleportConf := &newTeleportConfig{} + for _, opt := range opts { + opt(newTeleportConf) + } var err error // Before we do anything reset the SIGINT handler back to the default. @@ -720,16 +744,6 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { cfg.AuthServers = []utils.NetAddr{cfg.Auth.SSHAddr} } - // if user did not provide auth domain name, use this host's name - if cfg.Auth.Enabled && cfg.Auth.ClusterName == nil { - cfg.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ - ClusterName: cfg.Hostname, - }) - if err != nil { - return nil, trace.Wrap(err) - } - } - processID := fmt.Sprintf("%v", nextProcessID()) supervisor := NewSupervisor(processID, cfg.Log) storage, err := auth.NewProcessStorage(supervisor.ExitContext(), filepath.Join(cfg.DataDir, teleport.ComponentProcess)) @@ -745,6 +759,52 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { cfg.PluginRegistry = plugin.NewRegistry() } + var cloudLabels labels.Importer + + // Check if we're on an EC2 instance, and if we should override the node's hostname. + imClient := newTeleportConf.imdsClient + if imClient == nil { + imClient, err = utils.NewInstanceMetadataClient(supervisor.ExitContext()) + if err != nil { + return nil, trace.Wrap(err) + } + } + + if imClient.IsAvailable(supervisor.ExitContext()) { + ec2Hostname, err := imClient.GetTagValue(supervisor.ExitContext(), types.EC2HostnameTag) + if err == nil { + if ec2Hostname != "" { + cfg.Log.Info("Found %q tag in EC2 instance. Using %q as hostname.", types.EC2HostnameTag, ec2Hostname) + cfg.Hostname = ec2Hostname + } + } else if !trace.IsNotFound(err) { + cfg.Log.Errorf("Unexpected error while looking for EC2 hostname: %v", err) + } + + ec2Labels, err := ec2.New(supervisor.ExitContext(), &ec2.Config{ + Client: imClient, + Clock: cfg.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + cloudLabels = ec2Labels + } + + if cloudLabels != nil { + cloudLabels.Start(supervisor.ExitContext()) + } + + // if user did not provide auth domain name, use this host's name + if cfg.Auth.Enabled && cfg.Auth.ClusterName == nil { + cfg.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{ + ClusterName: cfg.Hostname, + }) + if err != nil { + return nil, trace.Wrap(err) + } + } + process := &TeleportProcess{ PluginRegistry: cfg.PluginRegistry, Clock: cfg.Clock, @@ -757,6 +817,7 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { id: processID, keyPairs: make(map[keyPairKey]KeyPair), appDependCh: make(chan Event, 1024), + cloudLabels: cloudLabels, TracingProvider: tracing.NoopProvider(), } @@ -1953,7 +2014,7 @@ func (process *TeleportProcess) initSSH() error { regular.SetShell(cfg.SSH.Shell), regular.SetEmitter(&events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer}), regular.SetSessionServer(conn.Client), - regular.SetLabels(cfg.SSH.Labels, cfg.SSH.CmdLabels), + regular.SetLabels(cfg.SSH.Labels, cfg.SSH.CmdLabels, process.cloudLabels), regular.SetNamespace(namespace), regular.SetPermitUserEnvironment(cfg.SSH.PermitUserEnvironment), regular.SetCiphers(cfg.Ciphers), @@ -3812,6 +3873,7 @@ func (process *TeleportProcess) initApps() { Hostname: process.Config.Hostname, GetRotation: process.getRotation, Apps: applications, + CloudLabels: process.cloudLabels, ResourceMatchers: process.Config.Apps.ResourceMatchers, OnHeartbeat: process.onHeartbeat(teleport.ComponentApp), }) diff --git a/lib/services/server.go b/lib/services/server.go index ca15aff44e016..e2d363bb38db8 100644 --- a/lib/services/server.go +++ b/lib/services/server.go @@ -91,7 +91,7 @@ func compareServers(a, b types.Server) int { if a.GetUseTunnel() != b.GetUseTunnel() { return Different } - if !utils.StringMapsEqual(a.GetLabels(), b.GetLabels()) { + if !utils.StringMapsEqual(a.GetStaticLabels(), b.GetStaticLabels()) { return Different } if !cmp.Equal(a.GetCmdLabels(), b.GetCmdLabels()) { diff --git a/lib/services/watcher.go b/lib/services/watcher.go index a8f048247ede4..5cfa40aa42c46 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -1198,8 +1198,6 @@ type Node interface { GetHostname() string // GetNamespace returns server namespace GetNamespace() string - // GetLabels returns server's static label key pairs - GetLabels() map[string]string // GetCmdLabels gets command labels GetCmdLabels() map[string]types.CommandLabel // GetPublicAddr is an optional field that returns the public address this cluster can be reached at. diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index 6b90c788c9247..78eceaf7684ec 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -87,6 +87,10 @@ type Config struct { // Apps is a list of statically registered apps this agent proxies. Apps types.Apps + // CloudLabels is a service that imports labels from a cloud provider. The labels are shared + // between all apps. + CloudLabels labels.Importer + // OnHeartbeat is called after every heartbeat. Used to update process state. OnHeartbeat func(error) @@ -400,6 +404,9 @@ func (s *Server) getServerInfo(app types.Application) (types.Resource, error) { if labels != nil { copy.SetDynamicLabels(labels.Get()) } + if s.c.CloudLabels != nil { + s.c.CloudLabels.Apply(copy) + } expires := s.c.Clock.Now().UTC().Add(apidefaults.ServerAnnounceTTL) return types.NewAppServerV3(types.Metadata{ Name: copy.GetName(), @@ -493,7 +500,6 @@ func (s *Server) Start(ctx context.Context) (err error) { if s.watcher, err = s.startResourceWatcher(ctx); err != nil { return trace.Wrap(err) } - return nil } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 3f40c9b992668..2aa5aecd5e283 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -85,6 +85,9 @@ type Config struct { AWSMatchers []services.AWSMatcher // Databases is a list of proxied databases from static configuration. Databases types.Databases + // CloudLabels is a service that imports labels from a cloud provider. The labels are shared + // between all databases. + CloudLabels labels.Importer // OnHeartbeat is called after every heartbeat. Used to update process state. OnHeartbeat func(error) // OnReconcile is called after each database resource reconciliation. @@ -529,6 +532,9 @@ func (s *Server) getServerInfo(database types.Database) (types.Resource, error) if labels != nil { copy.SetDynamicLabels(labels.Get()) } + if s.cfg.CloudLabels != nil { + s.cfg.CloudLabels.Apply(copy) + } expires := s.cfg.Clock.Now().UTC().Add(apidefaults.ServerAnnounceTTL) return types.NewDatabaseServerV3(types.Metadata{ Name: copy.GetName(), diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index ac72f99b4d7b2..f846381e20914 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -99,6 +99,9 @@ type Server struct { // dynamicLabels are the result of command execution. dynamicLabels *labels.Dynamic + // cloudLabels are the labels imported from a cloud provider. + cloudLabels labels.Importer + proxyMode bool proxyTun reversetunnel.Tunnel proxyAccessPoint auth.ReadProxyAccessPoint @@ -333,6 +336,10 @@ func (s *Server) Serve(l net.Listener) error { go s.dynamicLabels.Start() } + if s.cloudLabels != nil { + s.cloudLabels.Start(s.Context()) + } + go s.heartbeat.Run() return s.srv.Serve(l) } @@ -409,7 +416,7 @@ func SetProxyMode(tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint) Serve // SetLabels sets dynamic and static labels that server will report to the // auth servers. -func SetLabels(staticLabels map[string]string, cmdLabels services.CommandLabels) ServerOption { +func SetLabels(staticLabels map[string]string, cmdLabels services.CommandLabels, cloudLabels labels.Importer) ServerOption { return func(s *Server) error { var err error @@ -433,7 +440,7 @@ func SetLabels(staticLabels map[string]string, cmdLabels services.CommandLabels) if err != nil { return trace.Wrap(err) } - + s.cloudLabels = cloudLabels return nil } } @@ -805,6 +812,20 @@ func (s *Server) getRole() types.SystemRole { return types.RoleNode } +// getStaticLabels gets the labels that the server should present as static, +// which includes EC2 labels if available. +func (s *Server) getStaticLabels() map[string]string { + if s.cloudLabels == nil { + return s.labels + } + labels := s.cloudLabels.Get() + // Let static labels override ec2 labels if they conflict. + for k, v := range s.labels { + labels[k] = v + } + return labels +} + // getDynamicLabels returns all dynamic labels. If no dynamic labels are // defined, return an empty set. func (s *Server) getDynamicLabels() map[string]types.CommandLabelV2 { @@ -828,7 +849,7 @@ func (s *Server) GetInfo() types.Server { Metadata: types.Metadata{ Name: s.ID(), Namespace: s.getNamespace(), - Labels: s.labels, + Labels: s.getStaticLabels(), }, Spec: types.ServerSpecV2{ CmdLabels: s.getDynamicLabels(), diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 5bdc97e03546f..79f42ddbbc36e 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -172,7 +172,7 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO Period: types.NewDuration(time.Millisecond), Command: []string{"expr", "1", "+", "3"}, }, - }, + }, nil, ), SetBPF(&bpf.NOP{}), SetRestrictedSessionManager(&restricted.NOP{}), diff --git a/lib/teleterm/apiserver/handler/handler_servers.go b/lib/teleterm/apiserver/handler/handler_servers.go index 3a883c010dab6..aa8963addc5f9 100644 --- a/lib/teleterm/apiserver/handler/handler_servers.go +++ b/lib/teleterm/apiserver/handler/handler_servers.go @@ -41,7 +41,7 @@ func (s *Handler) ListServers(ctx context.Context, req *api.ListServersRequest) func newAPIServer(server clusters.Server) *api.Server { apiLabels := APILabels{} - serverLabels := server.GetLabels() + serverLabels := server.GetStaticLabels() for name, value := range serverLabels { apiLabels = append(apiLabels, &api.Label{ Name: name, diff --git a/lib/utils/ec2.go b/lib/utils/ec2.go index 666344b6e9041..5fb2af3633259 100644 --- a/lib/utils/ec2.go +++ b/lib/utils/ec2.go @@ -18,14 +18,20 @@ package utils import ( "context" + "fmt" "io" "regexp" + "strings" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/trace" ) +// metadataReadLimit is the largest number of bytes that will be read from imds responses. +const metadataReadLimit = 1_000_000 + // GetEC2IdentityDocument fetches the PKCS7 RSA2048 InstanceIdentityDocument // from the IMDS for this EC2 instance. func GetEC2IdentityDocument() ([]byte, error) { @@ -84,3 +90,57 @@ func IsEC2NodeID(id string) bool { func NodeIDFromIID(iid *imds.InstanceIdentityDocument) string { return iid.AccountID + "-" + iid.InstanceID } + +// InstanceMetadataClient is a wrapper for an imds.Client. +type InstanceMetadataClient struct { + c *imds.Client +} + +// NewInstanceMetadataClient creates a new instance metadata client. +func NewInstanceMetadataClient(ctx context.Context) (*InstanceMetadataClient, error) { + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return &InstanceMetadataClient{ + c: imds.NewFromConfig(cfg), + }, nil +} + +// IsAvailable checks if instance metadata is available. +func (client *InstanceMetadataClient) IsAvailable(ctx context.Context) bool { + _, err := client.getMetadata(ctx, "") + return err == nil +} + +// getMetadata gets the raw metadata from a specified path. +func (client *InstanceMetadataClient) getMetadata(ctx context.Context, path string) (string, error) { + output, err := client.c.GetMetadata(ctx, &imds.GetMetadataInput{Path: path}) + if err != nil { + return "", trace.Wrap(aws.ParseMetadataClientError(err)) + } + defer output.Content.Close() + body, err := ReadAtMost(output.Content, metadataReadLimit) + if err != nil { + return "", trace.Wrap(err) + } + return string(body), nil +} + +// GetTagKeys gets all of the EC2 tag keys. +func (client *InstanceMetadataClient) GetTagKeys(ctx context.Context) ([]string, error) { + body, err := client.getMetadata(ctx, "tags/instance") + if err != nil { + return nil, trace.Wrap(err) + } + return strings.Split(body, "\n"), nil +} + +// GetTagValue gets the value for a specified tag key. +func (client *InstanceMetadataClient) GetTagValue(ctx context.Context, key string) (string, error) { + body, err := client.getMetadata(ctx, fmt.Sprintf("tags/instance/%s", key)) + if err != nil { + return "", trace.Wrap(err) + } + return body, nil +} diff --git a/lib/web/ui/server.go b/lib/web/ui/server.go index 144f9c6aec2f3..ac8c714c15e71 100644 --- a/lib/web/ui/server.go +++ b/lib/web/ui/server.go @@ -70,7 +70,7 @@ func MakeServers(clusterName string, servers []types.Server) []Server { uiServers := []Server{} for _, server := range servers { uiLabels := []Label{} - serverLabels := server.GetLabels() + serverLabels := server.GetStaticLabels() for name, value := range serverLabels { uiLabels = append(uiLabels, Label{ Name: name,