diff --git a/api/client/client.go b/api/client/client.go index 46ab883723ae2..37fd37fb969eb 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -752,6 +752,12 @@ func (c *Client) TrustClient() trustpb.TrustServiceClient { return trustpb.NewTrustServiceClient(c.conn) } +// EmbeddingClient returns an unadorned Embedding client, using the underlying +// Auth gRPC connection. +func (c *Client) EmbeddingClient() assist.AssistEmbeddingServiceClient { + return assist.NewAssistEmbeddingServiceClient(c.conn) +} + // Ping gets basic info about the auth server. func (c *Client) Ping(ctx context.Context) (proto.PingResponse, error) { rsp, err := c.grpc.Ping(ctx, &proto.PingRequest{}) @@ -3966,3 +3972,11 @@ func (c *Client) UpdateAssistantConversationInfo(ctx context.Context, in *assist } return nil } + +func (c *Client) GetAssistantEmbeddings(ctx context.Context, in *assist.GetAssistantEmbeddingsRequest) (*assist.GetAssistantEmbeddingsResponse, error) { + result, err := c.EmbeddingClient().GetAssistantEmbeddings(ctx, in) + if err != nil { + return nil, trail.FromGRPC(err) + } + return result, nil +} diff --git a/api/gen/proto/go/assist/v1/assist.pb.go b/api/gen/proto/go/assist/v1/assist.pb.go index cbc94ac60bcd3..a4d8f30d25dfd 100644 --- a/api/gen/proto/go/assist/v1/assist.pb.go +++ b/api/gen/proto/go/assist/v1/assist.pb.go @@ -765,6 +765,199 @@ func (x *DeleteAssistantConversationRequest) GetUsername() string { return "" } +// GetAssistantEmbeddingsRequest is a request to get embeddings. +type GetAssistantEmbeddingsRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // username is a username of the user who requested the embeddings. + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + // query is the query used for similarity search. + Query string `protobuf:"bytes,2,opt,name=query,proto3" json:"query,omitempty"` + // limit is the number of embeddings to return (also known as k). + Limit uint32 `protobuf:"varint,3,opt,name=limit,proto3" json:"limit,omitempty"` + // kind is the kind of embeddings to return (ex, node). + Kind string `protobuf:"bytes,4,opt,name=kind,proto3" json:"kind,omitempty"` +} + +func (x *GetAssistantEmbeddingsRequest) Reset() { + *x = GetAssistantEmbeddingsRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_teleport_assist_v1_assist_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetAssistantEmbeddingsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetAssistantEmbeddingsRequest) ProtoMessage() {} + +func (x *GetAssistantEmbeddingsRequest) ProtoReflect() protoreflect.Message { + mi := &file_teleport_assist_v1_assist_proto_msgTypes[13] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetAssistantEmbeddingsRequest.ProtoReflect.Descriptor instead. +func (*GetAssistantEmbeddingsRequest) Descriptor() ([]byte, []int) { + return file_teleport_assist_v1_assist_proto_rawDescGZIP(), []int{13} +} + +func (x *GetAssistantEmbeddingsRequest) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *GetAssistantEmbeddingsRequest) GetQuery() string { + if x != nil { + return x.Query + } + return "" +} + +func (x *GetAssistantEmbeddingsRequest) GetLimit() uint32 { + if x != nil { + return x.Limit + } + return 0 +} + +func (x *GetAssistantEmbeddingsRequest) GetKind() string { + if x != nil { + return x.Kind + } + return "" +} + +// EmbeddingDocument is a document with an embedding. +type EmbeddedDocument struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // id is the id of the document. + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + // content is the content of the document. + Content string `protobuf:"bytes,2,opt,name=content,proto3" json:"content,omitempty"` + // similarityScore is the similarity score of the document. + SimilarityScore float32 `protobuf:"fixed32,3,opt,name=similarity_score,json=similarityScore,proto3" json:"similarity_score,omitempty"` +} + +func (x *EmbeddedDocument) Reset() { + *x = EmbeddedDocument{} + if protoimpl.UnsafeEnabled { + mi := &file_teleport_assist_v1_assist_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmbeddedDocument) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbeddedDocument) ProtoMessage() {} + +func (x *EmbeddedDocument) ProtoReflect() protoreflect.Message { + mi := &file_teleport_assist_v1_assist_proto_msgTypes[14] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbeddedDocument.ProtoReflect.Descriptor instead. +func (*EmbeddedDocument) Descriptor() ([]byte, []int) { + return file_teleport_assist_v1_assist_proto_rawDescGZIP(), []int{14} +} + +func (x *EmbeddedDocument) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *EmbeddedDocument) GetContent() string { + if x != nil { + return x.Content + } + return "" +} + +func (x *EmbeddedDocument) GetSimilarityScore() float32 { + if x != nil { + return x.SimilarityScore + } + return 0 +} + +// GetAssistantEmbeddingsResponse is a response from the assistant service. +type GetAssistantEmbeddingsResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // embeddings is the list of embeddings. + // The list is sorted by similarity score in descending order. + Embeddings []*EmbeddedDocument `protobuf:"bytes,1,rep,name=embeddings,proto3" json:"embeddings,omitempty"` +} + +func (x *GetAssistantEmbeddingsResponse) Reset() { + *x = GetAssistantEmbeddingsResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_teleport_assist_v1_assist_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GetAssistantEmbeddingsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetAssistantEmbeddingsResponse) ProtoMessage() {} + +func (x *GetAssistantEmbeddingsResponse) ProtoReflect() protoreflect.Message { + mi := &file_teleport_assist_v1_assist_proto_msgTypes[15] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetAssistantEmbeddingsResponse.ProtoReflect.Descriptor instead. +func (*GetAssistantEmbeddingsResponse) Descriptor() ([]byte, []int) { + return file_teleport_assist_v1_assist_proto_rawDescGZIP(), []int{15} +} + +func (x *GetAssistantEmbeddingsResponse) GetEmbeddings() []*EmbeddedDocument { + if x != nil { + return x.Embeddings + } + return nil +} + var File_teleport_assist_v1_assist_proto protoreflect.FileDescriptor var file_teleport_assist_v1_assist_proto_rawDesc = []byte{ @@ -856,66 +1049,96 @@ var file_teleport_assist_v1_assist_proto_rawDesc = []byte{ 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, - 0x61, 0x6d, 0x65, 0x32, 0xdd, 0x06, 0x0a, 0x0d, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x53, 0x65, - 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x8e, 0x01, 0x0a, 0x1b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x36, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, - 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x37, 0x2e, - 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, - 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, - 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x88, 0x01, 0x0a, 0x19, 0x47, 0x65, 0x74, 0x41, 0x73, + 0x61, 0x6d, 0x65, 0x22, 0x7b, 0x0a, 0x1d, 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, + 0x61, 0x6e, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, + 0x12, 0x14, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, + 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, + 0x22, 0x67, 0x0a, 0x10, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x44, 0x6f, 0x63, 0x75, + 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x02, 0x69, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x12, 0x29, + 0x0a, 0x10, 0x73, 0x69, 0x6d, 0x69, 0x6c, 0x61, 0x72, 0x69, 0x74, 0x79, 0x5f, 0x73, 0x63, 0x6f, + 0x72, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0f, 0x73, 0x69, 0x6d, 0x69, 0x6c, 0x61, + 0x72, 0x69, 0x74, 0x79, 0x53, 0x63, 0x6f, 0x72, 0x65, 0x22, 0x66, 0x0a, 0x1e, 0x47, 0x65, 0x74, + 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, + 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0a, 0x65, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x24, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, + 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x44, 0x6f, 0x63, + 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, + 0x73, 0x32, 0xdd, 0x06, 0x0a, 0x0d, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, + 0x69, 0x63, 0x65, 0x12, 0x8e, 0x01, 0x0a, 0x1b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x34, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, - 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, - 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x35, 0x2e, 0x74, 0x65, 0x6c, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, - 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, - 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x6d, 0x0a, 0x1b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, - 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x36, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, - 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, + 0x69, 0x6f, 0x6e, 0x12, 0x36, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, + 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, + 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x37, 0x2e, 0x74, 0x65, + 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, + 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, + 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x88, 0x01, 0x0a, 0x19, 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x12, 0x79, 0x0a, 0x14, 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x12, 0x2f, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, + 0x6e, 0x73, 0x12, 0x34, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, + 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, + 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x35, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, - 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30, 0x2e, 0x74, 0x65, 0x6c, 0x65, + 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, + 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x6d, 0x0a, 0x1b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, + 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x36, + 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, + 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, + 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x79, + 0x0a, 0x14, 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x12, 0x2f, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, + 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x41, + 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x30, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, + 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, + 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x63, 0x0a, 0x16, 0x43, 0x72, 0x65, + 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x12, 0x31, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, + 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, + 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x75, + 0x0a, 0x1f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, + 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x6e, 0x66, + 0x6f, 0x12, 0x3a, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, + 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, + 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x6a, 0x0a, 0x0f, 0x49, 0x73, 0x41, 0x73, 0x73, 0x69, 0x73, + 0x74, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, + 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x73, + 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2b, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x73, 0x41, 0x73, 0x73, 0x69, + 0x73, 0x74, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x32, 0x99, 0x01, 0x0a, 0x16, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x45, 0x6d, 0x62, 0x65, + 0x64, 0x64, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x7f, 0x0a, 0x16, + 0x47, 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x45, 0x6d, 0x62, 0x65, + 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x12, 0x31, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, + 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x41, + 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, + 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x32, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x47, - 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x63, 0x0a, 0x16, 0x43, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x31, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x12, 0x75, 0x0a, 0x1f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, - 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, - 0x6e, 0x66, 0x6f, 0x12, 0x3a, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, - 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x41, - 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x6a, 0x0a, 0x0f, 0x49, 0x73, 0x41, 0x73, 0x73, - 0x69, 0x73, 0x74, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x2a, 0x2e, 0x74, 0x65, 0x6c, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, - 0x49, 0x73, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x2b, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, - 0x74, 0x2e, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x73, 0x41, 0x73, - 0x73, 0x69, 0x73, 0x74, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x42, 0x45, 0x5a, 0x43, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2f, - 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x67, 0x65, 0x6e, - 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, - 0x2f, 0x76, 0x31, 0x3b, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x65, 0x74, 0x41, 0x73, 0x73, 0x69, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x45, 0x6d, 0x62, 0x65, 0x64, + 0x64, 0x69, 0x6e, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x45, 0x5a, + 0x43, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, + 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, + 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2f, 0x67, 0x6f, 0x2f, 0x61, 0x73, 0x73, 0x69, 0x73, 0x74, 0x2f, 0x76, 0x31, 0x3b, 0x61, 0x73, + 0x73, 0x69, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -930,7 +1153,7 @@ func file_teleport_assist_v1_assist_proto_rawDescGZIP() []byte { return file_teleport_assist_v1_assist_proto_rawDescData } -var file_teleport_assist_v1_assist_proto_msgTypes = make([]protoimpl.MessageInfo, 13) +var file_teleport_assist_v1_assist_proto_msgTypes = make([]protoimpl.MessageInfo, 16) var file_teleport_assist_v1_assist_proto_goTypes = []interface{}{ (*GetAssistantMessagesRequest)(nil), // 0: teleport.assist.v1.GetAssistantMessagesRequest (*AssistantMessage)(nil), // 1: teleport.assist.v1.AssistantMessage @@ -945,35 +1168,41 @@ var file_teleport_assist_v1_assist_proto_goTypes = []interface{}{ (*IsAssistEnabledRequest)(nil), // 10: teleport.assist.v1.IsAssistEnabledRequest (*IsAssistEnabledResponse)(nil), // 11: teleport.assist.v1.IsAssistEnabledResponse (*DeleteAssistantConversationRequest)(nil), // 12: teleport.assist.v1.DeleteAssistantConversationRequest - (*timestamppb.Timestamp)(nil), // 13: google.protobuf.Timestamp - (*emptypb.Empty)(nil), // 14: google.protobuf.Empty + (*GetAssistantEmbeddingsRequest)(nil), // 13: teleport.assist.v1.GetAssistantEmbeddingsRequest + (*EmbeddedDocument)(nil), // 14: teleport.assist.v1.EmbeddedDocument + (*GetAssistantEmbeddingsResponse)(nil), // 15: teleport.assist.v1.GetAssistantEmbeddingsResponse + (*timestamppb.Timestamp)(nil), // 16: google.protobuf.Timestamp + (*emptypb.Empty)(nil), // 17: google.protobuf.Empty } var file_teleport_assist_v1_assist_proto_depIdxs = []int32{ - 13, // 0: teleport.assist.v1.AssistantMessage.created_time:type_name -> google.protobuf.Timestamp + 16, // 0: teleport.assist.v1.AssistantMessage.created_time:type_name -> google.protobuf.Timestamp 1, // 1: teleport.assist.v1.CreateAssistantMessageRequest.message:type_name -> teleport.assist.v1.AssistantMessage 1, // 2: teleport.assist.v1.GetAssistantMessagesResponse.messages:type_name -> teleport.assist.v1.AssistantMessage - 13, // 3: teleport.assist.v1.ConversationInfo.created_time:type_name -> google.protobuf.Timestamp + 16, // 3: teleport.assist.v1.ConversationInfo.created_time:type_name -> google.protobuf.Timestamp 5, // 4: teleport.assist.v1.GetAssistantConversationsResponse.conversations:type_name -> teleport.assist.v1.ConversationInfo - 13, // 5: teleport.assist.v1.CreateAssistantConversationRequest.created_time:type_name -> google.protobuf.Timestamp - 7, // 6: teleport.assist.v1.AssistService.CreateAssistantConversation:input_type -> teleport.assist.v1.CreateAssistantConversationRequest - 4, // 7: teleport.assist.v1.AssistService.GetAssistantConversations:input_type -> teleport.assist.v1.GetAssistantConversationsRequest - 12, // 8: teleport.assist.v1.AssistService.DeleteAssistantConversation:input_type -> teleport.assist.v1.DeleteAssistantConversationRequest - 0, // 9: teleport.assist.v1.AssistService.GetAssistantMessages:input_type -> teleport.assist.v1.GetAssistantMessagesRequest - 2, // 10: teleport.assist.v1.AssistService.CreateAssistantMessage:input_type -> teleport.assist.v1.CreateAssistantMessageRequest - 9, // 11: teleport.assist.v1.AssistService.UpdateAssistantConversationInfo:input_type -> teleport.assist.v1.UpdateAssistantConversationInfoRequest - 10, // 12: teleport.assist.v1.AssistService.IsAssistEnabled:input_type -> teleport.assist.v1.IsAssistEnabledRequest - 8, // 13: teleport.assist.v1.AssistService.CreateAssistantConversation:output_type -> teleport.assist.v1.CreateAssistantConversationResponse - 6, // 14: teleport.assist.v1.AssistService.GetAssistantConversations:output_type -> teleport.assist.v1.GetAssistantConversationsResponse - 14, // 15: teleport.assist.v1.AssistService.DeleteAssistantConversation:output_type -> google.protobuf.Empty - 3, // 16: teleport.assist.v1.AssistService.GetAssistantMessages:output_type -> teleport.assist.v1.GetAssistantMessagesResponse - 14, // 17: teleport.assist.v1.AssistService.CreateAssistantMessage:output_type -> google.protobuf.Empty - 14, // 18: teleport.assist.v1.AssistService.UpdateAssistantConversationInfo:output_type -> google.protobuf.Empty - 11, // 19: teleport.assist.v1.AssistService.IsAssistEnabled:output_type -> teleport.assist.v1.IsAssistEnabledResponse - 13, // [13:20] is the sub-list for method output_type - 6, // [6:13] is the sub-list for method input_type - 6, // [6:6] is the sub-list for extension type_name - 6, // [6:6] is the sub-list for extension extendee - 0, // [0:6] is the sub-list for field type_name + 16, // 5: teleport.assist.v1.CreateAssistantConversationRequest.created_time:type_name -> google.protobuf.Timestamp + 14, // 6: teleport.assist.v1.GetAssistantEmbeddingsResponse.embeddings:type_name -> teleport.assist.v1.EmbeddedDocument + 7, // 7: teleport.assist.v1.AssistService.CreateAssistantConversation:input_type -> teleport.assist.v1.CreateAssistantConversationRequest + 4, // 8: teleport.assist.v1.AssistService.GetAssistantConversations:input_type -> teleport.assist.v1.GetAssistantConversationsRequest + 12, // 9: teleport.assist.v1.AssistService.DeleteAssistantConversation:input_type -> teleport.assist.v1.DeleteAssistantConversationRequest + 0, // 10: teleport.assist.v1.AssistService.GetAssistantMessages:input_type -> teleport.assist.v1.GetAssistantMessagesRequest + 2, // 11: teleport.assist.v1.AssistService.CreateAssistantMessage:input_type -> teleport.assist.v1.CreateAssistantMessageRequest + 9, // 12: teleport.assist.v1.AssistService.UpdateAssistantConversationInfo:input_type -> teleport.assist.v1.UpdateAssistantConversationInfoRequest + 10, // 13: teleport.assist.v1.AssistService.IsAssistEnabled:input_type -> teleport.assist.v1.IsAssistEnabledRequest + 13, // 14: teleport.assist.v1.AssistEmbeddingService.GetAssistantEmbeddings:input_type -> teleport.assist.v1.GetAssistantEmbeddingsRequest + 8, // 15: teleport.assist.v1.AssistService.CreateAssistantConversation:output_type -> teleport.assist.v1.CreateAssistantConversationResponse + 6, // 16: teleport.assist.v1.AssistService.GetAssistantConversations:output_type -> teleport.assist.v1.GetAssistantConversationsResponse + 17, // 17: teleport.assist.v1.AssistService.DeleteAssistantConversation:output_type -> google.protobuf.Empty + 3, // 18: teleport.assist.v1.AssistService.GetAssistantMessages:output_type -> teleport.assist.v1.GetAssistantMessagesResponse + 17, // 19: teleport.assist.v1.AssistService.CreateAssistantMessage:output_type -> google.protobuf.Empty + 17, // 20: teleport.assist.v1.AssistService.UpdateAssistantConversationInfo:output_type -> google.protobuf.Empty + 11, // 21: teleport.assist.v1.AssistService.IsAssistEnabled:output_type -> teleport.assist.v1.IsAssistEnabledResponse + 15, // 22: teleport.assist.v1.AssistEmbeddingService.GetAssistantEmbeddings:output_type -> teleport.assist.v1.GetAssistantEmbeddingsResponse + 15, // [15:23] is the sub-list for method output_type + 7, // [7:15] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_teleport_assist_v1_assist_proto_init() } @@ -1138,6 +1367,42 @@ func file_teleport_assist_v1_assist_proto_init() { return nil } } + file_teleport_assist_v1_assist_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetAssistantEmbeddingsRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_teleport_assist_v1_assist_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmbeddedDocument); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_teleport_assist_v1_assist_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GetAssistantEmbeddingsResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -1145,9 +1410,9 @@ func file_teleport_assist_v1_assist_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_teleport_assist_v1_assist_proto_rawDesc, NumEnums: 0, - NumMessages: 13, + NumMessages: 16, NumExtensions: 0, - NumServices: 1, + NumServices: 2, }, GoTypes: file_teleport_assist_v1_assist_proto_goTypes, DependencyIndexes: file_teleport_assist_v1_assist_proto_depIdxs, diff --git a/api/gen/proto/go/assist/v1/assist_grpc.pb.go b/api/gen/proto/go/assist/v1/assist_grpc.pb.go index ccf7a2eef4df0..cb399e7316958 100644 --- a/api/gen/proto/go/assist/v1/assist_grpc.pb.go +++ b/api/gen/proto/go/assist/v1/assist_grpc.pb.go @@ -358,3 +358,96 @@ var AssistService_ServiceDesc = grpc.ServiceDesc{ Streams: []grpc.StreamDesc{}, Metadata: "teleport/assist/v1/assist.proto", } + +const ( + AssistEmbeddingService_GetAssistantEmbeddings_FullMethodName = "/teleport.assist.v1.AssistEmbeddingService/GetAssistantEmbeddings" +) + +// AssistEmbeddingServiceClient is the client API for AssistEmbeddingService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type AssistEmbeddingServiceClient interface { + // AssistantGetEmbeddings returns the embeddings for the given query. + GetAssistantEmbeddings(ctx context.Context, in *GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*GetAssistantEmbeddingsResponse, error) +} + +type assistEmbeddingServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewAssistEmbeddingServiceClient(cc grpc.ClientConnInterface) AssistEmbeddingServiceClient { + return &assistEmbeddingServiceClient{cc} +} + +func (c *assistEmbeddingServiceClient) GetAssistantEmbeddings(ctx context.Context, in *GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*GetAssistantEmbeddingsResponse, error) { + out := new(GetAssistantEmbeddingsResponse) + err := c.cc.Invoke(ctx, AssistEmbeddingService_GetAssistantEmbeddings_FullMethodName, in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// AssistEmbeddingServiceServer is the server API for AssistEmbeddingService service. +// All implementations must embed UnimplementedAssistEmbeddingServiceServer +// for forward compatibility +type AssistEmbeddingServiceServer interface { + // AssistantGetEmbeddings returns the embeddings for the given query. + GetAssistantEmbeddings(context.Context, *GetAssistantEmbeddingsRequest) (*GetAssistantEmbeddingsResponse, error) + mustEmbedUnimplementedAssistEmbeddingServiceServer() +} + +// UnimplementedAssistEmbeddingServiceServer must be embedded to have forward compatible implementations. +type UnimplementedAssistEmbeddingServiceServer struct { +} + +func (UnimplementedAssistEmbeddingServiceServer) GetAssistantEmbeddings(context.Context, *GetAssistantEmbeddingsRequest) (*GetAssistantEmbeddingsResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetAssistantEmbeddings not implemented") +} +func (UnimplementedAssistEmbeddingServiceServer) mustEmbedUnimplementedAssistEmbeddingServiceServer() { +} + +// UnsafeAssistEmbeddingServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to AssistEmbeddingServiceServer will +// result in compilation errors. +type UnsafeAssistEmbeddingServiceServer interface { + mustEmbedUnimplementedAssistEmbeddingServiceServer() +} + +func RegisterAssistEmbeddingServiceServer(s grpc.ServiceRegistrar, srv AssistEmbeddingServiceServer) { + s.RegisterService(&AssistEmbeddingService_ServiceDesc, srv) +} + +func _AssistEmbeddingService_GetAssistantEmbeddings_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetAssistantEmbeddingsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AssistEmbeddingServiceServer).GetAssistantEmbeddings(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AssistEmbeddingService_GetAssistantEmbeddings_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AssistEmbeddingServiceServer).GetAssistantEmbeddings(ctx, req.(*GetAssistantEmbeddingsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// AssistEmbeddingService_ServiceDesc is the grpc.ServiceDesc for AssistEmbeddingService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var AssistEmbeddingService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "teleport.assist.v1.AssistEmbeddingService", + HandlerType: (*AssistEmbeddingServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetAssistantEmbeddings", + Handler: _AssistEmbeddingService_GetAssistantEmbeddings_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "teleport/assist/v1/assist.proto", +} diff --git a/api/proto/teleport/assist/v1/assist.proto b/api/proto/teleport/assist/v1/assist.proto index d3065ff35214c..6157198d791c0 100644 --- a/api/proto/teleport/assist/v1/assist.proto +++ b/api/proto/teleport/assist/v1/assist.proto @@ -121,6 +121,35 @@ message DeleteAssistantConversationRequest { string username = 2; } +// GetAssistantEmbeddingsRequest is a request to get embeddings. +message GetAssistantEmbeddingsRequest { + // username is a username of the user who requested the embeddings. + string username = 1; + // query is the query used for similarity search. + string query = 2; + // limit is the number of embeddings to return (also known as k). + uint32 limit = 3; + // kind is the kind of embeddings to return (ex, node). + string kind = 4; +} + +// EmbeddingDocument is a document with an embedding. +message EmbeddedDocument { + // id is the id of the document. + string id = 1; + // content is the content of the document. + string content = 2; + // similarityScore is the similarity score of the document. + float similarity_score = 3; +} + +// GetAssistantEmbeddingsResponse is a response from the assistant service. +message GetAssistantEmbeddingsResponse { + // embeddings is the list of embeddings. + // The list is sorted by similarity score in descending order. + repeated EmbeddedDocument embeddings = 1; +} + // AssistService is a service that provides an ability to communicate with the Teleport Assist. service AssistService { // CreateNewConversation creates a new conversation and returns the UUID of it. @@ -144,3 +173,9 @@ service AssistService { // IsAssistEnabled returns true if the assist is enabled or not on the auth level. rpc IsAssistEnabled(IsAssistEnabledRequest) returns (IsAssistEnabledResponse); } + +// AssistEmbeddingService is a service that provides an ability to communicate with the Assist Embedding service. +service AssistEmbeddingService { + // AssistantGetEmbeddings returns the embeddings for the given query. + rpc GetAssistantEmbeddings(GetAssistantEmbeddingsRequest) returns (GetAssistantEmbeddingsResponse); +} diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 617151d40bc75..055e4f406d8cf 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -31,6 +31,7 @@ type Chat struct { client *Client messages []openai.ChatCompletionMessage tokenizer tokenizer.Codec + agent *model.Agent } // Insert inserts a message into the conversation. Returns the index of the message. @@ -70,10 +71,15 @@ func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) { Content: userInput, } - response, err := model.AssistAgent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage) + response, err := chat.agent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage) if err != nil { return nil, trace.Wrap(err) } return response, nil } + +// Clear clears the conversation. +func (chat *Chat) Clear() { + chat.messages = []openai.ChatCompletionMessage{} +} diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go index f71b20b76d5e6..e571be3279b57 100644 --- a/lib/ai/chat_test.go +++ b/lib/ai/chat_test.go @@ -49,7 +49,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Hello", }, }, - want: 632, + want: 703, }, { name: "system and user messages", @@ -63,7 +63,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Hi LLM.", }, }, - want: 640, + want: 711, }, { name: "tokenize our prompt", @@ -77,7 +77,7 @@ func TestChat_PromptTokens(t *testing.T) { Content: "Show me free disk space on localhost node.", }, }, - want: 843, + want: 914, }, } @@ -96,7 +96,7 @@ func TestChat_PromptTokens(t *testing.T) { cfg.BaseURL = server.URL + "/v1" client := NewClientFromConfig(cfg) - chat := client.NewChat("Bob") + chat := client.NewChat(nil, "Bob") for _, message := range tt.messages { chat.Insert(message.Role, message.Content) @@ -128,7 +128,7 @@ func TestChat_Complete(t *testing.T) { cfg.BaseURL = server.URL + "/v1" client := NewClientFromConfig(cfg) - chat := client.NewChat("Bob") + chat := client.NewChat(nil, "Bob") t.Run("initial message", func(t *testing.T) { msgAny, err := chat.Complete(context.Background(), "Hello") diff --git a/lib/ai/client.go b/lib/ai/client.go index 6ec831f1b2cf4..aa14ab64c20ff 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -23,6 +23,7 @@ import ( "github.com/sashabaranov/go-openai" "github.com/tiktoken-go/tokenizer/codec" + "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" "github.com/gravitational/teleport/lib/ai/model" ) @@ -43,7 +44,8 @@ func NewClientFromConfig(config openai.ClientConfig) *Client { // NewChat creates a new chat. The username is set in the conversation context, // so that the AI can use it to personalize the conversation. -func (client *Client) NewChat(username string) *Chat { +// embeddingServiceClient is used to get the embeddings from the Auth Server. +func (client *Client) NewChat(embeddingServiceClient assist.AssistEmbeddingServiceClient, username string) *Chat { return &Chat{ client: client, messages: []openai.ChatCompletionMessage{ @@ -55,6 +57,7 @@ func (client *Client) NewChat(username string) *Chat { // Initialize a tokenizer for prompt token accounting. // Cl100k is used by GPT-3 and GPT-4. tokenizer: codec.NewCl100kBase(), + agent: model.NewAgent(embeddingServiceClient, username), } } diff --git a/lib/ai/embedding.go b/lib/ai/embedding.go index 4e2d6d5e37938..c145024256593 100644 --- a/lib/ai/embedding.go +++ b/lib/ai/embedding.go @@ -17,7 +17,6 @@ package ai import ( "context" "crypto/sha256" - "time" "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" @@ -28,9 +27,6 @@ import ( const ( maxOpenAIEmbeddingsPerRequest = 1000 - // EmbeddingPeriod is the time between two embedding routines. - // A seventh jitter is applied on the period. - EmbeddingPeriod = time.Hour ) // EmbeddingHash is the hash function that should be used to compute embedding @@ -92,11 +88,11 @@ func NewEmbedding(kind, id string, vector Vector64, hash Sha256Hash) *Embedding } // Embedder is implemented for batch text embedding. Embedding can happen in -// place (with an embedding model for example) or be done by a remote embedding +// place (with an embedding model, for example) or be done by a remote embedding // service like OpenAI. type Embedder interface { // ComputeEmbeddings computes the embeddings of multiple strings. - // The embedding list follows the input order (e.g. result[i] is the + // The embedding list follows the input order (e.g., result[i] is the // embedding of input[i]). ComputeEmbeddings(ctx context.Context, input []string) ([]Vector64, error) } diff --git a/lib/ai/embeddings.go b/lib/ai/embeddings.go index d05ee1e0b4ffd..a215b2b8813dd 100644 --- a/lib/ai/embeddings.go +++ b/lib/ai/embeddings.go @@ -34,6 +34,9 @@ import ( streamutils "github.com/gravitational/teleport/lib/utils/stream" ) +// maxEmbeddingAPISize is the maximum number of entities that can be embedded in a single API call. +const maxEmbeddingAPISize = 1000 + // Embeddings implements the minimal interface used by the Embedding processor. type Embeddings interface { // GetEmbeddings returns all embeddings for a given kind. @@ -81,17 +84,18 @@ func EmbeddingHashMatches(embedding *Embedding, hash Sha256Hash) bool { return *(*Sha256Hash)(embedding.EmbeddedHash) == hash } -// serializeNode converts a type.Server into text ready to be fed to an +// SerializeNode converts a type.Server into text ready to be fed to an // embedding model. The YAML serialization function was chosen over JSON and // CSV as it provided better results. -func serializeNode(node types.Server) ([]byte, error) { +func SerializeNode(node types.Server) ([]byte, error) { a := struct { Name string `yaml:"name"` Kind string `yaml:"kind"` SubKind string `yaml:"subkind"` Labels map[string]string `yaml:"labels"` }{ - Name: node.GetName(), + // Create artificial Name file for the node "name". Using node.GetName() as Name seems to confuse the model. + Name: node.GetHostname(), Kind: types.KindNode, SubKind: node.GetSubKind(), Labels: node.GetAllLabels(), @@ -145,31 +149,34 @@ func (b *BatchReducer[T, V]) Finalize(ctx context.Context) (V, error) { // EmbeddingProcessorConfig is the configuration for EmbeddingProcessor. type EmbeddingProcessorConfig struct { - AIClient Embedder - EmbeddingSrv Embeddings - NodeSrv NodesStreamGetter - Log logrus.FieldLogger - Jitter retryutils.Jitter + AIClient Embedder + EmbeddingSrv Embeddings + EmbeddingsRetriever *SimpleRetriever + NodeSrv NodesStreamGetter + Log logrus.FieldLogger + Jitter retryutils.Jitter } // EmbeddingProcessor is responsible for processing nodes, generating embeddings -// and storing their the embeddings in the backend. +// and storing their embeddings in the backend. type EmbeddingProcessor struct { - aiClient Embedder - embeddingSrv Embeddings - nodeSrv NodesStreamGetter - log logrus.FieldLogger - jitter retryutils.Jitter + aiClient Embedder + embeddingSrv Embeddings + embeddingsRetriever *SimpleRetriever + nodeSrv NodesStreamGetter + log logrus.FieldLogger + jitter retryutils.Jitter } // NewEmbeddingProcessor returns a new EmbeddingProcessor. func NewEmbeddingProcessor(cfg *EmbeddingProcessorConfig) *EmbeddingProcessor { return &EmbeddingProcessor{ - aiClient: cfg.AIClient, - embeddingSrv: cfg.EmbeddingSrv, - nodeSrv: cfg.NodeSrv, - log: cfg.Log, - jitter: cfg.Jitter, + aiClient: cfg.AIClient, + embeddingSrv: cfg.EmbeddingSrv, + embeddingsRetriever: cfg.EmbeddingsRetriever, + nodeSrv: cfg.NodeSrv, + log: cfg.Log, + jitter: cfg.Jitter, } } @@ -205,11 +212,16 @@ func (e *EmbeddingProcessor) mapProcessFn(ctx context.Context, data []*nodeStrin } // Run runs the EmbeddingProcessor. -func (e *EmbeddingProcessor) Run(ctx context.Context, period time.Duration) error { +func (e *EmbeddingProcessor) Run(ctx context.Context, initialDelay, period time.Duration) error { + initTimer := time.NewTimer(initialDelay) for { select { case <-ctx.Done(): return ctx.Err() + case <-initTimer.C: + // Stop the timer after the initial delay. + initTimer.Stop() + e.process(ctx) case <-time.After(e.jitter(period)): e.process(ctx) } @@ -218,9 +230,12 @@ func (e *EmbeddingProcessor) Run(ctx context.Context, period time.Duration) erro func (e *EmbeddingProcessor) process(ctx context.Context) { batch := NewBatchReducer[*nodeStringPair, []*Embedding](e.mapProcessFn, - 1000, // Max batch size allowed by OpenAI API, + maxEmbeddingAPISize, // Max batch size allowed by OpenAI API, ) + e.log.Debugf("embedding processor started") + defer e.log.Debugf("embedding processor finished") + embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode) nodesStream := e.nodeSrv.GetNodeStream(ctx, defaults.Namespace) @@ -229,7 +244,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { embeddingsStream, // On new node callback. Add the node to the batch. func(node types.Server) error { - nodeData, err := serializeNode(node) + nodeData, err := SerializeNode(node) if err != nil { return trace.Wrap(err) } @@ -246,7 +261,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { // On equal node callback. Check if the node's embedding hash matches // the one in the backend. If not, add the node to the batch. func(node types.Server, embedding *Embedding) error { - nodeData, err := serializeNode(node) + nodeData, err := SerializeNode(node) if err != nil { return trace.Wrap(err) } @@ -263,7 +278,7 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { } return nil }, - // On compare keys callback. Compare the keys for iterration. + // On compare keys callback. Compare the keys for iteration. func(node types.Server, embeddings *Embedding) int { if node.GetName() == embeddings.GetName() { return 0 @@ -286,7 +301,35 @@ func (e *EmbeddingProcessor) process(ctx context.Context) { if err := e.upsertEmbeddings(ctx, vectors); err != nil { e.log.Warnf("Failed to upsert embeddings: %v", err) + } + + if err := e.updateMemIndex(ctx); err != nil { + e.log.Warnf("Failed to update memory index: %v", err) + } +} + +// updateMemIndex is a helper function that updates the in-memory index with the +// latest embeddings. The new index is created and then swapped with the old one. +func (e *EmbeddingProcessor) updateMemIndex(ctx context.Context) error { + embeddingsIndex := NewSimpleRetriever() + embeddingsStream := e.embeddingSrv.GetEmbeddings(ctx, types.KindNode) + + for embeddingsStream.Next() { + embedding := embeddingsStream.Item() + if !embeddingsIndex.Insert(embedding.GetEmbeddedID(), embedding) { + e.log.Warnf("Embeddings index is full, some resources can be missing") + break + } + } + + if err := embeddingsStream.Done(); err != nil { + return trace.Wrap(err) + } + + e.embeddingsRetriever.Swap(embeddingsIndex) + + return nil } // upsertEmbeddings is a helper function that upserts the embeddings into the backend. diff --git a/lib/ai/embeddings_test.go b/lib/ai/embeddings_test.go index 2bd1e4b12a0af..c2c369ad05c22 100644 --- a/lib/ai/embeddings_test.go +++ b/lib/ai/embeddings_test.go @@ -65,7 +65,7 @@ func TestNodeEmbeddingGeneration(t *testing.T) { clock := clockwork.NewFakeClock() // Test setup: crate a backend, presence service, the node watcher and - // the embeddings service + // the embeddings service. bk, err := memory.New(memory.Config{ Context: ctx, Clock: clock, @@ -77,16 +77,17 @@ func TestNodeEmbeddingGeneration(t *testing.T) { embeddings := local.NewEmbeddingsService(bk) processor := ai.NewEmbeddingProcessor(&ai.EmbeddingProcessorConfig{ - AIClient: &embedder, - EmbeddingSrv: embeddings, - NodeSrv: presence, - Log: utils.NewLoggerForTests(), - Jitter: retryutils.NewSeventhJitter(), + AIClient: &embedder, + EmbeddingSrv: embeddings, + EmbeddingsRetriever: ai.NewSimpleRetriever(), + NodeSrv: presence, + Log: utils.NewLoggerForTests(), + Jitter: retryutils.NewSeventhJitter(), }) done := make(chan struct{}) go func() { - err := processor.Run(ctx, 100*time.Millisecond) + err := processor.Run(ctx, 100*time.Millisecond, time.Second) assert.ErrorIs(t, context.Canceled, err) close(done) }() diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go index f142f7f70c14e..b766a47b5ddf8 100644 --- a/lib/ai/model/agent.go +++ b/lib/ai/model/agent.go @@ -25,6 +25,8 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" ) const ( @@ -34,11 +36,17 @@ const ( maxElapsedTime = 5 * time.Minute ) -// AssistAgent is a global instance of the Assist agent which defines the model responsible for the Assist feature. -var AssistAgent = &Agent{ - tools: []Tool{ - &commandExecutionTool{}, - }, +// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature. +func NewAgent(assistClient assist.AssistEmbeddingServiceClient, username string) *Agent { + return &Agent{ + tools: []Tool{ + &commandExecutionTool{}, + &embeddingRetrievalTool{ + assistClient: assistClient, + currentUser: username, + }, + }, + } } // Agent is a model storing static state which defines some properties of the chat model. @@ -208,7 +216,11 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState) (stepOu return stepOutput{finish: &agentFinish{output: completion}}, nil } - return stepOutput{}, trace.NotImplemented("assist does not support non command execution tools yet") + runOut, err := tool.Run(ctx, action.input) + if err != nil { + return stepOutput{}, trace.Wrap(err) + } + return stepOutput{action: action, observation: runOut}, nil } func (a *Agent) plan(ctx context.Context, state *executionState) (*agentAction, *agentFinish, error) { diff --git a/lib/ai/model/tool.go b/lib/ai/model/tool.go index ab3c444cd2edb..93649f7cb6d7d 100644 --- a/lib/ai/model/tool.go +++ b/lib/ai/model/tool.go @@ -19,8 +19,13 @@ package model import ( "context" "fmt" + "strings" "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/api/types" ) // Tool is an interface that allows the agent to interact with the outside world. @@ -61,7 +66,7 @@ The input must be a JSON object with the following schema: `, "```", "```") } -func (c *commandExecutionTool) Run(ctx context.Context, input string) (string, error) { +func (c *commandExecutionTool) Run(_ context.Context, _ string) (string, error) { // This is stubbed because commandExecutionTool is handled specially. // This is because execution of this tool breaks the loop and returns a command suggestion to the user. // It is still handled as a tool because testing has shown that the LLM behaves better when it is treated as a tool. @@ -94,3 +99,80 @@ func (*commandExecutionTool) parseInput(input string) (*commandExecutionToolInpu return &output, nil } + +type embeddingRetrievalTool struct { + assistClient assist.AssistEmbeddingServiceClient + currentUser string +} + +type embeddingRetrievalToolInput struct { + Question string `json:"question"` +} + +func (e *embeddingRetrievalTool) Run(ctx context.Context, input string) (string, error) { + inputCmd, outErr := e.parseInput(input) + if outErr == nil { + // If we failed to parse the input, we can still send the payload for embedding retrieval. + // In most cases, we will still get some sensible results. + // If we parsed the input successfully, we should use the parsed input instead. + input = inputCmd.Question + } + log.Tracef("embedding retrieval input: %v", input) + + resp, err := e.assistClient.GetAssistantEmbeddings(ctx, &assist.GetAssistantEmbeddingsRequest{ + Username: e.currentUser, + Kind: types.KindNode, // currently only node embeddings are supported + Limit: 10, + Query: input, + }) + if err != nil { + return "", trace.Wrap(err) + } + + sb := strings.Builder{} + for _, embedding := range resp.Embeddings { + sb.WriteString(embedding.Content) + sb.WriteString("\n") + } + + log.Tracef("embedding retrieval: %v", sb.String()) + + if sb.Len() == 0 { + // Either no nodes are connected, embedding process hasn't started yet, or + // the user doesn't have access to any resources. + return "Didn't find any nodes matching your query", nil + } + + return sb.String(), nil +} + +func (e *embeddingRetrievalTool) Name() string { + return "Nodes names and labels retrieval" +} + +func (e *embeddingRetrievalTool) Description() string { + return fmt.Sprintf(`Ask about existing remote hosts to fetch node names or/and set of labels. Use this capability instead of guessing the names and labels. +The input must be a JSON object with the following schema: +%vjson +{ + "question": string \\ Question about the available remote hosts +} +%v +`, "```", "```") +} + +func (*embeddingRetrievalTool) parseInput(input string) (*embeddingRetrievalToolInput, *invalidOutputError) { + output, err := parseJSONFromModel[embeddingRetrievalToolInput](input) + if err != nil { + return nil, err + } + + if len(output.Question) == 0 { + return nil, &invalidOutputError{ + coarse: "embedding retrieval: missing question", + detail: "question must be non-empty", + } + } + + return &output, nil +} diff --git a/lib/ai/simpleretriever.go b/lib/ai/simpleretriever.go new file mode 100644 index 0000000000000..16745f82579fd --- /dev/null +++ b/lib/ai/simpleretriever.go @@ -0,0 +1,121 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai + +import ( + "sort" + "sync" +) + +// SimpleRetriever is a simple implementation of embeddings retriever. +// It stores all the embeddings in memory and retrieves the k nearest neighbors +// by iterating over all the embeddings. Do not use for large datasets. +type SimpleRetriever struct { + embeddings map[string]*Embedding + maxSize int + mtx sync.Mutex +} + +func NewSimpleRetriever() *SimpleRetriever { + return &SimpleRetriever{ + embeddings: make(map[string]*Embedding), + maxSize: 1_000, // keep the number low to avoid OOM + } +} + +// Insert adds the embedding to the retriever. If the retriever is full, the +// embedding is not added and false is returned. +func (r *SimpleRetriever) Insert(id string, embedding *Embedding) bool { + r.mtx.Lock() + defer r.mtx.Unlock() + if len(r.embeddings) >= r.maxSize { + return false + } + r.embeddings[id] = embedding + return true +} + +// Remove removes the embedding from the retriever by ID. +func (r *SimpleRetriever) Remove(id string) { + r.mtx.Lock() + defer r.mtx.Unlock() + delete(r.embeddings, id) +} + +// Swap replaces the embeddings in the retriever with the embeddings from the +// provided retriever. +// The mutex is acquired for the receiver, but not for the provided retriever. +func (r *SimpleRetriever) Swap(s *SimpleRetriever) { + r.mtx.Lock() + defer r.mtx.Unlock() + r.embeddings = s.embeddings + r.maxSize = s.maxSize +} + +// FilterFn is a function that filters out embeddings. +// If the function returns false, the embedding is filtered out. +type FilterFn func(id string, embedding *Embedding) bool + +// GetRelevant returns the k nearest neighbors to the query embedding. +// If a filter is provided, only the embeddings that pass the filter are considered. +func (r *SimpleRetriever) GetRelevant(query *Embedding, k int, filter FilterFn) []*Document { + // Replace with priority queue if k is large. + results := make([]*Document, 0, k) + + r.mtx.Lock() + defer r.mtx.Unlock() + + // Find the k nearest neighbors + for id, embedding := range r.embeddings { + // Skip if the document is filtered out + if filter != nil && !filter(id, embedding) { + continue + } + + // Calculate the similarity score + similarity, _ := calculateSimilarity(query.Vector, embedding.Vector) + // If the results slice smaller than the k, add the element to the results + if len(results) < k { + results = append(results, &Document{ + Embedding: embedding, + SimilarityScore: similarity, + }) + + // Sort to preserve the invariant - the result slice is sorted by + // similarity score + sort.Slice(results, func(i, j int) bool { + return results[i].SimilarityScore > results[j].SimilarityScore + }) + } else if similarity > results[len(results)-1].SimilarityScore { + // If the element is more relevant than the least similar element, + // add it to the result slice + results[len(results)-1] = &Document{ + Embedding: embedding, + SimilarityScore: similarity, + } + + // Sort to preserve the invariant - the result slice is sorted by + // similarity score + sort.Slice(results, func(i, j int) bool { + return results[i].SimilarityScore > results[j].SimilarityScore + }) + } + } + + // Return the results sorted by similarity score. + return results +} diff --git a/lib/ai/simpleretriever_test.go b/lib/ai/simpleretriever_test.go new file mode 100644 index 0000000000000..37030da39f40a --- /dev/null +++ b/lib/ai/simpleretriever_test.go @@ -0,0 +1,80 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai + +import ( + "fmt" + "math/rand" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" +) + +func TestSimpleRetriever_GetRelevant(t *testing.T) { + t.Parallel() + + // Generate random vector. The seed is fixed, so the results are deterministic. + randGen := rand.New(rand.NewSource(42)) + + generateVector := func() Vector64 { + const testVectorDimension = 100 + // generate random vector + // reduce the dimensionality to 100 + vec := make(Vector64, testVectorDimension) + for i := 0; i < testVectorDimension; i++ { + vec[i] = randGen.Float64() + } + // normalize vector, so the similarity between two vectors is the dot product + // between [0, 1] + return normalize(vec) + } + + const testEmbeddingsSize = 100 + points := make([]*Embedding, testEmbeddingsSize) + for i := 0; i < testEmbeddingsSize; i++ { + points[i] = NewEmbedding(types.KindNode, strconv.Itoa(i), generateVector(), [32]byte{}) + } + + // Create a query. + query := NewEmbedding(types.KindNode, "1", generateVector(), [32]byte{}) + + retriever := NewSimpleRetriever() + + for _, point := range points { + retriever.Insert(point.GetName(), point) + } + + // Get the top 10 most similar documents. + docs := retriever.GetRelevant(query, 10, func(id string, embedding *Embedding) bool { + return true + }) + require.Len(t, docs, 10) + + expectedResults := []int{57, 92, 95, 49, 33, 56, 30, 99, 90, 47} + expectedSimilarities := []float64{0.80405, 0.79051, 0.78161, 0.78159, + 0.77655, 0.77374, 0.77306, 0.76688, 0.76634, 0.76458} + + for i, result := range docs { + require.Equal(t, + fmt.Sprintf("%s/%s", types.KindNode, strconv.Itoa(expectedResults[i])), + result.GetName(), "expected order is wrong") + require.InDelta(t, expectedSimilarities[i], result.SimilarityScore, 10e-6, "similarity score is wrong") + } +} diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 5783cddbc2fe7..e0c34d8a007cf 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -117,9 +117,10 @@ type Chat struct { // NewChat creates a new Assist chat. func (a *Assist) NewChat(ctx context.Context, assistService MessageService, - conversationID string, username string, + embeddingServiceClient assist.AssistEmbeddingServiceClient, + conversationID, username string, ) (*Chat, error) { - aichat := a.client.NewChat(username) + aichat := a.client.NewChat(embeddingServiceClient, username) chat := &Chat{ assist: a, @@ -165,6 +166,12 @@ func (a *Assist) GenerateCommandSummary(ctx context.Context, messages []*assist. return a.client.CommandSummary(ctx, modelMessages, output) } +// reloadMessages clears the chat history and reloads the messages from the database. +func (c *Chat) reloadMessages(ctx context.Context) error { + c.chat.Clear() + return c.loadMessages(ctx) +} + // loadMessages loads the messages from the database. func (c *Chat) loadMessages(ctx context.Context) error { // existing conversation, retrieve old messages @@ -242,9 +249,7 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use // If data might have been inserted into the chat history, we want to // refresh and get the latest data before querying the model. if c.potentiallyStaleHistory { - c.chat = c.assist.client.NewChat(c.Username) - err := c.loadMessages(ctx) - if err != nil { + if err := c.reloadMessages(ctx); err != nil { return nil, trace.Wrap(err) } } diff --git a/lib/assist/assist_test.go b/lib/assist/assist_test.go index cc2013f1af388..2ac2141cc3fa5 100644 --- a/lib/assist/assist_test.go +++ b/lib/assist/assist_test.go @@ -71,7 +71,7 @@ func TestChatComplete(t *testing.T) { require.NoError(t, err) // When a chat is created. - chat, err := client.NewChat(ctx, authSrv.AuthServer, conversationResp.Id, testUser) + chat, err := client.NewChat(ctx, authSrv.AuthServer, nil, conversationResp.Id, testUser) require.NoError(t, err) t.Run("new conversation is new", func(t *testing.T) { diff --git a/lib/auth/assist/assistv1/service.go b/lib/auth/assist/assistv1/service.go index 4d718894297c4..a46dda271bbe1 100644 --- a/lib/auth/assist/assistv1/service.go +++ b/lib/auth/assist/assistv1/service.go @@ -22,23 +22,47 @@ import ( "context" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/emptypb" + "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/ai" + "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/services" ) // ServiceConfig holds configuration options for // the assist gRPC service. type ServiceConfig struct { - Backend services.Assistant + Backend services.Assistant + Embeddings *ai.SimpleRetriever + Embedder ai.Embedder + Authorizer authz.Authorizer + Logger *logrus.Entry + ResourceGetter ResourceGetter +} + +// ResourceGetter represents a subset of the auth.Cache interface. +// Created to avoid circular dependencies. +type ResourceGetter interface { + GetNode(ctx context.Context, namespace, name string) (types.Server, error) } // Service implements the teleport.assist.v1.AssistService RPC service. type Service struct { assist.UnimplementedAssistServiceServer - - backend services.Assistant + assist.UnimplementedAssistEmbeddingServiceServer + + backend services.Assistant + embeddings *ai.SimpleRetriever + // embedder is used to embed text into a vector. + // It can be nil if the OpenAI API key is not set. + embedder ai.Embedder + authorizer authz.Authorizer + log *logrus.Entry + resourceGetter ResourceGetter } // NewService returns a new assist gRPC service. @@ -46,10 +70,24 @@ func NewService(cfg *ServiceConfig) (*Service, error) { switch { case cfg.Backend == nil: return nil, trace.BadParameter("backend is required") + case cfg.Embeddings == nil: + return nil, trace.BadParameter("embeddings is required") + case cfg.Authorizer == nil: + return nil, trace.BadParameter("authorizer is required") + case cfg.ResourceGetter == nil: + return nil, trace.BadParameter("resource getter is required") + case cfg.Logger == nil: + cfg.Logger = logrus.WithField(trace.Component, "assist.service") } + // Embedder can be nil is the OpenAI API key is not set. return &Service{ - backend: cfg.Backend, + backend: cfg.Backend, + embeddings: cfg.Embeddings, + embedder: cfg.Embedder, + authorizer: cfg.Authorizer, + resourceGetter: cfg.ResourceGetter, + log: cfg.Logger, }, nil } @@ -92,6 +130,67 @@ func (a *Service) CreateAssistantMessage(ctx context.Context, req *assist.Create } // IsAssistEnabled returns true if the assist is enabled or not on the auth level. -func (a *Service) IsAssistEnabled(ctx context.Context, req *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) { +func (a *Service) IsAssistEnabled(ctx context.Context, _ *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) { + if a.embedder == nil { + // If the embedder is not configured, the assist is not enabled as we cannot compute embeddings. + return &assist.IsAssistEnabledResponse{Enabled: false}, nil + } + + // Check if assist can use the backend. return a.backend.IsAssistEnabled(ctx) } + +func (a *Service) GetAssistantEmbeddings(ctx context.Context, msg *assist.GetAssistantEmbeddingsRequest) (*assist.GetAssistantEmbeddingsResponse, error) { + // TODO(jakule): The kind needs to be updated when we add more resources. + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindNode, types.VerbRead, types.VerbList) + if err != nil { + return nil, trace.Wrap(err) + } + + if a.embedder == nil { + return nil, trace.BadParameter("assist is not configured in auth server") + } + + // Call the openAI API to get the embeddings for the query. + embeddings, err := a.embedder.ComputeEmbeddings(ctx, []string{msg.Query}) + if err != nil { + return nil, trace.Wrap(err) + } + if len(embeddings) == 0 { + return nil, trace.NotFound("OpenAI embeddings returned no results") + } + + // Use default values for the id and content, as we only care about the embeddings. + queryEmbeddings := ai.NewEmbedding(msg.Kind, "", embeddings[0], [32]byte{}) + documents := a.embeddings.GetRelevant(queryEmbeddings, int(msg.Limit), func(id string, embedding *ai.Embedding) bool { + // Run RBAC check on the embedded resource. + node, err := a.resourceGetter.GetNode(ctx, defaults.Namespace, embedding.GetEmbeddedID()) + if err != nil { + a.log.Tracef("failed to get node %q: %v", embedding.GetName(), err) + return false + } + return authCtx.Checker.CheckAccess(node, services.AccessState{MFAVerified: true}) == nil + }) + + protoDocs := make([]*assist.EmbeddedDocument, 0, len(documents)) + for _, doc := range documents { + node, err := a.resourceGetter.GetNode(ctx, defaults.Namespace, doc.GetEmbeddedID()) + if err != nil { + return nil, trace.Wrap(err) + } + + content, err := ai.SerializeNode(node) + if err != nil { + return nil, trace.Wrap(err) + } + protoDocs = append(protoDocs, &assist.EmbeddedDocument{ + Id: doc.GetEmbeddedID(), + Content: string(content), + SimilarityScore: float32(doc.SimilarityScore), + }) + } + + return &assist.GetAssistantEmbeddingsResponse{ + Embeddings: protoDocs, + }, nil +} diff --git a/lib/auth/auth.go b/lib/auth/auth.go index fe8cb8957e4e9..8ce1cd2b541c6 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -65,6 +65,7 @@ import ( "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/api/utils/retryutils" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/auth/native" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" @@ -310,6 +311,8 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { fips: cfg.FIPS, loadAllCAs: cfg.LoadAllCAs, httpClientForAWSSTS: cfg.HTTPClientForAWSSTS, + embeddingsRetriever: cfg.EmbeddingRetriever, + embedder: cfg.EmbeddingClient, } as.inventory = inventory.NewController(&as, services, inventory.WithAuthServerID(cfg.HostUUID)) for _, o := range opts { @@ -621,6 +624,12 @@ type Server struct { // httpClientForAWSSTS overwrites the default HTTP client used for making // STS requests. httpClientForAWSSTS utils.HTTPDoClient + + // embeddingRetriever is a retriever used to retrieve embeddings from the backend. + embeddingsRetriever *ai.SimpleRetriever + + // embedder is an embedder client used to generate embeddings. + embedder ai.Embedder } // SetSAMLService registers svc as the SAMLService that provides the SAML diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 1f9dd51218255..ab9c4be56e87e 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -318,6 +318,15 @@ func (a *ServerWithRoles) PluginsClient() pluginspb.PluginServiceClient { ) } +// EmbeddingClient allows ServerWithRoles to implement ClientI. +// It should not be called through ServerWithRoles, +// as it returns a dummy client that will always respond with "not implemented". +func (a *ServerWithRoles) EmbeddingClient() assist.AssistEmbeddingServiceClient { + return assist.NewAssistEmbeddingServiceClient( + utils.NewGRPCDummyClientConnection("EmbeddingClient() should not be called on ServerWithRoles"), + ) +} + // SAMLIdPClient allows ServerWithRoles to implement ClientI. // It should not be called through ServerWithRoles, // as it returns a dummy client that will always respond with "not implemented". diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 487151fa5e811..833c4827b69e7 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" + assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" loginrulepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/loginrule/v1" pluginspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" @@ -733,6 +734,9 @@ type ClientI interface { // "not implemented" errors (as per the default gRPC behavior). LoginRuleClient() loginrulepb.LoginRuleServiceClient + // EmbeddingClient returns a client to the Embedding gRPC service. + EmbeddingClient() assistpb.AssistEmbeddingServiceClient + // NewKeepAliver returns a new instance of keep aliver NewKeepAliver(ctx context.Context) (types.KeepAliver, error) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 3a70984289a1a..2c19ee5919a48 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5359,12 +5359,17 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { // Initialize and register the assist service. assistSrv, err := assistv1.NewService(&assistv1.ServiceConfig{ - Backend: cfg.AuthServer.Services, + Backend: cfg.AuthServer.Services, + Embeddings: cfg.AuthServer.embeddingsRetriever, + Embedder: cfg.AuthServer.embedder, + Authorizer: cfg.Authorizer, + ResourceGetter: cfg.AuthServer, }) if err != nil { return nil, trace.Wrap(err) } assist.RegisterAssistServiceServer(server, assistSrv) + assist.RegisterAssistEmbeddingServiceServer(server, assistSrv) // create server with no-op role to pass to JoinService server serverWithNopRole, err := serverWithNopRole(cfg) diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index eb326bab63b8e..f71c1fdd56c3c 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/auth/native" authority "github.com/gravitational/teleport/lib/auth/testauthority" @@ -76,6 +77,8 @@ type TestAuthServerConfig struct { TraceClient otlptrace.Client // AuthPreferenceSpec is custom initial AuthPreference spec for the test. AuthPreferenceSpec *types.AuthPreferenceSpecV2 + // Embedder is required to enable the assist in the auth server. + Embedder ai.Embedder } // CheckAndSetDefaults checks and sets defaults @@ -98,6 +101,9 @@ func (cfg *TestAuthServerConfig) CheckAndSetDefaults() error { SecondFactor: constants.SecondFactorOff, } } + if cfg.Embedder == nil { + cfg.Embedder = &noopEmbedder{} + } return nil } @@ -187,6 +193,14 @@ func WithClock(clock clockwork.Clock) ServerOption { } } +// WithEmbedder is a functional server option that sets the server's embedder. +func WithEmbedder(embedder ai.Embedder) ServerOption { + return func(s *Server) error { + s.embedder = embedder + return nil + } +} + // TestAuthServer is auth server using local filesystem backend // and test certificate authority key generation that speeds up // keygen by using the same private key @@ -268,7 +282,11 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { RSAKeyPairSource: authority.New().GenerateKeyPair, }, }, - }, WithClock(cfg.Clock)) + EmbeddingRetriever: ai.NewSimpleRetriever(), + }, + WithClock(cfg.Clock), + WithEmbedder(cfg.Embedder), + ) if err != nil { return nil, trace.Wrap(err) } @@ -1163,3 +1181,10 @@ func CreateUserAndRoleWithoutRoles(clt clt, username string, allowedLogins []str return user, role, nil } + +// noopEmbedder is a no op implementation of the Embedder interface. +type noopEmbedder struct{} + +func (n noopEmbedder) ComputeEmbeddings(_ context.Context, _ []string) ([]ai.Vector64, error) { + return []ai.Vector64{}, nil +} diff --git a/lib/auth/init.go b/lib/auth/init.go index 8886670715639..0d3059f2f5adf 100644 --- a/lib/auth/init.go +++ b/lib/auth/init.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib" + "github.com/gravitational/teleport/lib/ai" "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/backend" @@ -232,6 +233,12 @@ type InitConfig struct { // HTTPClientForAWSSTS overwrites the default HTTP client used for making // STS requests. Used in test. HTTPClientForAWSSTS utils.HTTPDoClient + + // EmbeddingRetriever is a retriever for embeddings. + EmbeddingRetriever *ai.SimpleRetriever + + // EmbeddingClient is a client that allows generating embeddings. + EmbeddingClient ai.Embedder } // Init instantiates and configures an instance of AuthServer diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index 5baba26027ef4..ce9b58b12598f 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -625,7 +625,7 @@ func roleSpecForProxy(clusterName string) types.RoleSpecV6 { } } -// RoleSetForBuiltinRole returns RoleSet for embedded builtin role +// RoleSetForBuiltinRoles returns RoleSet for embedded builtin role func RoleSetForBuiltinRoles(clusterName string, recConfig types.SessionRecordingConfig, roles ...types.SystemRole) (services.RoleSet, error) { var definitions []types.Role for _, role := range roles { diff --git a/lib/service/service.go b/lib/service/service.go index ff684ffd0df87..b0fd894112ba7 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -258,6 +258,15 @@ const ( TeleportOKEvent = "TeleportOKEvent" ) +const ( + // embeddingInitialDelay is the time to wait before the first embedding + // routine is started. + embeddingInitialDelay = 10 * time.Second + // embeddingPeriod is the time between two embedding routines. + // A seventh jitter is applied on the period. + embeddingPeriod = time.Hour +) + // Connector has all resources process needs to connect to other parts of the // cluster: client and identity. type Connector struct { @@ -1619,6 +1628,13 @@ func (process *TeleportProcess) initAuthService() error { traceClt = clt } + var embedderClient ai.Embedder + if cfg.Auth.AssistAPIKey != "" { + embedderClient = ai.NewClient(cfg.Auth.AssistAPIKey) + } + + embeddingsRetriever := ai.NewSimpleRetriever() + // first, create the AuthServer authServer, err := auth.Init(auth.InitConfig{ Backend: b, @@ -1657,6 +1673,8 @@ func (process *TeleportProcess) initAuthService() error { LoadAllCAs: cfg.Auth.LoadAllCAs, Clock: cfg.Clock, HTTPClientForAWSSTS: cfg.Auth.HTTPClientForAWSSTS, + EmbeddingRetriever: embeddingsRetriever, + EmbeddingClient: embedderClient, }, func(as *auth.Server) error { if !process.Config.CachePolicy.Enabled { return nil @@ -1696,14 +1714,15 @@ func (process *TeleportProcess) initAuthService() error { } authServer.SetLockWatcher(lockWatcher) - if cfg.Auth.AssistAPIKey != "" { - openAIClient := ai.NewClient(cfg.Auth.AssistAPIKey) + if embedderClient != nil { + log.Debugf("Starting embedding watcher") embeddingProcessor := ai.NewEmbeddingProcessor(&ai.EmbeddingProcessorConfig{ - AIClient: openAIClient, - EmbeddingSrv: authServer, - NodeSrv: authServer, - Log: log, - Jitter: retryutils.NewFullJitter(), + AIClient: embedderClient, + EmbeddingsRetriever: embeddingsRetriever, + EmbeddingSrv: authServer, + NodeSrv: authServer, + Log: log, + Jitter: retryutils.NewFullJitter(), }) process.RegisterFunc("ai.embedding-processor", func() error { @@ -1725,7 +1744,7 @@ func (process *TeleportProcess) initAuthService() error { return nil } log.Debugf("Starting embedding processor") - return embeddingProcessor.Run(process.ExitContext(), ai.EmbeddingPeriod) + return embeddingProcessor.Run(process.ExitContext(), embeddingInitialDelay, embeddingPeriod) }) } diff --git a/lib/services/embeddings.go b/lib/services/embeddings.go index 2143b4195e5db..1149213bf69cb 100644 --- a/lib/services/embeddings.go +++ b/lib/services/embeddings.go @@ -31,6 +31,6 @@ type Embeddings interface { GetEmbedding(ctx context.Context, kind, resourceID string) (*ai.Embedding, error) // GetEmbeddings returns all embeddings for a given kind. GetEmbeddings(ctx context.Context, kind string) stream.Stream[*ai.Embedding] - // UpsertEmbedding creates or update a single ai.Embedding in the backend. + // UpsertEmbedding creates or updates a single ai.Embedding in the backend. UpsertEmbedding(ctx context.Context, embedding *ai.Embedding) (*ai.Embedding, error) } diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 25abe37fb42ae..2001d9caa6de5 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -416,7 +416,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - chat, err := assistClient.NewChat(ctx, authClient, conversationID, sctx.GetUser()) + chat, err := assistClient.NewChat(ctx, authClient, authClient.EmbeddingClient(), conversationID, sctx.GetUser()) if err != nil { return trace.Wrap(err) }