Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer"

"github.com/gravitational/teleport/lib/ai/model"
"github.com/gravitational/teleport/lib/ai/model/output"
Expand All @@ -30,10 +29,9 @@ import (

// Chat represents a conversation between a user and an assistant with context memory.
type Chat struct {
client *Client
messages []openai.ChatCompletionMessage
tokenizer tokenizer.Codec
agent *model.Agent
client *Client
messages []openai.ChatCompletionMessage
agent *model.Agent
}

// Insert inserts a message into the conversation. Returns the index of the message.
Expand Down
8 changes: 4 additions & 4 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,13 @@ func generateAccessRequestResponse(t *testing.T) string {
func TestChat_Complete_AuditQuery(t *testing.T) {
// Test setup: generate the responses that will be served by our OpenAI mock
action := model.PlanOutput{
Action: "Audit Query Generation",
Action: tools.AuditQueryGenerationToolName,
ActionInput: "Lists user who connected to a server as root.",
Reasoning: "foo",
}
selectedAction, err := json.Marshal(action)
require.NoError(t, err)
generatedQuery := "SELECT user FROM session_start WHERE login='root'"
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"

responses := []string{
// The model must select the audit query tool
Expand Down Expand Up @@ -356,7 +356,7 @@ func TestChat_Complete_AuditQuery(t *testing.T) {
require.NoError(t, err)

// We check that the agent returns the expected response
message, ok := result.(*output.Message)
message, ok := result.(*output.StreamingMessage)
require.True(t, ok)
require.Equal(t, generatedQuery, message.Content)
require.Equal(t, generatedQuery, message.WaitAndConsume())
}
35 changes: 26 additions & 9 deletions lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer/codec"

"github.com/gravitational/teleport/lib/ai/embedding"
"github.com/gravitational/teleport/lib/ai/model"
Expand Down Expand Up @@ -75,10 +74,7 @@ func (client *Client) NewChat(toolContext *modeltools.ToolContext) *Chat {
Content: model.PromptCharacter(toolContext.User),
},
},
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: model.NewAgent(toolContext, tools...),
agent: model.NewAgent(toolContext, tools...),
}
}

Expand All @@ -92,13 +88,34 @@ func (client *Client) NewCommand(username string) *Chat {
Content: model.PromptCharacter(username),
},
},
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: model.NewAgent(toolContext, &modeltools.CommandGenerationTool{}),
agent: model.NewAgent(toolContext, &modeltools.CommandGenerationTool{}),
}
}

func (client *Client) RunTool(ctx context.Context, toolContext *modeltools.ToolContext, toolName, toolInput string) (any, *tokens.TokenCount, error) {
tools := []modeltools.Tool{
&modeltools.CommandExecutionTool{},
&modeltools.EmbeddingRetrievalTool{},
&modeltools.AuditQueryGenerationTool{LLM: client.svc},
}
// The following tools are only available in the enterprise build. They will fail
// if included in OSS due to the lack of the required backend APIs.
if modules.GetModules().BuildType() == modules.BuildEnterprise {
tools = append(tools, &modeltools.AccessRequestCreateTool{},
&modeltools.AccessRequestsListTool{},
&modeltools.AccessRequestListRequestableRolesTool{},
&modeltools.AccessRequestListRequestableResourcesTool{})
}
agent := model.NewAgent(toolContext, tools...)
action := &model.AgentAction{
Action: toolName,
Input: toolInput,
Reasoning: "Tool invoked directly",
}

return agent.DoAction(ctx, client.svc, action)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We should wrap the error. I'm not sure what the other returns are from the context here.

Suggested change
return agent.DoAction(ctx, client.svc, action)
?, ?, err = agent.DoAction(ctx, client.svc, action)
return ?, ?, trace.Wrap(err)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should, but I also don't think we should do that in a backport. We can address those comments in a separate PR targeting master.

}

