Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions lib/services/unified_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,19 @@ func (c *UnifiedResourceCache) put(ctx context.Context, resource resource) error
c.mu.Lock()
defer c.mu.Unlock()
key := resourceKey(resource)
c.resources[key] = resource
sortKey := makeResourceSortKey(resource)
oldResource, exists := c.resources[key]
if exists {
// If the resource has changed in such a way that the sort keys
// for the nameTree or typeTree change, remove the old entries
// from those trees before adding a new one. This can happen
// when a node's hostname changes
oldSortKey := makeResourceSortKey(oldResource)
if string(oldSortKey.byName) != string(sortKey.byName) {
c.deleteSortKey(oldSortKey)
}
}
c.resources[key] = resource
c.nameTree.ReplaceOrInsert(&item{Key: sortKey.byName, Value: key})
c.typeTree.ReplaceOrInsert(&item{Key: sortKey.byType, Value: key})
return nil
Expand All @@ -144,6 +155,16 @@ func putResources[T resource](cache *UnifiedResourceCache, resources []T) {
}
}

func (c *UnifiedResourceCache) deleteSortKey(sortKey resourceSortKey) error {
if _, ok := c.nameTree.Delete(&item{Key: sortKey.byName}); !ok {
return trace.NotFound("key %q is not found in unified cache name sort tree", string(sortKey.byName))
}
if _, ok := c.typeTree.Delete(&item{Key: sortKey.byType}); !ok {
return trace.NotFound("key %q is not found in unified cache type sort tree", string(sortKey.byType))
}
return nil
}

// delete removes the item by key, returns NotFound error
// if item does not exist
func (c *UnifiedResourceCache) delete(ctx context.Context, res types.Resource) error {
Expand All @@ -159,12 +180,7 @@ func (c *UnifiedResourceCache) delete(ctx context.Context, res types.Resource) e
sortKey := makeResourceSortKey(resource)

return c.read(ctx, func(cache *UnifiedResourceCache) error {
if _, ok := cache.nameTree.Delete(&item{Key: sortKey.byName}); !ok {
return trace.NotFound("key %q is not found in unified cache name sort tree", string(sortKey.byName))
}
if _, ok := cache.typeTree.Delete(&item{Key: sortKey.byType}); !ok {
return trace.NotFound("key %q is not found in unified cache type sort tree", string(sortKey.byType))
}
cache.deleteSortKey(sortKey)
// delete from resource map
delete(c.resources, key)
return nil
Expand Down
62 changes: 59 additions & 3 deletions lib/services/unified_resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestUnifiedResourceWatcher(t *testing.T) {
Events: local.NewEventsService(bk),
}
// Add node to the backend.
node := newNodeServer(t, "node1", "127.0.0.1:22", false /*tunnel*/)
node := newNodeServer(t, "node1", "hostname1", "127.0.0.1:22", false /*tunnel*/)
_, err = clt.UpsertNode(ctx, node)
require.NoError(t, err)

Expand Down Expand Up @@ -152,7 +152,7 @@ func TestUnifiedResourceWatcher(t *testing.T) {
))

// // Update and remove some resources.
nodeUpdated := newNodeServer(t, "node1", "192.168.0.1:22", false /*tunnel*/)
nodeUpdated := newNodeServer(t, "node1", "hostname1", "192.168.0.1:22", false /*tunnel*/)
_, err = clt.UpsertNode(ctx, nodeUpdated)
require.NoError(t, err)
err = clt.DeleteApplicationServer(ctx, defaults.Namespace, "app1-host-id", "app1")
Expand Down Expand Up @@ -180,6 +180,62 @@ func TestUnifiedResourceWatcher(t *testing.T) {
))
}

