diff --git a/lib/ai/embedding.go b/lib/ai/embedding.go index 922ae31bf9136..4e2d6d5e37938 100644 --- a/lib/ai/embedding.go +++ b/lib/ai/embedding.go @@ -30,7 +30,7 @@ const ( maxOpenAIEmbeddingsPerRequest = 1000 // EmbeddingPeriod is the time between two embedding routines. // A seventh jitter is applied on the period. - EmbeddingPeriod = 15 * time.Minute + EmbeddingPeriod = time.Hour ) // EmbeddingHash is the hash function that should be used to compute embedding diff --git a/lib/ai/embeddings.go b/lib/ai/embeddings.go new file mode 100644 index 0000000000000..d05ee1e0b4ffd --- /dev/null +++ b/lib/ai/embeddings.go @@ -0,0 +1,302 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai + +import ( + "context" + "strings" + "time" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + "google.golang.org/protobuf/proto" + "gopkg.in/yaml.v3" + + "github.com/gravitational/teleport/api/defaults" + embeddingpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1" + "github.com/gravitational/teleport/api/internalutils/stream" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + streamutils "github.com/gravitational/teleport/lib/utils/stream" +) + +// Embeddings implements the minimal interface used by the Embedding processor. +type Embeddings interface { + // GetEmbeddings returns all embeddings for a given kind. + GetEmbeddings(ctx context.Context, kind string) stream.Stream[*Embedding] + // UpsertEmbedding creates or update a single ai.Embedding in the backend. + UpsertEmbedding(ctx context.Context, embedding *Embedding) (*Embedding, error) +} + +// NodesStreamGetter is a service that gets nodes. +type NodesStreamGetter interface { + // GetNodeStream returns a list of registered servers. + GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server] +} + +// MarshalEmbedding marshals the ai.Embedding resource to binary ProtoBuf. +func MarshalEmbedding(embedding *Embedding) ([]byte, error) { + data, err := proto.Marshal((*embeddingpb.Embedding)(embedding)) + if err != nil { + return nil, trace.Wrap(err) + } + return data, nil +} + +// UnmarshalEmbedding unmarshals binary ProtoBuf into an ai.Embedding resource. +func UnmarshalEmbedding(bytes []byte) (*Embedding, error) { + if len(bytes) == 0 { + return nil, trace.BadParameter("missing embedding data") + } + var embedding embeddingpb.Embedding + err := proto.Unmarshal(bytes, &embedding) + if err != nil { + return nil, trace.Wrap(err) + } + + return (*Embedding)(&embedding), nil +} + +// EmbeddingHashMatches returns true if the hash of the embedding matches the +// given hash. +func EmbeddingHashMatches(embedding *Embedding, hash Sha256Hash) bool { + if len(embedding.EmbeddedHash) != 32 { + return false + } + + return *(*Sha256Hash)(embedding.EmbeddedHash) == hash +} + +// serializeNode converts a type.Server into text ready to be fed to an +// embedding model. The YAML serialization function was chosen over JSON and +// CSV as it provided better results. +func serializeNode(node types.Server) ([]byte, error) { + a := struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + SubKind string `yaml:"subkind"` + Labels map[string]string `yaml:"labels"` + }{ + Name: node.GetName(), + Kind: types.KindNode, + SubKind: node.GetSubKind(), + Labels: node.GetAllLabels(), + } + text, err := yaml.Marshal(&a) + return text, trace.Wrap(err) +} + +// BatchReducer is a helper that processes data in batches. +type BatchReducer[T, V any] struct { + data []T + batchSize int + processFn func(ctx context.Context, data []T) (V, error) +} + +// NewBatchReducer is a BatchReducer constructor. +func NewBatchReducer[T, V any](processFn func(ctx context.Context, data []T) (V, error), batchSize int) *BatchReducer[T, V] { + return &BatchReducer[T, V]{ + data: make([]T, 0), + batchSize: batchSize, + processFn: processFn, + } +} + +// Add adds a new item to the batch. If the batch is full, it will be processed +// and the result will be returned. Otherwise, a zero value will be returned. +// Finalize must be called to process the remaining data in the batch. +func (b *BatchReducer[T, V]) Add(ctx context.Context, data T) (V, error) { + b.data = append(b.data, data) + if len(b.data) >= b.batchSize { + val, err := b.processFn(ctx, b.data) + b.data = b.data[:0] + return val, trace.Wrap(err) + } + + var def V + return def, nil +} + +// Finalize processes the remaining data in the batch and returns the result. +func (b *BatchReducer[T, V]) Finalize(ctx context.Context) (V, error) { + if len(b.data) > 0 { + val, err := b.processFn(ctx, b.data) + b.data = b.data[:0] + return val, trace.Wrap(err) + } + + var def V + return def, nil +} + +// EmbeddingProcessorConfig is the configuration for EmbeddingProcessor. +type EmbeddingProcessorConfig struct { + AIClient Embedder + EmbeddingSrv Embeddings + NodeSrv NodesStreamGetter + Log logrus.FieldLogger + Jitter retryutils.Jitter +} + +// EmbeddingProcessor is responsible for processing nodes, generating embeddings +// and storing their the embeddings in the backend. +type EmbeddingProcessor struct { + aiClient Embedder + embeddingSrv Embeddings + nodeSrv NodesStreamGetter + log logrus.FieldLogger + jitter retryutils.Jitter +} + +// NewEmbeddingProcessor returns a new EmbeddingProcessor. +func NewEmbeddingProcessor(cfg *EmbeddingProcessorConfig) *EmbeddingProcessor { + return &EmbeddingProcessor{ + aiClient: cfg.AIClient, + embeddingSrv: cfg.EmbeddingSrv, + nodeSrv: cfg.NodeSrv, + log: cfg.Log, + jitter: cfg.Jitter, + } +} + +// nodeStringPair is a helper struct that pairs a node with a data string. +type nodeStringPair struct { + node types.Server + data string +} + +// mapProcessFn is a helper function that maps a slice of nodeStringPair, +// compute embeddings and return them as a slice of ai.Embedding. +func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*nodeStringPair) ([]*Embedding, error) { + dataBatch := make([]string, 0, len(data)) + for _, pair := range data { + dataBatch = append(dataBatch, pair.data) + } + + embeddings, err := e.aiClient.ComputeEmbeddings(ctx, dataBatch) + if err != nil { + return nil, trace.Wrap(err) + } + + results := make([]*Embedding, 0, len(embeddings)) + for i, embedding := range embeddings { + emb := NewEmbedding(types.KindNode, + data[i].node.GetName(), embedding, + EmbeddingHash([]byte(data[i].data)), + ) + results = append(results, emb) + } + + return results, nil +} + +// Run runs the EmbeddingProcessor. +func (e *EmbeddingProcessor) Run(ctx context.Context, period time.Duration) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(e.jitter(period)): + e.process(ctx) + } + } +} + +func (e *EmbeddingProcessor) process(ctx context.Context) { + batch := NewBatchReducer[*nodeStringPair, []*Embedding](e.mapProcessFn, + 1000, // Max batch size allowed by OpenAI API, + ) + + embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode) + nodesStream := e.nodeSrv.GetNodeStream(ctx, defaults.Namespace) + + s := streamutils.NewZipStreams( + nodesStream, + embeddingsStream, + // On new node callback. Add the node to the batch. + func(node types.Server) error { + nodeData, err := serializeNode(node) + if err != nil { + return trace.Wrap(err) + } + vectors, err := batch.Add(ctx, &nodeStringPair{node, string(nodeData)}) + if err != nil { + return trace.Wrap(err) + } + if err := e.upsertEmbeddings(ctx, vectors); err != nil { + return trace.Wrap(err) + } + + return nil + }, + // On equal node callback. Check if the node's embedding hash matches + // the one in the backend. If not, add the node to the batch. + func(node types.Server, embedding *Embedding) error { + nodeData, err := serializeNode(node) + if err != nil { + return trace.Wrap(err) + } + nodeHash := EmbeddingHash(nodeData) + + if !EmbeddingHashMatches(embedding, nodeHash) { + vectors, err := batch.Add(ctx, &nodeStringPair{node, string(nodeData)}) + if err != nil { + return trace.Wrap(err) + } + if err := e.upsertEmbeddings(ctx, vectors); err != nil { + return trace.Wrap(err) + } + } + return nil + }, + // On compare keys callback. Compare the keys for iterration. + func(node types.Server, embeddings *Embedding) int { + if node.GetName() == embeddings.GetName() { + return 0 + } + + return strings.Compare(node.GetName(), embeddings.GetName()) + }, + ) + + if err := s.Process(); err != nil { + e.log.Warnf("Failed to generate nodes embedding: %v", err) + } + + // Process the remaining nodes in the batch + vectors, err := batch.Finalize(ctx) + if err != nil { + e.log.Warnf("Failed to add node to batch: %v", err) + return + } + + if err := e.upsertEmbeddings(ctx, vectors); err != nil { + e.log.Warnf("Failed to upsert embeddings: %v", err) + } +} + +// upsertEmbeddings is a helper function that upserts the embeddings into the backend. +func (e *EmbeddingProcessor) upsertEmbeddings(ctx context.Context, rawEmbeddings []*Embedding) error { + // Store the new embeddings into the backend + for _, embedding := range rawEmbeddings { + _, err := e.embeddingSrv.UpsertEmbedding(ctx, embedding) + if err != nil { + return trace.Wrap(err) + } + } + return nil +} diff --git a/lib/ai/embeddings_test.go b/lib/ai/embeddings_test.go new file mode 100644 index 0000000000000..2bd1e4b12a0af --- /dev/null +++ b/lib/ai/embeddings_test.go @@ -0,0 +1,257 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai_test + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/internalutils/stream" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/ai" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/utils" +) + +// MockEmbedder returns embeddings based on the sha256 hash function. Those +// embeddings have no semantic meaning but ensure different embedded content +// provides different embeddings. +type MockEmbedder struct{} + +func (m MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]ai.Vector64, error) { + result := make([]ai.Vector64, len(input)) + for i, text := range input { + hash := sha256.Sum256([]byte(text)) + vector := make(ai.Vector64, len(hash)) + for j, x := range hash { + vector[j] = 1 / float64(int(x)+1) + } + result[i] = vector + } + return result, nil +} + +func TestNodeEmbeddingGeneration(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + clock := clockwork.NewFakeClock() + + // Test setup: crate a backend, presence service, the node watcher and + // the embeddings service + bk, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + embedder := MockEmbedder{} + presence := local.NewPresenceService(bk) + embeddings := local.NewEmbeddingsService(bk) + + processor := ai.NewEmbeddingProcessor(&ai.EmbeddingProcessorConfig{ + AIClient: &embedder, + EmbeddingSrv: embeddings, + NodeSrv: presence, + Log: utils.NewLoggerForTests(), + Jitter: retryutils.NewSeventhJitter(), + }) + + done := make(chan struct{}) + go func() { + err := processor.Run(ctx, 100*time.Millisecond) + assert.ErrorIs(t, context.Canceled, err) + close(done) + }() + + // Add some node servers. + const numNodes = 5 + nodes := make([]types.Server, 0, numNodes) + for i := 0; i < numNodes; i++ { + node, _ := types.NewServer(fmt.Sprintf("node%d", i), types.KindNode, types.ServerSpecV2{ + Addr: "127.0.0.1:1234", + Hostname: fmt.Sprintf("node%d", i), + CmdLabels: map[string]types.CommandLabelV2{ + "version": {Result: "v8"}, + "hostname": {Result: fmt.Sprintf("node%d.example.com", i)}, + }, + }) + _, err = presence.UpsertNode(ctx, node) + require.NoError(t, err) + nodes = append(nodes, node) + } + + require.Eventually(t, func() bool { + items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) + assert.NoError(t, err) + return (len(items) == numNodes) && (len(nodes) == numNodes) + }, 7*time.Second, 200*time.Millisecond) + + cancel() + + waitForDone(t, done, "timed out waiting for processor to stop") + + validateEmbeddings(t, + presence.GetNodeStream(ctx, defaults.Namespace), + embeddings.GetEmbeddings(ctx, types.KindNode)) +} + +func TestMarshallUnmarshallEmbedding(t *testing.T) { + // We test that float precision is above six digits + initial := ai.NewEmbedding(types.KindNode, "foo", ai.Vector64{0.1234567, 1, 1}, sha256.Sum256([]byte("test"))) + + marshaled, err := ai.MarshalEmbedding(initial) + require.NoError(t, err) + + final, err := ai.UnmarshalEmbedding(marshaled) + require.NoError(t, err) + + require.Equal(t, initial.EmbeddedId, final.EmbeddedId) + require.Equal(t, initial.EmbeddedKind, final.EmbeddedKind) + require.Equal(t, initial.EmbeddedHash, final.EmbeddedHash) + require.Equal(t, initial.Vector, final.Vector) +} + +func waitForDone(t *testing.T, done chan struct{}, errMsg string) { + t.Helper() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal(errMsg) + } +} + +func validateEmbeddings(t *testing.T, nodesStream stream.Stream[types.Server], embeddingsStream stream.Stream[*ai.Embedding]) { + t.Helper() + + nodes, err := stream.Collect(nodesStream) + require.NoError(t, err) + + embeddings, err := stream.Collect(embeddingsStream) + require.NoError(t, err) + + require.Equal(t, len(nodes), len(embeddings), "Number of nodes and embeddings should be equal") + + for i, node := range nodes { + emb := embeddings[i] + + require.Equal(t, node.GetName(), emb.GetEmbeddedID(), "Node ID and embedding ID should be equal") + require.Equal(t, types.KindNode, emb.GetEmbeddedKind(), "Node kind and embedding kind should be equal") + } +} + +func Test_batchReducer_Add(t *testing.T) { + t.Parallel() + + // Sum process function - used for simplicity + sumFn := func(ctx context.Context, data []int) (int, error) { + sum := 0 + for _, d := range data { + sum += d + } + return sum, nil + } + + type testCase struct { + // Test case name + name string + // Process batch size + batchSize int + // Input data + data []int + // Function to process batch + processFn func(ctx context.Context, data []int) (int, error) + // Expected result on Add + want []int + // Expected result on Finalize + finalizeResult int + // Expected error + wantErr assert.ErrorAssertionFunc + } + + tests := []testCase{ + { + name: "empty", + batchSize: 100, + data: []int{}, + want: []int{}, + finalizeResult: 0, + processFn: sumFn, + wantErr: assert.NoError, + }, + { + name: "one element", + batchSize: 100, + data: []int{1}, + want: []int{0}, + finalizeResult: 1, + processFn: sumFn, + wantErr: assert.NoError, + }, + { + name: "many elements", + batchSize: 3, + data: []int{1, 1, 1, 1}, + want: []int{0, 0, 3, 0}, + finalizeResult: 1, + processFn: sumFn, + wantErr: assert.NoError, + }, + { + name: "propagate error", + batchSize: 2, + data: []int{0}, + want: []int{0}, + processFn: func(ctx context.Context, data []int) (int, error) { + return 0, errors.New("error") + }, + wantErr: assert.Error, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + br := ai.NewBatchReducer[int, int](tt.processFn, tt.batchSize) + + for i, d := range tt.data { + got, err := br.Add(ctx, d) + require.NoError(t, err) + assert.Equalf(t, tt.want[i], got, "Add(%v)", tt.data) + } + + got, err := br.Finalize(ctx) + if !tt.wantErr(t, err, fmt.Sprintf("Finalize(%v)", tt.data)) { + return + } + assert.Equalf(t, tt.finalizeResult, got, "Finalize(%v)", tt.data) + }) + } +} diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 10e714200fa18..ce8e691fb7e2c 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -621,11 +621,6 @@ type Server struct { // httpClientForAWSSTS overwrites the default HTTP client used for making // STS requests. httpClientForAWSSTS utils.HTTPDoClient - - // nodeEmbeddingWatcher listens for nodes and emebeds them. - // This is used only when Assist is enabled and allows to build an index - // to perform semantic node search. - nodeEmbeddingWatcher *services.NodeEmbeddingWatcher } // SetSAMLService registers svc as the SAMLService that provides the SAML @@ -743,15 +738,6 @@ func (a *Server) SetHeadlessAuthenticationWatcher(headlessAuthenticationWatcher a.headlessAuthenticationWatcher = headlessAuthenticationWatcher } -// SetNodeEmbeddingsWatcher stores a reference to the nodeEmbeddingWatcher into -// the auth. This should be called only when Assist is enabled and the -// auth runs the Embeddings service. -func (a *Server) SetNodeEmbeddingsWatcher(watcher *services.NodeEmbeddingWatcher) { - a.lock.Lock() - defer a.lock.Unlock() - a.nodeEmbeddingWatcher = watcher -} - // syncUpgradeWindowStartHour attempts to load the cloud UpgradeWindowStartHour value and set // the ClusterMaintenanceConfig resource's AgentUpgrade.UTCStartHour field to match it. func (a *Server) syncUpgradeWindowStartHour(ctx context.Context) error { diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 5581a54573109..3ea047d60fbb4 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -1136,6 +1136,15 @@ func (a *ServerWithRoles) GetInstances(ctx context.Context, filter types.Instanc return a.authServer.GetInstances(ctx, filter) } +// GetNodeStream returns a stream of nodes. +func (a *ServerWithRoles) GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server] { + if err := a.action(namespace, types.KindNode, types.VerbList, types.VerbRead); err != nil { + return stream.Fail[types.Server](trace.Wrap(err)) + } + + return a.authServer.GetNodeStream(ctx, namespace) +} + func (a *ServerWithRoles) GetClusterAlerts(ctx context.Context, query types.GetClusterAlertsRequest) ([]types.ClusterAlert, error) { // unauthenticated clients can never check for alerts. we don't normally explicitly // check for this kind of thing, but since alerts use an unusual access-control diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 8eba35b95dfbc..487151fa5e811 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -32,6 +32,7 @@ import ( loginrulepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/loginrule/v1" pluginspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" samlidppb "github.com/gravitational/teleport/api/gen/proto/go/teleport/samlidp/v1" + "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" @@ -282,6 +283,11 @@ func (c *Client) CompareAndSwapUser(ctx context.Context, new, expected types.Use return trace.NotImplemented(notImplementedMessage) } +// GetNodeStream not implemented: can only be called locally +func (c *Client) GetNodeStream(_ context.Context, _ string) stream.Stream[types.Server] { + return stream.Fail[types.Server](trace.NotImplemented(notImplementedMessage)) +} + // StreamSessionEvents streams all events from a given session recording. An error is returned on the first // channel if one is encountered. Otherwise the event channel is closed when the stream ends. // The event channel is not closed on error to prevent race conditions in downstream select statements. diff --git a/lib/service/service.go b/lib/service/service.go index 3b1470265be6a..54163d1cbc2cd 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -66,6 +66,7 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/agentless" "github.com/gravitational/teleport/lib/ai" @@ -1698,24 +1699,17 @@ func (process *TeleportProcess) initAuthService() error { if cfg.Auth.AssistAPIKey != "" { log.Debugf("Starting embedding watcher") openAIClient := ai.NewClient(cfg.Auth.AssistAPIKey) - nodeEmbeddingWatcher, err := services.NewNodeEmbeddingWatcher(process.ExitContext(), services.NodeEmbeddingWatcherConfig{ - NodeWatcherConfig: services.NodeWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentAssist, - Log: log, - Client: authServer.Services, - }, - NodesGetter: authServer.Services, - }, - Embeddings: authServer, - Embedder: openAIClient, + embeddingProcessor := ai.NewEmbeddingProcessor(&ai.EmbeddingProcessorConfig{ + AIClient: openAIClient, + EmbeddingSrv: authServer, + NodeSrv: authServer, + Log: log, + Jitter: retryutils.NewFullJitter(), }) - if err != nil { - return trace.Wrap(err) - } - authServer.SetNodeEmbeddingsWatcher(nodeEmbeddingWatcher) - go nodeEmbeddingWatcher.RunPeriodicEmbedding(process.ExitContext(), ai.EmbeddingPeriod) + process.RegisterFunc("ai.embedding-processor", func() error { + return embeddingProcessor.Run(process.ExitContext(), ai.EmbeddingPeriod) + }) } headlessAuthenticationWatcher, err := local.NewHeadlessAuthenticationWatcher(process.ExitContext(), local.HeadlessAuthenticationWatcherConfig{ diff --git a/lib/services/embeddings.go b/lib/services/embeddings.go index 3f4cdafcd82a6..2143b4195e5db 100644 --- a/lib/services/embeddings.go +++ b/lib/services/embeddings.go @@ -1,36 +1,26 @@ -// Copyright 2023 Gravitational, Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package services import ( "context" - "sync" - "sync/atomic" - "time" - "github.com/gravitational/trace" - "google.golang.org/protobuf/proto" - "gopkg.in/yaml.v3" - - apidefaults "github.com/gravitational/teleport/api/defaults" - embeddingpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/embedding/v1" "github.com/gravitational/teleport/api/internalutils/stream" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/ai" - "github.com/gravitational/teleport/lib/utils/interval" ) // Embeddings service is responsible for storing and retrieving embeddings in @@ -44,413 +34,3 @@ type Embeddings interface { // UpsertEmbedding creates or update a single ai.Embedding in the backend. UpsertEmbedding(ctx context.Context, embedding *ai.Embedding) (*ai.Embedding, error) } - -// MarshalEmbedding marshals the ai.Embedding resource to binary ProtoBuf. -func MarshalEmbedding(embedding *ai.Embedding) ([]byte, error) { - data, err := proto.Marshal((*embeddingpb.Embedding)(embedding)) - if err != nil { - return nil, trace.Wrap(err) - } - return data, nil -} - -// UnmarshalEmbedding unmarshals binary ProtoBuf into an ai.Embedding resource. -func UnmarshalEmbedding(bytes []byte) (*ai.Embedding, error) { - if len(bytes) == 0 { - return nil, trace.BadParameter("missing embedding data") - } - var embedding embeddingpb.Embedding - err := proto.Unmarshal(bytes, &embedding) - if err != nil { - return nil, trace.Wrap(err) - } - - return (*ai.Embedding)(&embedding), nil -} - -// NodeEmbeddingWatcher listen for Node events and asynchronously compute -// embeddings for known nodes. -type NodeEmbeddingWatcher struct { - *resourceWatcher - *nodeEmbeddingCollector -} - -// NewNodeEmbeddingWatcher returns a new instance of NodeEmbeddingWatcher. -func NewNodeEmbeddingWatcher(ctx context.Context, cfg NodeEmbeddingWatcherConfig) (*NodeEmbeddingWatcher, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - collector := &nodeEmbeddingCollector{ - NodeEmbeddingWatcherConfig: cfg, - initializationC: make(chan struct{}), - currentNodes: make(map[string]*embeddedNode), - } - // start the collector as staled. - collector.stale.Store(true) - - watcher, err := newResourceWatcher(ctx, collector, cfg.ResourceWatcherConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - return &NodeEmbeddingWatcher{resourceWatcher: watcher, nodeEmbeddingCollector: collector}, nil -} - -// NodeEmbeddingWatcherConfig is the configuration of the NodeEmbeddingWatcher. -// It extends the NodeWatcherConfig with a reference to the services.Embeddings. -// This way, the watcher and collector have access to the embeddings in the -// backend and can create new embeddings via the ai.Embedder. -type NodeEmbeddingWatcherConfig struct { - NodeWatcherConfig - Embeddings - Embedder ai.Embedder -} - -func (cfg *NodeEmbeddingWatcherConfig) CheckAndSetDefaults() error { - if err := cfg.NodeWatcherConfig.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) - } - if cfg.Embedder == nil { - return trace.BadParameter("embedder is not set") - } - return nil -} - -// nodeEmbeddingCollector accompanies resourceWatcher when monitoring currentNodes. -// It keeps tracks of which node has been embedded and which node requires embedding. -// The embedding happens asynchronously as calling the openAI API by batch is -// much quicker, stable and efficient. -type nodeEmbeddingCollector struct { - NodeEmbeddingWatcherConfig - - // initializationC is used to check whether the initial sync has completed - // This is required for implementing the collector interface - initializationC chan struct{} - // once keeps track if the initialization message has already been sent - once sync.Once - - // currentNodes holds the knwown nodes and their embedding state. - // The map is consumed in 3 cases: - // - during the initial or un-stale full-sync - // - when an event comes in and changes a node's state - // - during the embedding process - // During the embedding process, currentNodes might be updated between the - // read and the write operations: - // - Additions can be ignored as they'll get picked up by the next embedding routine - // - If an element gets deleted, the embedding routine must not add it back - // - If an element gets updated, the timestamp will have changed and the - // indexation must not mark the element as embedded. It will be picked up - // by the next embedding routine. - currentNodes map[string]*embeddedNode - - // mutex must be acquired before reading or writing to currentNodes - mutex sync.Mutex - stale atomic.Bool -} - -type embeddedNode struct { - node types.Server - needsEmbedding bool - // Last update allows to avoid most race conditions - // Before updating or deleting the node, the caller must - // check if it has been edited in the meantime. - lastUpdate time.Time -} - -func (e *embeddedNode) gotEmbedded() { - e.needsEmbedding = false -} - -// resourceKind specifies the resource kind to watch. -func (n *nodeEmbeddingCollector) resourceKind() string { - return types.KindNode -} - -// getResourcesAndUpdateCurrent is called when the resources should be -// (re-)fetched directly. -func (n *nodeEmbeddingCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - // We start the full sync by locking. As we are computing diff between the - // nodes in the backend and our tracked nodes, we don't want currentNodes - // to change until we're finish sync-ing. - n.mutex.Lock() - defer n.mutex.Unlock() - timestamp := time.Now() - - allNodes, err := n.getNodes(ctx) - if err != nil { - return trace.Wrap(err) - } - - toRemove := make([]string, 0) - // If we knew a node which is not in the full node list anymore we drop - // it from the index - for _, knownNode := range n.currentNodes { - if _, ok := allNodes[knownNode.node.GetName()]; !ok { - toRemove = append(toRemove, knownNode.node.GetName()) - } - } - - n.addNodes(allNodes, timestamp) - n.removeNodes(toRemove, timestamp) - - n.defineCollectorAsInitialized() - n.stale.Store(false) - return nil -} - -// addNodes takes a map of new or updated nodes, stores them and flags them for -// embedding. mutex must be acquired before calling this function. If a node's -// timestamp is newer than the provided timestamp, it will be ignored. -func (n *nodeEmbeddingCollector) addNodes(nodes map[string]types.Server, timestamp time.Time) { - for nodeName, node := range nodes { - // If the node is already known and has been edited in the meantime we - // don't want to override as we don't have the latest version - if currentNode, ok := n.currentNodes[nodeName]; ok { - if currentNode.lastUpdate.After(timestamp) { - continue - } - } - n.currentNodes[nodeName] = &embeddedNode{ - node: node, - needsEmbedding: true, - lastUpdate: timestamp, - } - } -} - -// removeNodes takes a list of node names, removes them from the vector index -// and from the collector tracking. mutex must be acquired before calling this -// function. If a node's timestamp is newer than the provided timestamp, it -// will be ignored. -func (n *nodeEmbeddingCollector) removeNodes(nodeNames []string, timestamp time.Time) { - for _, nodeName := range nodeNames { - if n.currentNodes[nodeName].lastUpdate.Before(timestamp) { - delete(n.currentNodes, nodeName) - } - } -} - -// RunIndexation walks through all collector-tracked nodes and runs a batch -// embedding on all nodes needing embeddings. The embeddings are then inserted -// into the vector index. This process is ran asynchronously to reduce the load -// and leverage OpenAI's batch embedding API. -func (n *nodeEmbeddingCollector) RunIndexation(ctx context.Context) error { - n.Log.Debug("running embedding") - // If data is stale, we attempt to refresh it, else we continue and embed - // the stale data - if n.stale.Load() { - _ = n.getResourcesAndUpdateCurrent(ctx) - } - - needsEmbedding := make(map[string][]byte) - n.mutex.Lock() - timestamp := time.Now() - for nodeName, node := range n.currentNodes { - if node.needsEmbedding { - text, err := serializeNode(node.node) - if err != nil { - n.Log.Warningf("failed to serialize node %s, the node won't be embedded", node.node.GetName()) - continue - } - needsEmbedding[nodeName] = text - } - } - n.mutex.Unlock() - - embeddings, err := n.embed(ctx, types.KindNode, needsEmbedding) - if err != nil { - return trace.Wrap(err) - } - n.mutex.Lock() - defer n.mutex.Unlock() - for _, embedding := range embeddings { - if node, ok := n.currentNodes[embedding.GetEmbeddedID()]; ok && node.lastUpdate.Before(timestamp) { - node.gotEmbedded() - } - } - - n.Log.Debugf("Embedded %d nodes", len(embeddings)) - - // TODO(hugoShaka): when vector index is here, delete then insert nodes in it. - return nil -} - -func (n *nodeEmbeddingCollector) getNodes(ctx context.Context) (map[string]types.Server, error) { - nodes, err := n.NodesGetter.GetNodes(ctx, apidefaults.Namespace) - if err != nil { - return nil, trace.Wrap(err) - } - - if len(nodes) == 0 { - return map[string]types.Server{}, nil - } - - current := make(map[string]types.Server, len(nodes)) - for _, node := range nodes { - current[node.GetName()] = node - } - - return current, nil -} - -func (n *nodeEmbeddingCollector) defineCollectorAsInitialized() { - n.once.Do(func() { - // mark watcher as initialized. - close(n.initializationC) - }) -} - -// processEventAndUpdateCurrent is called when a watcher event is received. -func (n *nodeEmbeddingCollector) processEventAndUpdateCurrent(_ context.Context, event types.Event) { - if event.Resource == nil || event.Resource.GetKind() != types.KindNode { - n.Log.Warningf("Unexpected event: %v.", event) - return - } - - n.mutex.Lock() - timestamp := time.Now() - defer n.mutex.Unlock() - switch event.Type { - case types.OpDelete: - n.removeNodes([]string{event.Resource.GetName()}, timestamp) - case types.OpPut: - server, ok := event.Resource.(types.Server) - if !ok { - n.Log.Warningf("Unexpected type %T.", event.Resource) - return - } - n.addNodes(map[string]types.Server{server.GetName(): server}, timestamp) - default: - n.Log.Warningf("Skipping unsupported event type %s.", event.Type) - } -} - -func (n *nodeEmbeddingCollector) initializationChan() <-chan struct{} { - return n.initializationC -} - -func (n *nodeEmbeddingCollector) notifyStale() { - n.stale.Store(true) -} - -// NodeCount returns the number of nodes being tracked by the collector which -// have not been embedded. This function is mainly here for testing purposes. -func (n *nodeEmbeddingCollector) NodeCount(needsEmbedding bool) int { - count := 0 - n.mutex.Lock() - defer n.mutex.Unlock() - for _, node := range n.currentNodes { - if node.needsEmbedding == needsEmbedding { - count += 1 - } - } - return count -} - -// embed takes a resource textual representation, checks if the resource -// already has an up-to-date embedding stored in the backend, and computes -// a new embedding otherwise. The newly computed embedding is stored in -// the backend. -func (n *nodeEmbeddingCollector) embed(ctx context.Context, kind string, resources map[string][]byte) ([]*ai.Embedding, error) { - - // Lookup if there are embeddings in the backend for this node - // and the hash matches - embeddingsFromCache := make([]*ai.Embedding, 0) - toEmbed := make(map[string][]byte) - for name, data := range resources { - existingEmbedding, err := n.GetEmbedding(ctx, kind, name) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - if err == nil { - if embeddingHashMatches(existingEmbedding, ai.EmbeddingHash(data)) { - embeddingsFromCache = append(embeddingsFromCache, existingEmbedding) - continue - } - } - toEmbed[name] = data - } - - // Convert to a list but keep track of the order so that we know which - // input maps to which resource. - keys := make([]string, 0, len(toEmbed)) - input := make([]string, len(toEmbed)) - - for key := range toEmbed { - keys = append(keys, key) - } - - for i, key := range keys { - input[i] = string(toEmbed[key]) - } - - response, err := n.Embedder.ComputeEmbeddings(ctx, input) - if err != nil { - return nil, trace.Wrap(err) - } - - newEmbeddings := make([]*ai.Embedding, 0, len(response)) - for i, vector := range response { - newEmbeddings = append(newEmbeddings, ai.NewEmbedding(kind, keys[i], vector, ai.EmbeddingHash(resources[keys[i]]))) - } - - // Store the new embeddings into the backend - for _, embedding := range newEmbeddings { - _, err := n.UpsertEmbedding(ctx, embedding) - if err != nil { - return nil, trace.Wrap(err) - } - } - - return append(embeddingsFromCache, newEmbeddings...), nil -} - -func embeddingHashMatches(embedding *ai.Embedding, hash ai.Sha256Hash) bool { - if len(embedding.EmbeddedHash) != 32 { - return false - } - - return *(*ai.Sha256Hash)(embedding.EmbeddedHash) == hash -} - -// serializeNode converts a type.Server into text ready to be fed to an -// embedding model. The YAML serialization function was chosen over JSON and -// CSV as it provided better results. -func serializeNode(node types.Server) ([]byte, error) { - a := struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - SubKind string `yaml:"subkind"` - Labels map[string]string `yaml:"labels"` - }{ - Name: node.GetName(), - Kind: types.KindNode, - SubKind: node.GetSubKind(), - Labels: node.GetAllLabels(), - } - text, err := yaml.Marshal(&a) - return text, trace.Wrap(err) -} - -func (n *NodeEmbeddingWatcher) RunPeriodicEmbedding(ctx context.Context, period time.Duration) { - ticker := interval.New(interval.Config{ - Duration: period, - Jitter: retryutils.NewSeventhJitter(), - FirstDuration: time.Minute, - }) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.Next(): - err := n.RunIndexation(ctx) - if err != nil { - n.Log.Error(err) - } - } - } - -} diff --git a/lib/services/embeddings_test.go b/lib/services/embeddings_test.go deleted file mode 100644 index daea16826f1f4..0000000000000 --- a/lib/services/embeddings_test.go +++ /dev/null @@ -1,354 +0,0 @@ -// Copyright 2023 Gravitational, Inc -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package services_test - -import ( - "context" - "crypto/sha256" - "fmt" - "testing" - "time" - - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/api/internalutils/stream" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/ai" - "github.com/gravitational/teleport/lib/backend/memory" - "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/services/local" -) - -// MockEmbedder returns embeddings based on the sha256 hash function. Those -// embeddings have no semantic meaning but ensure different embedded content -// provides different embeddings. -type MockEmbedder struct { -} - -func (m MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]ai.Vector64, error) { - result := make([]ai.Vector64, len(input)) - for i, text := range input { - hash := sha256.Sum256([]byte(text)) - vector := make(ai.Vector64, len(hash)) - for j, x := range hash { - vector[j] = 1 / float64(int(x)+1) - } - result[i] = vector - } - return result, nil -} - -func TestNodeEmbeddingWatcherCreate(t *testing.T) { - t.Parallel() - - ctx := context.Background() - clock := clockwork.NewFakeClock() - - // Test setup: crate a backend, presence service, the node watcher and - // the embeddings service - bk, err := memory.New(memory.Config{ - Context: ctx, - Clock: clock, - }) - require.NoError(t, err) - - type client struct { - services.Presence - services.Embeddings - types.Events - } - - embedder := MockEmbedder{} - presence := local.NewPresenceService(bk) - embeddings := local.NewEmbeddingsService(bk) - - cfg := services.NodeEmbeddingWatcherConfig{ - NodeWatcherConfig: services.NodeWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - Client: &client{ - Presence: presence, - Embeddings: embeddings, - Events: local.NewEventsService(bk), - }, - MaxStaleness: time.Minute, - }, - }, - Embeddings: embeddings, - Embedder: embedder, - } - watcher, err := services.NewNodeEmbeddingWatcher(ctx, cfg) - require.NoError(t, err) - t.Cleanup(watcher.Close) - - // Test start - // Add some node servers. - nodes := make([]types.Server, 0, 5) - for i := 0; i < 5; i++ { - node, _ := types.NewServer(fmt.Sprintf("node%d", i), types.KindNode, types.ServerSpecV2{ - Addr: "127.0.0.1:1234", - Hostname: fmt.Sprintf("node%d", i), - CmdLabels: map[string]types.CommandLabelV2{ - "version": {Result: "v8"}, - "hostname": {Result: fmt.Sprintf("node%d.example.com", i)}, - }, - }) - _, err = presence.UpsertNode(ctx, node) - require.NoError(t, err) - nodes = append(nodes, node) - } - - // Validate the nodes are eventually tracked by the embedding collector - require.Eventually(t, func() bool { - return watcher.NodeCount(true) == len(nodes) - }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive currentNodes.") - require.Zero(t, watcher.NodeCount(false)) - - // Trigger the embedding routine - err = watcher.RunIndexation(ctx) - require.NoError(t, err) - - // Validate that all nodes were embedded and snapshot the backend content - require.Equal(t, watcher.NodeCount(false), len(nodes)) - require.Zero(t, watcher.NodeCount(true)) - items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) - require.NoError(t, err) - require.Equal(t, len(items), len(nodes)) -} - -func TestNodeEmbeddingWatcherIdempotency(t *testing.T) { - t.Parallel() - - ctx := context.Background() - clock := clockwork.NewFakeClock() - - // Test setup: crate a backend, presence service, the node watcher and - // the embeddings service - bk, err := memory.New(memory.Config{ - Context: ctx, - Clock: clock, - }) - require.NoError(t, err) - - type client struct { - services.Presence - services.Embeddings - types.Events - } - - embedder := MockEmbedder{} - presence := local.NewPresenceService(bk) - embeddings := local.NewEmbeddingsService(bk) - - cfg := services.NodeEmbeddingWatcherConfig{ - NodeWatcherConfig: services.NodeWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - Client: &client{ - Presence: presence, - Embeddings: embeddings, - Events: local.NewEventsService(bk), - }, - MaxStaleness: time.Minute, - }, - }, - Embeddings: embeddings, - Embedder: embedder, - } - watcher, err := services.NewNodeEmbeddingWatcher(ctx, cfg) - require.NoError(t, err) - t.Cleanup(watcher.Close) - - // Test start - // Add some node servers. - nodes := make([]types.Server, 0, 5) - for i := 0; i < 5; i++ { - node, _ := types.NewServer(fmt.Sprintf("node%d", i), types.KindNode, types.ServerSpecV2{ - Addr: "127.0.0.1:1234", - Hostname: fmt.Sprintf("node%d", i), - CmdLabels: map[string]types.CommandLabelV2{ - "version": {Result: "v8"}, - "hostname": {Result: fmt.Sprintf("node%d.example.com", i)}, - }, - }) - _, err = presence.UpsertNode(ctx, node) - require.NoError(t, err) - nodes = append(nodes, node) - } - - // Validate the nodes are eventually tracked by the embedding collector - require.Eventually(t, func() bool { - return watcher.NodeCount(true) == len(nodes) - }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive currentNodes.") - require.Zero(t, watcher.NodeCount(false)) - - // Trigger the embedding routine - err = watcher.RunIndexation(ctx) - require.NoError(t, err) - - // Validate that all nodes were embedded and snapshot the backend content - require.Equal(t, watcher.NodeCount(false), len(nodes)) - require.Zero(t, watcher.NodeCount(true)) - items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) - require.NoError(t, err) - require.Equal(t, len(items), len(nodes)) - - // Trigger the embedding routine again - err = watcher.RunIndexation(ctx) - require.NoError(t, err) - - // Validate no nodes are needing embedding and that the items in the backend - // have been updated - require.Zero(t, watcher.NodeCount(true)) - newItems, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) - require.NoError(t, err) - require.Equal(t, len(items), len(newItems)) - - for _, oldEmbedding := range items { - newEmbedding, err := embeddings.GetEmbedding(ctx, types.KindNode, oldEmbedding.GetEmbeddedID()) - require.NoError(t, err) - require.Equal(t, oldEmbedding.GetVector(), newEmbedding.GetVector()) - } -} - -func TestNodeEmbeddingWatcherUpdate(t *testing.T) { - t.Parallel() - - ctx := context.Background() - clock := clockwork.NewFakeClock() - - // Test setup: crate a backend, presence service, the node watcher and - // the embeddings service - bk, err := memory.New(memory.Config{ - Context: ctx, - Clock: clock, - }) - require.NoError(t, err) - - type client struct { - services.Presence - services.Embeddings - types.Events - } - - embedder := MockEmbedder{} - presence := local.NewPresenceService(bk) - embeddings := local.NewEmbeddingsService(bk) - - cfg := services.NodeEmbeddingWatcherConfig{ - NodeWatcherConfig: services.NodeWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: "test", - Client: &client{ - Presence: presence, - Embeddings: embeddings, - Events: local.NewEventsService(bk), - }, - MaxStaleness: time.Minute, - }, - }, - Embeddings: embeddings, - Embedder: embedder, - } - watcher, err := services.NewNodeEmbeddingWatcher(ctx, cfg) - require.NoError(t, err) - t.Cleanup(watcher.Close) - - // Test setup: Add some node servers. - nodes := make([]types.Server, 0, 5) - for i := 0; i < 5; i++ { - node, _ := types.NewServer(fmt.Sprintf("node%d", i), types.KindNode, types.ServerSpecV2{ - Addr: "127.0.0.1:1234", - Hostname: fmt.Sprintf("node%d", i), - CmdLabels: map[string]types.CommandLabelV2{ - "version": {Result: "v8"}, - "hostname": {Result: fmt.Sprintf("node%d.example.com", i)}, - }, - }) - _, err = presence.UpsertNode(ctx, node) - require.NoError(t, err) - nodes = append(nodes, node) - } - - // Validate the nodes are eventually tracked by the embedding collector - require.Eventually(t, func() bool { - return watcher.NodeCount(true) == len(nodes) - }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive currentNodes.") - require.Zero(t, watcher.NodeCount(false)) - - // Trigger the embedding routine - err = watcher.RunIndexation(ctx) - require.NoError(t, err) - - // Validate that all nodes were embedded and snapshot the backend content - require.Equal(t, watcher.NodeCount(false), len(nodes)) - require.Zero(t, watcher.NodeCount(true)) - items, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) - require.NoError(t, err) - require.Equal(t, len(items), len(nodes)) - - // Test start - // Edit the node server labels - for i := 0; i < 5; i++ { - nodes[i].SetCmdLabels( - map[string]types.CommandLabel{ - "version": &types.CommandLabelV2{Result: "v9"}, - "hostname": &types.CommandLabelV2{Result: fmt.Sprintf("node%d.example.com", i)}, - }) - _, err = presence.UpsertNode(ctx, nodes[i]) - require.NoError(t, err) - } - - // Validate the node updates have been tracked by the watcher and that the - // nodes are embedding candidates - require.Eventually(t, func() bool { - return watcher.NodeCount(true) == len(nodes) - }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive currentNodes.") - require.Zero(t, watcher.NodeCount(false)) - - // Trigger the embedding routine again - err = watcher.RunIndexation(ctx) - require.NoError(t, err) - - // Validate no nodes are needing embedding and that the items in the backend - // have been updated - require.Zero(t, watcher.NodeCount(true)) - newItems, err := stream.Collect(embeddings.GetEmbeddings(ctx, types.KindNode)) - require.NoError(t, err) - require.Equal(t, len(items), len(newItems)) - - for _, oldEmbedding := range items { - newEmbedding, err := embeddings.GetEmbedding(ctx, types.KindNode, oldEmbedding.GetEmbeddedID()) - require.NoError(t, err) - require.NotEqual(t, oldEmbedding.GetVector(), newEmbedding.GetVector()) - } -} - -func TestMarshallUnmarshallEmbedding(t *testing.T) { - // We test that float precision is above six digits - initial := ai.NewEmbedding(types.KindNode, "foo", ai.Vector64{0.1234567, 1, 1}, sha256.Sum256([]byte("test"))) - - marshaled, err := services.MarshalEmbedding(initial) - require.NoError(t, err) - - final, err := services.UnmarshalEmbedding(marshaled) - require.NoError(t, err) - - require.Equal(t, initial.EmbeddedId, final.EmbeddedId) - require.Equal(t, initial.EmbeddedKind, final.EmbeddedKind) - require.Equal(t, initial.EmbeddedHash, final.EmbeddedHash) - require.Equal(t, initial.Vector, final.Vector) -} diff --git a/lib/services/local/embeddings.go b/lib/services/local/embeddings.go index ef94ec18f8ed2..d782e1448b573 100644 --- a/lib/services/local/embeddings.go +++ b/lib/services/local/embeddings.go @@ -26,7 +26,6 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/services" ) // EmbeddingsService implements the services.Embeddings interface. @@ -48,7 +47,7 @@ func (e EmbeddingsService) GetEmbedding(ctx context.Context, kind, resourceID st if err != nil { return nil, trace.Wrap(err) } - return services.UnmarshalEmbedding(result.Value) + return ai.UnmarshalEmbedding(result.Value) } // GetEmbeddings returns a stream of embeddings for a given kind. @@ -56,7 +55,7 @@ func (e EmbeddingsService) GetEmbeddings(ctx context.Context, kind string) strea startKey := backend.ExactKey(embeddingsPrefix, kind) items := backend.StreamRange(ctx, e, startKey, backend.RangeEnd(startKey), 50) return stream.FilterMap(items, func(item backend.Item) (*ai.Embedding, bool) { - embedding, err := services.UnmarshalEmbedding(item.Value) + embedding, err := ai.UnmarshalEmbedding(item.Value) if err != nil { e.log.Warnf("Skipping embedding at %s, failed to unmarshal: %v", item.Key, err) return nil, false @@ -67,7 +66,7 @@ func (e EmbeddingsService) GetEmbeddings(ctx context.Context, kind string) strea // UpsertEmbedding creates or update a single ai.Embedding in the backend. func (e EmbeddingsService) UpsertEmbedding(ctx context.Context, embedding *ai.Embedding) (*ai.Embedding, error) { - value, err := services.MarshalEmbedding(embedding) + value, err := ai.MarshalEmbedding(embedding) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index 93bf7f7c557c7..03470687fac29 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -317,6 +317,20 @@ func (s *PresenceService) GetNodes(ctx context.Context, namespace string) ([]typ return servers, nil } +// GetNodeStream returns a stream of nodes in a namespace. +func (s *PresenceService) GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server] { + startKey := backend.ExactKey(nodesPrefix, namespace) + items := backend.StreamRange(ctx, s, startKey, backend.RangeEnd(startKey), 50) + return stream.FilterMap(items, func(item backend.Item) (types.Server, bool) { + embedding, err := services.UnmarshalServer(item.Value, types.KindNode) + if err != nil { + s.log.Warnf("Skipping node at %s, failed to unmarshal: %v", item.Key, err) + return nil, false + } + return embedding, true + }) +} + // UpsertNode registers node presence, permanently if TTL is 0 or for the // specified duration with second resolution if it's >= 1 second. func (s *PresenceService) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) { diff --git a/lib/services/presence.go b/lib/services/presence.go index 691524653a394..148f6d4ea0c21 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -36,6 +36,12 @@ type NodesGetter interface { GetNodes(ctx context.Context, namespace string) ([]types.Server, error) } +// NodesStreamGetter is a service that gets nodes. +type NodesStreamGetter interface { + // GetNodeStream returns a list of registered servers. + GetNodeStream(ctx context.Context, namespace string) stream.Stream[types.Server] +} + // Presence records and reports the presence of all components // of the cluster - Nodes, Proxies and SSH nodes type Presence interface { @@ -52,6 +58,9 @@ type Presence interface { // NodesGetter gets nodes NodesGetter + // NodesStreamGetter gets nodes as a stream + NodesStreamGetter + // DeleteAllNodes deletes all nodes in a namespace. DeleteAllNodes(ctx context.Context, namespace string) error diff --git a/lib/utils/stream/zip.go b/lib/utils/stream/zip.go new file mode 100644 index 0000000000000..74d0f6ab20f7e --- /dev/null +++ b/lib/utils/stream/zip.go @@ -0,0 +1,121 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stream + +import ( + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/internalutils/stream" +) + +// ZipStreams is a helper for iterate two streams and process elements in the +// leader stream only if they don't already exists in the follower stream. +// The streams must be sorted and comparable. +type ZipStreams[T, V any] struct { + // leader is the stream that will be leading the iteration. + leader stream.Stream[T] + // follower is the stream that will be following the iteration. + follower stream.Stream[V] + // onMissing is the function that will be called when the leader element is + // missing in the follower stream. + onMissing func(elem T) error + // onEqualKeys is the function that will be called when the leader element + // has the same key as the follower element. It allows additional processing + // of the element. + onEqualKeys func(leader T, follower V) error + // compareKeys is the function that will be used to compare the keys of the + // leader and follower elements. + // It should return 0 if leader == follower, -1 if leader < follower, and +1 if leader > follower. + compareKeys func(leader T, follower V) int +} + +// NewZipStreams returns a new instance of ZipStreams. +func NewZipStreams[T, V any](leader stream.Stream[T], follower stream.Stream[V], + onMissing func(elem T) error, + onEqualKeys func(leader T, follower V) error, + compare func(leader T, follower V) int, +) *ZipStreams[T, V] { + return &ZipStreams[T, V]{ + leader: leader, + follower: follower, + onMissing: onMissing, + onEqualKeys: onEqualKeys, + compareKeys: compare, + } +} + +// Process consumes the streams and returns an error reported by handler functions. +// Processing will stop on the first error. +func (z *ZipStreams[T, V]) Process() error { + var leaderItem T + var followerItem V + hasLeader := z.leader.Next() + hasFollower := z.follower.Next() + + if hasLeader { + leaderItem = z.leader.Item() + } + if hasFollower { + followerItem = z.follower.Item() + } + + for hasLeader && hasFollower { + cmp := z.compareKeys(leaderItem, followerItem) + if cmp == -1 { + // leader > follower - follower is missing + if err := z.onMissing(leaderItem); err != nil { + return trace.Wrap(err) + } + + hasLeader = z.leader.Next() + if hasLeader { + leaderItem = z.leader.Item() + } + } else if cmp == 1 { + // leader < follower - advancde + hasFollower = z.follower.Next() + if hasFollower { + followerItem = z.follower.Item() + } + } else { + // leader == follower + if err := z.onEqualKeys(leaderItem, followerItem); err != nil { + return trace.Wrap(err) + } + hasLeader = z.leader.Next() + hasFollower = z.follower.Next() + if hasLeader { + leaderItem = z.leader.Item() + } + if hasFollower { + followerItem = z.follower.Item() + } + } + } + + for hasLeader { + if err := z.onMissing(leaderItem); err != nil { + return trace.Wrap(err) + } + hasLeader = z.leader.Next() + if hasLeader { + leaderItem = z.leader.Item() + } + } + + return trace.NewAggregate(z.leader.Done(), z.follower.Done()) +} diff --git a/lib/utils/stream/zip_test.go b/lib/utils/stream/zip_test.go new file mode 100644 index 0000000000000..c5c555f9f3bda --- /dev/null +++ b/lib/utils/stream/zip_test.go @@ -0,0 +1,254 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stream_test + +import ( + "strings" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/internalutils/stream" + streamutils "github.com/gravitational/teleport/lib/utils/stream" +) + +func Test_zipStreams_Process(t *testing.T) { + type testCase[T any, V any] struct { + name string + validate func(t *testing.T) (*streamutils.ZipStreams[T, V], func()) + wantErr bool + } + tests := []testCase[string, string]{ + { + name: "empty", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCounter := 0 + return streamutils.NewZipStreams[string, string]( + stream.Empty[string](), + stream.Empty[string](), + func(s1 string) error { + counter++ + return nil + }, + func(leader string, follower string) error { + equalCounter++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 0, counter) + require.Equal(t, 0, equalCounter) + } + }, + wantErr: false, + }, + { + name: "one", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCounter := 0 + return streamutils.NewZipStreams[string, string]( + stream.Slice([]string{"foo"}), + stream.Empty[string](), + func(s1 string) error { + counter++ + return nil + }, + func(leader string, follower string) error { + equalCounter++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 1, counter) + require.Equal(t, 0, equalCounter) + } + }, + wantErr: false, + }, + { + name: "no leaders", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCounter := 0 + return streamutils.NewZipStreams[string, string]( + stream.Empty[string](), + stream.Slice([]string{"foo"}), + func(s1 string) error { + counter++ + return nil + }, + func(leader string, follower string) error { + equalCounter++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 0, counter) + require.Equal(t, 0, equalCounter) + } + }, + wantErr: false, + }, + { + name: "already in sync", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCounter := 0 + return streamutils.NewZipStreams[string, string]( + stream.Slice([]string{"foo"}), + stream.Slice([]string{"foo"}), + func(s1 string) error { + counter++ + return nil + }, + func(leader string, follower string) error { + equalCounter++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 0, counter) + require.Equal(t, 1, equalCounter) + } + }, + wantErr: false, + }, + { + name: "additional leader", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCounter := 0 + calledWith := make([]string, 0) + return streamutils.NewZipStreams[string, string]( + stream.Slice([]string{"bar", "foo"}), + stream.Slice([]string{"foo"}), + func(s1 string) error { + counter++ + calledWith = append(calledWith, s1) + return nil + }, + func(leader string, follower string) error { + // should be called with "foo" and "foo" + require.Equal(t, "foo", leader) + require.Equal(t, "foo", follower) + equalCounter++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 1, counter) + require.Equal(t, []string{"bar"}, calledWith) + require.Equal(t, 1, equalCounter) + } + }, + wantErr: false, + }, + { + name: "additional follower - no calls", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCounter := 0 + return streamutils.NewZipStreams[string, string]( + stream.Slice([]string{"foo"}), + stream.Slice([]string{"bar", "foo"}), + func(s1 string) error { + counter++ + return nil + }, + func(leader string, follower string) error { + require.Equal(t, "foo", leader) + require.Equal(t, "foo", follower) + equalCounter++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 0, counter) + require.Equal(t, 1, equalCounter) + } + }, + wantErr: false, + }, + { + name: "mix", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + counter := 0 + equalCount := 0 + calledWith := make([]string, 0) + sameCalledWith := make([]string, 0) + return streamutils.NewZipStreams[string, string]( + stream.Slice([]string{"1", "2", "4", "5", "8"}), + stream.Slice([]string{"2", "3", "4", "9"}), + func(s1 string) error { + counter++ + calledWith = append(calledWith, s1) + return nil + }, + func(leader string, follower string) error { + // Both fields should be the same + require.Equal(t, leader, follower) + sameCalledWith = append(sameCalledWith, leader) + equalCount++ + return nil + }, + strings.Compare, + ), func() { + require.Equal(t, 3, counter) + require.Equal(t, []string{"1", "5", "8"}, calledWith) + require.Equal(t, 2, equalCount) + require.Equal(t, []string{"2", "4"}, sameCalledWith) + } + }, + wantErr: false, + }, + { + name: "errors are propagated", + validate: func(t *testing.T) (*streamutils.ZipStreams[string, string], func()) { + return streamutils.NewZipStreams[string, string]( + stream.Slice([]string{"1", "2", "5", "8"}), + stream.Slice([]string{"2", "3", "9"}), + func(s1 string) error { + return trace.Errorf("something bad") + }, + func(leader string, follower string) error { + return nil + }, + strings.Compare, + ), func() { + } + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + z, validate := tt.validate(t) + err := z.Process() + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + validate() + }) + } +}