Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,12 @@ func (c *Client) TrustClient() trustpb.TrustServiceClient {
return trustpb.NewTrustServiceClient(c.conn)
}

// EmbeddingClient returns an unadorned Embedding client, using the underlying
// Auth gRPC connection.
func (c *Client) EmbeddingClient() assist.AssistEmbeddingServiceClient {
return assist.NewAssistEmbeddingServiceClient(c.conn)
}

// Ping gets basic info about the auth server.
func (c *Client) Ping(ctx context.Context) (proto.PingResponse, error) {
rsp, err := c.grpc.Ping(ctx, &proto.PingRequest{})
Expand Down Expand Up @@ -3966,3 +3972,11 @@ func (c *Client) UpdateAssistantConversationInfo(ctx context.Context, in *assist
}
return nil
}

func (c *Client) GetAssistantEmbeddings(ctx context.Context, in *assist.GetAssistantEmbeddingsRequest) (*assist.GetAssistantEmbeddingsResponse, error) {
result, err := c.EmbeddingClient().GetAssistantEmbeddings(ctx, in)
if err != nil {
return nil, trail.FromGRPC(err)
}
return result, nil
}
431 changes: 348 additions & 83 deletions api/gen/proto/go/assist/v1/assist.pb.go

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions api/gen/proto/go/assist/v1/assist_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions api/proto/teleport/assist/v1/assist.proto
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,35 @@ message DeleteAssistantConversationRequest {
string username = 2;
}

// GetAssistantEmbeddingsRequest is a request to get embeddings.
message GetAssistantEmbeddingsRequest {
// username is a username of the user who requested the embeddings.
string username = 1;
// query is the query used for similarity search.
string query = 2;
// limit is the number of embeddings to return (also known as k).
uint32 limit = 3;
// kind is the kind of embeddings to return (ex, node).
string kind = 4;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be a oneof? stringly typed enumerations are rather fragile/unclear and this doesn't describe all possible/valid states either.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why oneof and not enum then? My idea was to keep it flexible, but after adding more types, we will probably need to re-route them to different sources, so we will need to rely on that type anyways.

}

// EmbeddingDocument is a document with an embedding.
message EmbeddedDocument {
// id is the id of the document.
string id = 1;
// content is the content of the document.
string content = 2;
// similarityScore is the similarity score of the document.
float similarity_score = 3;
}

// GetAssistantEmbeddingsResponse is a response from the assistant service.
message GetAssistantEmbeddingsResponse {
// embeddings is the list of embeddings.
// The list is sorted by similarity score in descending order.
repeated EmbeddedDocument embeddings = 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we specify these are ordered?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

}

// AssistService is a service that provides an ability to communicate with the Teleport Assist.
service AssistService {
// CreateNewConversation creates a new conversation and returns the UUID of it.
Expand All @@ -144,3 +173,9 @@ service AssistService {
// IsAssistEnabled returns true if the assist is enabled or not on the auth level.
rpc IsAssistEnabled(IsAssistEnabledRequest) returns (IsAssistEnabledResponse);
}

// AssistEmbeddingService is a service that provides an ability to communicate with the Assist Embedding service.
service AssistEmbeddingService {
// AssistantGetEmbeddings returns the embeddings for the given query.
rpc GetAssistantEmbeddings(GetAssistantEmbeddingsRequest) returns (GetAssistantEmbeddingsResponse);
}
8 changes: 7 additions & 1 deletion lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Chat struct {
client *Client
messages []openai.ChatCompletionMessage
tokenizer tokenizer.Codec
agent *model.Agent
}

// Insert inserts a message into the conversation. Returns the index of the message.
Expand Down Expand Up @@ -70,10 +71,15 @@ func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) {
Content: userInput,
}

response, err := model.AssistAgent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage)
response, err := chat.agent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage)
if err != nil {
return nil, trace.Wrap(err)
}

return response, nil
}

// Clear clears the conversation.
func (chat *Chat) Clear() {
chat.messages = []openai.ChatCompletionMessage{}
}
10 changes: 5 additions & 5 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hello",
},
},
want: 632,
want: 703,
},
{
name: "system and user messages",
Expand All @@ -63,7 +63,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hi LLM.",
},
},
want: 640,
want: 711,
},
{
name: "tokenize our prompt",
Expand All @@ -77,7 +77,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Show me free disk space on localhost node.",
},
},
want: 843,
want: 914,
},
}

Expand All @@ -96,7 +96,7 @@ func TestChat_PromptTokens(t *testing.T) {
cfg.BaseURL = server.URL + "/v1"

client := NewClientFromConfig(cfg)
chat := client.NewChat("Bob")
chat := client.NewChat(nil, "Bob")

for _, message := range tt.messages {
chat.Insert(message.Role, message.Content)
Expand Down Expand Up @@ -128,7 +128,7 @@ func TestChat_Complete(t *testing.T) {
cfg.BaseURL = server.URL + "/v1"
client := NewClientFromConfig(cfg)

chat := client.NewChat("Bob")
chat := client.NewChat(nil, "Bob")

t.Run("initial message", func(t *testing.T) {
msgAny, err := chat.Complete(context.Background(), "Hello")
Expand Down
5 changes: 4 additions & 1 deletion lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer/codec"

"github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
"github.com/gravitational/teleport/lib/ai/model"
)

Expand All @@ -43,7 +44,8 @@ func NewClientFromConfig(config openai.ClientConfig) *Client {

// NewChat creates a new chat. The username is set in the conversation context,
// so that the AI can use it to personalize the conversation.
func (client *Client) NewChat(username string) *Chat {
// embeddingServiceClient is used to get the embeddings from the Auth Server.
func (client *Client) NewChat(embeddingServiceClient assist.AssistEmbeddingServiceClient, username string) *Chat {
return &Chat{
client: client,
messages: []openai.ChatCompletionMessage{
Expand All @@ -55,6 +57,7 @@ func (client *Client) NewChat(username string) *Chat {
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: model.NewAgent(embeddingServiceClient, username),
}
}

Expand Down
8 changes: 2 additions & 6 deletions lib/ai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package ai
import (
"context"
"crypto/sha256"
"time"

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
Expand All @@ -28,9 +27,6 @@ import (

const (
maxOpenAIEmbeddingsPerRequest = 1000
// EmbeddingPeriod is the time between two embedding routines.
// A seventh jitter is applied on the period.
EmbeddingPeriod = time.Hour
)

// EmbeddingHash is the hash function that should be used to compute embedding
Expand Down Expand Up @@ -92,11 +88,11 @@ func NewEmbedding(kind, id string, vector Vector64, hash Sha256Hash) *Embedding
}

// Embedder is implemented for batch text embedding. Embedding can happen in
// place (with an embedding model for example) or be done by a remote embedding
// place (with an embedding model, for example) or be done by a remote embedding
// service like OpenAI.
type Embedder interface {
// ComputeEmbeddings computes the embeddings of multiple strings.
// The embedding list follows the input order (e.g. result[i] is the
// The embedding list follows the input order (e.g., result[i] is the
// embedding of input[i]).
ComputeEmbeddings(ctx context.Context, input []string) ([]Vector64, error)
}
Expand Down
Loading