func TestUnifiedResourceWatcher_PreventDuplicates(t *testing.T) {
t.Parallel()

ctx := context.Background()

bk, err := memory.New(memory.Config{})
require.NoError(t, err)

type client struct {
services.Presence
services.WindowsDesktops
services.SAMLIdPServiceProviders
types.Events
}

samlService, err := local.NewSAMLIdPServiceProviderService(bk)
require.NoError(t, err)

clt := &client{
Presence: local.NewPresenceService(bk),
WindowsDesktops: local.NewWindowsDesktopService(bk),
SAMLIdPServiceProviders: samlService,
Events: local.NewEventsService(bk),
}
w, err := services.NewUnifiedResourceCache(ctx, services.UnifiedResourceCacheConfig{
ResourceWatcherConfig: services.ResourceWatcherConfig{
Component: teleport.ComponentUnifiedResource,
Client: clt,
},
ResourceGetter: clt,
})
require.NoError(t, err)

// add a node
node := newNodeServer(t, "node1", "hostname1", "127.0.0.1:22", false /*tunnel*/)
_, err = clt.UpsertNode(ctx, node)
require.NoError(t, err)

assert.Eventually(t, func() bool {
res, _ := w.GetUnifiedResources(ctx)
return len(res) == 1
}, 5*time.Second, 10*time.Millisecond, "Timed out waiting for unified resources to be added")

// update a node
updatedNode := newNodeServer(t, "node1", "hostname2", "127.0.0.1:22", false /*tunnel*/)
_, err = clt.UpsertNode(ctx, updatedNode)
require.NoError(t, err)

// only one resource should still exists with the name "node1" (with hostname updated)
assert.Eventually(t, func() bool {
res, _ := w.GetUnifiedResources(ctx)
return len(res) == 1
}, 5*time.Second, 10*time.Millisecond, "Timed out waiting for unified resources to be added")

}

func TestUnifiedResourceWatcher_DeleteEvent(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -214,7 +270,7 @@ func TestUnifiedResourceWatcher_DeleteEvent(t *testing.T) {
require.NoError(t, err)

// add a node
node := newNodeServer(t, "node1", "127.0.0.1:22", false /*tunnel*/)
node := newNodeServer(t, "node1", "hostname1", "127.0.0.1:22", false /*tunnel*/)
_, err = clt.UpsertNode(ctx, node)
require.NoError(t, err)

Expand Down
7 changes: 4 additions & 3 deletions lib/services/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ func TestNodeWatcherFallback(t *testing.T) {
// Add some servers.
nodes := make([]types.Server, 0, 5)
for i := 0; i < 5; i++ {
node := newNodeServer(t, fmt.Sprintf("node%d", i), "127.0.0.1:2023", i%2 == 0)
node := newNodeServer(t, fmt.Sprintf("node%d", i), fmt.Sprintf("hostname%d", i), "127.0.0.1:2023", i%2 == 0)
_, err = presence.UpsertNode(ctx, node)
require.NoError(t, err)
nodes = append(nodes, node)
Expand Down Expand Up @@ -962,7 +962,7 @@ func TestNodeWatcher(t *testing.T) {
// Add some node servers.
nodes := make([]types.Server, 0, 5)
for i := 0; i < 5; i++ {
node := newNodeServer(t, fmt.Sprintf("node%d", i), "127.0.0.1:2023", i%2 == 0)
node := newNodeServer(t, fmt.Sprintf("node%d", i), fmt.Sprintf("hostname%d", i), "127.0.0.1:2023", i%2 == 0)
_, err = presence.UpsertNode(ctx, node)
require.NoError(t, err)
nodes = append(nodes, node)
Expand All @@ -989,10 +989,11 @@ func TestNodeWatcher(t *testing.T) {
require.Empty(t, w.GetNodes(ctx, func(n services.Node) bool { return n.GetName() == nodes[0].GetName() }))
}

func newNodeServer(t *testing.T, name, addr string, tunnel bool) types.Server {
func newNodeServer(t *testing.T, name, hostname, addr string, tunnel bool) types.Server {
s, err := types.NewServer(name, types.KindNode, types.ServerSpecV2{
Addr: addr,
UseTunnel: tunnel,
Hostname: hostname,
})
require.NoError(t, err)
return s
Expand Down