diff --git a/api/types/server_info.go b/api/types/server_info.go index 475aafc309f2b..9fb2356ac0cb4 100644 --- a/api/types/server_info.go +++ b/api/types/server_info.go @@ -166,6 +166,7 @@ func (s *ServerInfoV1) SetNewLabels(labels map[string]string) { func (s *ServerInfoV1) setStaticFields() { s.Kind = KindServerInfo s.Version = V1 + s.SubKind = SubKindCloudInfo } // CheckAndSetDefaults validates the Resource and sets any empty fields to @@ -175,8 +176,14 @@ func (s *ServerInfoV1) CheckAndSetDefaults() error { return trace.Wrap(s.Metadata.CheckAndSetDefaults()) } -// GetServerInfoName gets the name of the ServerInfo generated for a discovered -// EC2 instance with this account ID and instance ID. -func (a *AWSInfo) GetServerInfoName() string { - return fmt.Sprintf("aws-%v-%v", a.AccountID, a.InstanceID) +// ServerInfoNameFromAWS gets the name of the ServerInfo that matches the node +// with the given AWS account ID and instance ID. +func ServerInfoNameFromAWS(accountID, instanceID string) string { + return fmt.Sprintf("aws-%v-%v", accountID, instanceID) +} + +// ServerInfoNameFromNodeName gets the name of the ServerInfo that matches the +// node with the given name. +func ServerInfoNameFromNodeName(name string) string { + return fmt.Sprintf("si-%v", name) } diff --git a/lib/auth/server_info.go b/lib/auth/server_info.go index fc58f291caacf..7fdfb08ff13c1 100644 --- a/lib/auth/server_info.go +++ b/lib/auth/server_info.go @@ -18,6 +18,7 @@ package auth import ( "context" + "maps" "time" "github.com/gravitational/trace" @@ -44,7 +45,7 @@ func (a *Server) ReconcileServerInfos(ctx context.Context) error { for moreNodes := true; moreNodes; { nodes, moreNodes = stream.Take(nodeStream, batchSize) - updates, err := a.setCloudLabelsOnNodes(ctx, nodes) + updates, err := a.setLabelsOnNodes(ctx, nodes) if err != nil { return trace.Wrap(err) } @@ -70,32 +71,57 @@ func (a *Server) ReconcileServerInfos(ctx context.Context) error { } } -func (a *Server) setCloudLabelsOnNodes(ctx context.Context, nodes []types.Server) (failedUpdates int, err error) { +// getServerInfoNames gets the names of ServerInfos that could exist for a +// node. The list of names returned are ordered such that later ServerInfos +// override earlier ones on conflicting labels. +func getServerInfoNames(node types.Server) []string { + var names []string + if meta := node.GetCloudMetadata(); meta != nil && meta.AWS != nil { + names = append(names, types.ServerInfoNameFromAWS(meta.AWS.AccountID, meta.AWS.InstanceID)) + } + // ServerInfos matched by node name should override any ServerInfos created + // by the discovery service. + return append(names, types.ServerInfoNameFromNodeName(node.GetName())) +} + +func (a *Server) setLabelsOnNodes(ctx context.Context, nodes []types.Server) (failedUpdates int, err error) { for _, node := range nodes { - meta := node.GetCloudMetadata() - if meta != nil && meta.AWS != nil { - si, err := a.GetServerInfo(ctx, meta.AWS.GetServerInfoName()) + // Get the server infos that match this node. + serverInfoNames := getServerInfoNames(node) + serverInfos := make([]types.ServerInfo, 0, len(serverInfoNames)) + for _, name := range serverInfoNames { + si, err := a.GetServerInfo(ctx, name) if err == nil { - err := a.updateLabelsOnNode(ctx, node, si) - // Didn't find control stream for node, save count for logging. - if trace.IsNotFound(err) { - failedUpdates++ - } else if err != nil { - return failedUpdates, trace.Wrap(err) - } + serverInfos = append(serverInfos, si) } else if !trace.IsNotFound(err) { return failedUpdates, trace.Wrap(err) } } + if len(serverInfos) == 0 { + continue + } + + // Didn't find control stream for node, save count for logging. + if err := a.updateLabelsOnNode(ctx, node, serverInfos); trace.IsNotFound(err) { + failedUpdates++ + } else if err != nil { + return failedUpdates, trace.Wrap(err) + } } return failedUpdates, nil } -func (a *Server) updateLabelsOnNode(ctx context.Context, node types.Server, si types.ServerInfo) error { +func (a *Server) updateLabelsOnNode(ctx context.Context, node types.Server, serverInfos []types.ServerInfo) error { + // Merge labels from server infos. Later label sets should override earlier + // ones if they conflict. + newLabels := make(map[string]string) + for _, si := range serverInfos { + maps.Copy(newLabels, si.GetNewLabels()) + } err := a.UpdateLabels(ctx, proto.InventoryUpdateLabelsRequest{ ServerID: node.GetName(), Kind: proto.LabelUpdateKind_SSHServerCloudLabels, - Labels: si.GetStaticLabels(), + Labels: newLabels, }) return trace.Wrap(err) } diff --git a/lib/auth/server_info_test.go b/lib/auth/server_info_test.go index 90249e8486458..61bd810fe7678 100644 --- a/lib/auth/server_info_test.go +++ b/lib/auth/server_info_test.go @@ -94,18 +94,25 @@ func TestReconcileLabels(t *testing.T) { require.NoError(t, err) // Update the server's labels. - labels := map[string]string{"a": "1", "b": "2"} - serverInfo, err := types.NewServerInfo(types.Metadata{ - Name: "aws-my-account-my-instance", - Labels: labels, - }, types.ServerInfoSpecV1{}) + awsServerInfo, err := types.NewServerInfo(types.Metadata{ + Name: types.ServerInfoNameFromAWS("my-account", "my-instance"), + }, types.ServerInfoSpecV1{ + NewLabels: map[string]string{"a": "1", "b": "2"}, + }) + require.NoError(t, err) + require.NoError(t, pack.a.UpsertServerInfo(ctx, awsServerInfo)) + + regularServerInfo, err := types.NewServerInfo(types.Metadata{ + Name: types.ServerInfoNameFromNodeName(serverName), + }, types.ServerInfoSpecV1{ + NewLabels: map[string]string{"b": "3", "c": "4"}, + }) require.NoError(t, err) - serverInfo.SetSubKind(types.SubKindCloudInfo) - require.NoError(t, pack.a.UpsertServerInfo(ctx, serverInfo)) + require.NoError(t, pack.a.UpsertServerInfo(ctx, regularServerInfo)) go pack.a.ReconcileServerInfos(ctx) // Wait until the reconciler finishes processing the serverinfo. clock.BlockUntil(1) // Check that labels were received downstream. - require.Equal(t, labels, upstream.updatedLabels) + require.Equal(t, map[string]string{"a": "1", "b": "3", "c": "4"}, upstream.updatedLabels) } diff --git a/lib/services/resource.go b/lib/services/resource.go index 4044d1a188944..aa95b04614a5b 100644 --- a/lib/services/resource.go +++ b/lib/services/resource.go @@ -218,6 +218,8 @@ func ParseShortcut(in string) (string, error) { return types.KindAuditQuery, nil case types.KindSecurityReport: return types.KindSecurityReport, nil + case types.KindServerInfo: + return types.KindServerInfo, nil } return "", trace.BadParameter("unsupported resource: %q - resources should be expressed as 'type/name', for example 'connector/github'", in) } diff --git a/lib/srv/server/ec2_watcher.go b/lib/srv/server/ec2_watcher.go index 4f76423bcef1c..e8f8c7c4061fa 100644 --- a/lib/srv/server/ec2_watcher.go +++ b/lib/srv/server/ec2_watcher.go @@ -96,14 +96,13 @@ func ToEC2Instances(insts []*ec2.Instance) []EC2Instance { func (i *EC2Instances) ServerInfos() ([]types.ServerInfo, error) { serverInfos := make([]types.ServerInfo, 0, len(i.Instances)) for _, instance := range i.Instances { - name := i.AccountID + "-" + instance.InstanceID tags := make(map[string]string, len(instance.Tags)) for k, v := range instance.Tags { tags[labels.FormatCloudLabelKey(labels.AWSLabelNamespace, k)] = v } si, err := types.NewServerInfo(types.Metadata{ - Name: name, + Name: types.ServerInfoNameFromAWS(i.AccountID, instance.InstanceID), }, types.ServerInfoSpecV1{ NewLabels: tags, }) diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go index d2ebe8e4eb0ef..24ebd6b7168fb 100644 --- a/lib/srv/server/ec2_watcher_test.go +++ b/lib/srv/server/ec2_watcher_test.go @@ -238,7 +238,7 @@ func TestEC2Watcher(t *testing.T) { func TestConvertEC2InstancesToServerInfos(t *testing.T) { t.Parallel() expected, err := types.NewServerInfo(types.Metadata{ - Name: "myaccount-myinstance", + Name: "aws-myaccount-myinstance", }, types.ServerInfoSpecV1{ NewLabels: map[string]string{"aws/foo": "bar"}, }) diff --git a/tool/tctl/common/collection.go b/tool/tctl/common/collection.go index 15c4531943c14..9eff80f37ee5b 100644 --- a/tool/tctl/common/collection.go +++ b/tool/tctl/common/collection.go @@ -1297,3 +1297,24 @@ func (c *securityReportCollection) writeText(w io.Writer, verbose bool) error { _, err := t.AsBuffer().WriteTo(w) return trace.Wrap(err) } + +type serverInfoCollection struct { + serverInfos []types.ServerInfo +} + +func (c *serverInfoCollection) resources() []types.Resource { + r := make([]types.Resource, len(c.serverInfos)) + for i, resource := range c.serverInfos { + r[i] = resource + } + return r +} + +func (c *serverInfoCollection) writeText(w io.Writer, verbose bool) error { + t := asciitable.MakeTable([]string{"Name", "Labels"}) + for _, si := range c.serverInfos { + t.AddRow([]string{si.GetName(), printMetadataLabels(si.GetNewLabels())}) + } + _, err := t.AsBuffer().WriteTo(w) + return trace.Wrap(err) +} diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 6439217791eb1..96f43fa33e54d 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -40,6 +40,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" loginrulepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/loginrule/v1" + "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/discoveryconfig" "github.com/gravitational/teleport/api/types/externalcloudaudit" @@ -138,6 +139,7 @@ func (rc *ResourceCommand) Initialize(app *kingpin.Application, config *servicec types.KindDiscoveryConfig: rc.createDiscoveryConfig, types.KindAuditQuery: rc.createAuditQuery, types.KindSecurityReport: rc.createSecurityReport, + types.KindServerInfo: rc.createServerInfo, } rc.UpdateHandlers = map[ResourceKind]ResourceCreateHandler{ types.KindUser: rc.updateUser, @@ -1121,6 +1123,34 @@ func (rc *ResourceCommand) createAccessList(ctx context.Context, client auth.Cli return nil } +func (rc *ResourceCommand) createServerInfo(ctx context.Context, client auth.ClientI, raw services.UnknownResource) error { + si, err := services.UnmarshalServerInfo(raw.Raw) + if err != nil { + return trace.Wrap(err) + } + + // Check if the ServerInfo already exists. + name := si.GetName() + _, err = client.GetServerInfo(ctx, name) + if err != nil && !trace.IsNotFound(err) { + return trace.Wrap(err) + } + + exists := (err == nil) + if !rc.force && exists { + return trace.AlreadyExists("server info %q already exists", name) + } + + err = client.UpsertServerInfo(ctx, si) + if err != nil { + return trace.Wrap(err) + } + fmt.Printf("Server info %q has been %s\n", + name, UpsertVerb(exists, rc.force), + ) + return nil +} + // Delete deletes resource by name func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err error) { singletonResources := []string{ @@ -1474,7 +1504,11 @@ func (rc *ResourceCommand) Delete(ctx context.Context, client auth.ClientI) (err return trace.Wrap(err) } fmt.Printf("Security report %q has been deleted\n", rc.ref.Name) - + case types.KindServerInfo: + if err := client.DeleteServerInfo(ctx, rc.ref.Name); err != nil { + return trace.Wrap(err) + } + fmt.Printf("Server info %q has been deleted\n", rc.ref.Name) default: return trace.BadParameter("deleting resources of type %q is not supported", rc.ref.Kind) } @@ -2307,6 +2341,19 @@ func (rc *ResourceCommand) getCollection(ctx context.Context, client auth.Client return nil, trace.Wrap(err) } return &securityReportCollection{items: resources}, nil + case types.KindServerInfo: + if rc.ref.Name != "" { + si, err := client.GetServerInfo(ctx, rc.ref.Name) + if err != nil { + return nil, trace.Wrap(err) + } + return &serverInfoCollection{serverInfos: []types.ServerInfo{si}}, nil + } + serverInfos, err := stream.Collect(client.GetServerInfos(ctx)) + if err != nil { + return nil, trace.Wrap(err) + } + return &serverInfoCollection{serverInfos: serverInfos}, nil } return nil, trace.BadParameter("getting %q is not supported", rc.ref.String()) } diff --git a/tool/tctl/common/resource_command_test.go b/tool/tctl/common/resource_command_test.go index 1ec89e89c474c..c2d3c45079e1b 100644 --- a/tool/tctl/common/resource_command_test.go +++ b/tool/tctl/common/resource_command_test.go @@ -1360,6 +1360,10 @@ func TestCreateResources(t *testing.T) { kind: types.KindRole, create: testCreateRole, }, + { + kind: types.KindServerInfo, + create: testCreateServerInfo, + }, } for _, test := range tests { @@ -1683,3 +1687,58 @@ version: v7 _, err = runResourceCommand(t, fc, []string{"create", "-f", roleYAMLPath}) require.NoError(t, err) } + +func testCreateServerInfo(t *testing.T, fc *config.FileConfig) { + // Ensure that our test server info does not exist + _, err := runResourceCommand(t, fc, []string{"get", types.KindServerInfo + "/test-server-info", "--format=json"}) + require.True(t, trace.IsNotFound(err), "expected test-role to not exist prior to being created") + + const serverInfoYAML = `--- +kind: server_info +sub_kind: cloud_info +version: v1 +metadata: + name: test-server-info +spec: + new_labels: + 'a': '1' + 'b': '2' +` + + // Create the server info + serverInfoYAMLPath := filepath.Join(t.TempDir(), "server-info.yaml") + err = os.WriteFile(serverInfoYAMLPath, []byte(serverInfoYAML), 0644) + require.NoError(t, err) + _, err = runResourceCommand(t, fc, []string{"create", serverInfoYAMLPath}) + require.NoError(t, err) + + // Fetch the server info + buf, err := runResourceCommand(t, fc, []string{"get", types.KindServerInfo + "/test-server-info", "--format=json"}) + require.NoError(t, err) + serverInfos := mustDecodeJSON[[]*types.ServerInfoV1](t, buf) + require.Len(t, serverInfos, 1) + + var expected types.ServerInfoV1 + err = yaml.Unmarshal([]byte(serverInfoYAML), &expected) + require.NoError(t, err) + + require.Empty(t, cmp.Diff( + []*types.ServerInfoV1{&expected}, + serverInfos, + cmpopts.IgnoreFields(types.Metadata{}, "ID", "Revision"), + )) + + // Explicitly change the revision and try creating the resource with and without + // the force flag. + expected.SetRevision(uuid.NewString()) + newRevisionServerInfo, err := services.MarshalServerInfo(&expected, services.PreserveResourceID()) + require.NoError(t, err) + err = os.WriteFile(serverInfoYAMLPath, newRevisionServerInfo, 0644) + require.NoError(t, err) + + _, err = runResourceCommand(t, fc, []string{"create", serverInfoYAMLPath}) + require.True(t, trace.IsAlreadyExists(err)) + + _, err = runResourceCommand(t, fc, []string{"create", "-f", serverInfoYAMLPath}) + require.NoError(t, err) +}