func (client *Client) NewAuditQuery(username string) *Chat {
toolContext := &modeltools.ToolContext{User: username}
return &Chat{
Expand Down
110 changes: 110 additions & 0 deletions lib/ai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai

import (
"context"
"encoding/json"
"net/http/httptest"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"

assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
"github.com/gravitational/teleport/lib/ai/model/output"
"github.com/gravitational/teleport/lib/ai/model/tools"
"github.com/gravitational/teleport/lib/ai/testutils"
)

func TestRunTool_AuditQueryGeneration(t *testing.T) {
// Test setup: starting a mock openai server and creating the client
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"

responses := []string{
// Then the audit query tool chooses to request session.start events
"session.start",
// Finally the tool builds a query based on the provided schemas
generatedQuery,
}
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL

client := NewClientFromConfig(cfg)

// Doing the test: Check that the AuditQueryGeneration tool can be invoked
// through client.RunTool and validate its response.
ctx := context.Background()
toolCtx := &tools.ToolContext{User: "alice"}
response, _, err := client.RunTool(ctx, toolCtx, tools.AuditQueryGenerationToolName, "List users who connected to a server as root")
require.NoError(t, err)
message, ok := response.(*output.StreamingMessage)
require.True(t, ok)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: add more context so that we get more than "should be true" if the assertion fails

Suggested change
require.True(t, ok)
require.True(t, ok, "expected response to be an output.StreamingMessage, got %T, response)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could change, but I also don't think we should do that in a backport. We can address those comments in a separate PR targeting master.

require.Equal(t, generatedQuery, message.WaitAndConsume())
}

type mockEmbeddingGetter struct {
response []*assistpb.EmbeddedDocument
}

func (m *mockEmbeddingGetter) GetAssistantEmbeddings(ctx context.Context, in *assistpb.GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*assistpb.GetAssistantEmbeddingsResponse, error) {
return &assistpb.GetAssistantEmbeddingsResponse{Embeddings: m.response}, nil
}

func TestRunTool_EmbeddingRetrieval(t *testing.T) {
// Test setup: starting a mock openai server and embedding getter,
// then create the client.
mock := &mockEmbeddingGetter{
[]*assistpb.EmbeddedDocument{
{
Id: "1",
Content: "foo",
SimilarityScore: 1,
},
{
Id: "2",
Content: "bar",
SimilarityScore: 0.9,
},
},
}
ctx := context.Background()
toolCtx := &tools.ToolContext{AssistEmbeddingServiceClient: mock}

responses := make([]string, 0)
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL
client := NewClientFromConfig(cfg)

// Doing the test: Check that the EmbeddingRetrieval tool can be invoked
// through client.RunTool and validate its response.
input := tools.EmbeddingRetrievalToolInput{Question: "Find foobar"}
inputText, err := json.Marshal(input)
require.NoError(t, err)
response, _, err := client.RunTool(ctx, toolCtx, "Nodes names and labels retrieval", string(inputText))
require.NoError(t, err)
message, ok := response.(*output.Message)
require.True(t, ok)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: add more context so that we get more than "should be true" if the assertion fails

Suggested change
require.True(t, ok)
require.True(t, ok, "expected response to be an output.StreamingMessage, got %T, response

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could change, but I also don't think we should do that in a backport. We can address those comments in a separate PR targeting master.

require.Equal(t, "foo\nbar\n", message.Content)
}
70 changes: 31 additions & 39 deletions lib/ai/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package model
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"time"

Expand Down Expand Up @@ -135,6 +133,28 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist
}
}

func (a *Agent) DoAction(ctx context.Context, llm *openai.Client, action *AgentAction) (any, *tokens.TokenCount, error) {
state := &executionState{
llm: llm,
tokenCount: tokens.NewTokenCount(),
}
out, err := a.doAction(ctx, state, action)
if err != nil {
return nil, nil, trace.Wrap(err)
}
switch {
case out.finish != nil:
// If the tool already breaks execution, we don't have to do anything
return out.finish.output, state.tokenCount, nil
case out.observation != "":
// If the tool doesn't break execution and returns a single observation,
// we wrap the observation in a Message.
return &output.Message{Content: out.observation}, state.tokenCount, nil
default:
return nil, state.tokenCount, trace.Errorf("action %s did not end execution nor returned an observation", action.Action)
}
}

// stepOutput represents the inputs and outputs of a single thought step.
type stepOutput struct {
// if the agent is done, finish is set.
Expand Down Expand Up @@ -191,6 +211,10 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
// If action is set, the agent is not done and called upon a tool.
progressUpdates(action)

return a.doAction(ctx, state, action)
}

func (a *Agent) doAction(ctx context.Context, state *executionState, action *AgentAction) (stepOutput, error) {
var tool tools.Tool
for _, candidate := range a.tools {
if candidate.Name() == action.Action {
Expand Down Expand Up @@ -266,14 +290,12 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
}

log.Tracef("Tool chose to query table '%s'", tableName)
query, err := tool.GenerateQuery(ctx, tableName, action.Input, state.tokenCount)
response, err := tool.GenerateQuery(ctx, tableName, action.Input, state.tokenCount)
if err != nil {
return stepOutput{}, trace.Wrap(err)
}
log.Tracef("Tool generated query: %s", query)

completion := &output.Message{Content: query}
return stepOutput{finish: &agentFinish{output: completion}}, nil
return stepOutput{finish: &agentFinish{output: response}}, nil
default:
runOut, err := tool.Run(ctx, a.toolCtx, action.Input)
if err != nil {
Expand Down Expand Up @@ -305,23 +327,7 @@ func (a *Agent) plan(ctx context.Context, state *executionState) (*AgentAction,
return nil, nil, trace.Wrap(err)
}

deltas := make(chan string)
go func() {
defer close(deltas)

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
return
} else if err != nil {
log.Tracef("agent encountered an error while streaming: %v", err)
return
}

delta := response.Choices[0].Delta.Content
deltas <- delta
}
}()
deltas := output.StreamToDeltas(stream)

action, finish, completionTokenCounter, err := parsePlanningOutput(deltas)
state.tokenCount.AddCompletionCounter(completionTokenCounter)
Expand Down Expand Up @@ -391,25 +397,11 @@ func parsePlanningOutput(deltas <-chan string) (*AgentAction, *agentFinish, toke
text += delta

if strings.HasPrefix(text, finalResponseHeader) {
parts := make(chan string)
streamingTokenCounter, err := tokens.NewAsynchronousTokenCounter(text)
message, tc, err := output.NewStreamingMessage(deltas, text, finalResponseHeader)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}
go func() {
defer close(parts)

parts <- strings.TrimPrefix(text, finalResponseHeader)
for delta := range deltas {
parts <- delta
errCount := streamingTokenCounter.Add()
if errCount != nil {
log.WithError(errCount).Debug("Failed to add streamed completion text to the token counter")
}
}
}()

return nil, &agentFinish{output: &output.StreamingMessage{Parts: parts}}, streamingTokenCounter, nil
return nil, &agentFinish{output: message}, tc, nil
}
}

Expand Down
12 changes: 12 additions & 0 deletions lib/ai/model/output/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package output

import "strings"

// Message represents a new message within a live conversation.
type Message struct {
Content string
Expand All @@ -26,6 +28,16 @@ type StreamingMessage struct {
Parts <-chan string
}

// WaitAndConsume waits until the message stream is over and returns the full message.
// This can only be called once on a message as it empties its Parts channel.
func (msg *StreamingMessage) WaitAndConsume() string {
sb := strings.Builder{}
for part := range msg.Parts {
sb.WriteString(part)
}
return sb.String()
}

// Label represents a label returned by OpenAI's completion API.
type Label struct {
Key string `json:"key"`
Expand Down
Loading