Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,13 @@ func generateAccessRequestResponse(t *testing.T) string {
func TestChat_Complete_AuditQuery(t *testing.T) {
// Test setup: generate the responses that will be served by our OpenAI mock
action := model.PlanOutput{
Action: "Audit Query Generation",
Action: tools.AuditQueryGenerationToolName,
ActionInput: "Lists user who connected to a server as root.",
Reasoning: "foo",
}
selectedAction, err := json.Marshal(action)
require.NoError(t, err)
generatedQuery := "SELECT user FROM session_start WHERE login='root'"
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"

responses := []string{
// The model must select the audit query tool
Expand Down
24 changes: 24 additions & 0 deletions lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,30 @@ func (client *Client) NewCommand(username string) *Chat {
}
}

func (client *Client) RunTool(ctx context.Context, toolContext *modeltools.ToolContext, toolName, toolInput string) (any, *tokens.TokenCount, error) {
tools := []modeltools.Tool{
&modeltools.CommandExecutionTool{},
&modeltools.EmbeddingRetrievalTool{},
&modeltools.AuditQueryGenerationTool{LLM: client.svc},
}
// The following tools are only available in the enterprise build. They will fail
// if included in OSS due to the lack of the required backend APIs.
if modules.GetModules().BuildType() == modules.BuildEnterprise {
tools = append(tools, &modeltools.AccessRequestCreateTool{},
&modeltools.AccessRequestsListTool{},
&modeltools.AccessRequestListRequestableRolesTool{},
&modeltools.AccessRequestListRequestableResourcesTool{})
}
agent := model.NewAgent(toolContext, tools...)
action := &model.AgentAction{
Action: toolName,
Input: toolInput,
Reasoning: "Tool invoked directly",
}

return agent.DoAction(ctx, client.svc, action)
}

func (client *Client) NewAuditQuery(username string) *Chat {
toolContext := &modeltools.ToolContext{User: username}
return &Chat{
Expand Down
110 changes: 110 additions & 0 deletions lib/ai/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai

import (
"context"
"encoding/json"
"net/http/httptest"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"

assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
"github.com/gravitational/teleport/lib/ai/model/output"
"github.com/gravitational/teleport/lib/ai/model/tools"
"github.com/gravitational/teleport/lib/ai/testutils"
)

func TestRunTool_AuditQueryGeneration(t *testing.T) {
// Test setup: starting a mock openai server and creating the client
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"

responses := []string{
// Then the audit query tool chooses to request session.start events
"session.start",
// Finally the tool builds a query based on the provided schemas
generatedQuery,
}
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL

client := NewClientFromConfig(cfg)

// Doing the test: Check that the AuditQueryGeneration tool can be invoked
// through client.RunTool and validate its response.
ctx := context.Background()
toolCtx := &tools.ToolContext{User: "alice"}
response, _, err := client.RunTool(ctx, toolCtx, tools.AuditQueryGenerationToolName, "List users who connected to a server as root")
require.NoError(t, err)
message, ok := response.(*output.StreamingMessage)
require.True(t, ok)
require.Equal(t, generatedQuery, message.WaitAndConsume())
}

type mockEmbeddingGetter struct {
response []*assistpb.EmbeddedDocument
}

func (m *mockEmbeddingGetter) GetAssistantEmbeddings(ctx context.Context, in *assistpb.GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*assistpb.GetAssistantEmbeddingsResponse, error) {
return &assistpb.GetAssistantEmbeddingsResponse{Embeddings: m.response}, nil
}

func TestRunTool_EmbeddingRetrieval(t *testing.T) {
// Test setup: starting a mock openai server and embedding getter,
// then create the client.
mock := &mockEmbeddingGetter{
[]*assistpb.EmbeddedDocument{
{
Id: "1",
Content: "foo",
SimilarityScore: 1,
},
{
Id: "2",
Content: "bar",
SimilarityScore: 0.9,
},
},
}
ctx := context.Background()
toolCtx := &tools.ToolContext{AssistEmbeddingServiceClient: mock}

responses := make([]string, 0)
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL
client := NewClientFromConfig(cfg)

// Doing the test: Check that the EmbeddingRetrieval tool can be invoked
// through client.RunTool and validate its response.
input := tools.EmbeddingRetrievalToolInput{Question: "Find foobar"}
inputText, err := json.Marshal(input)
require.NoError(t, err)
response, _, err := client.RunTool(ctx, toolCtx, "Nodes names and labels retrieval", string(inputText))
require.NoError(t, err)
message, ok := response.(*output.Message)
require.True(t, ok)
require.Equal(t, "foo\nbar\n", message.Content)
}
26 changes: 26 additions & 0 deletions lib/ai/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist
}
}

func (a *Agent) DoAction(ctx context.Context, llm *openai.Client, action *AgentAction) (any, *tokens.TokenCount, error) {
state := &executionState{
llm: llm,
tokenCount: tokens.NewTokenCount(),
}
out, err := a.doAction(ctx, state, action)
if err != nil {
return nil, nil, trace.Wrap(err)
}
switch {
case out.finish != nil:
// If the tool already breaks execution, we don't have to do anything
return out.finish.output, state.tokenCount, nil
case out.observation != "":
// If the tool doesn't break execution and returns a single observation,
// we wrap the observation in a Message.
return &output.Message{Content: out.observation}, state.tokenCount, nil
default:
return nil, state.tokenCount, trace.Errorf("action %s did not end execution nor returned an observation", action.Action)
}
}

// stepOutput represents the inputs and outputs of a single thought step.
type stepOutput struct {
// if the agent is done, finish is set.
Expand Down Expand Up @@ -189,6 +211,10 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
// If action is set, the agent is not done and called upon a tool.
progressUpdates(action)

return a.doAction(ctx, state, action)
}

func (a *Agent) doAction(ctx context.Context, state *executionState, action *AgentAction) (stepOutput, error) {
var tool tools.Tool
for _, candidate := range a.tools {
if candidate.Name() == action.Action {
Expand Down
4 changes: 3 additions & 1 deletion lib/ai/model/tools/auditquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ import (
"github.com/gravitational/teleport/lib/ai/tokens"
)

const AuditQueryGenerationToolName = "Audit Query Generation"

type AuditQueryGenerationTool struct {
LLM *openai.Client
}

func (t *AuditQueryGenerationTool) Name() string {
return "Audit Query Generation"
return AuditQueryGenerationToolName
}

func (t *AuditQueryGenerationTool) Description() string {
Expand Down
39 changes: 39 additions & 0 deletions lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,45 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e
return a.client.Summary(ctx, message)
}

// RunTool runs a model tool without an ai.Chat.
func (a *Assist) RunTool(ctx context.Context, onMessage onMessageFunc, toolName, userInput string, toolContext *tools.ToolContext,
) (*tokens.TokenCount, error) {
message, tc, err := a.client.RunTool(ctx, toolContext, toolName, userInput)
if err != nil {
return nil, trace.Wrap(err)
}

switch message := message.(type) {
case *output.Message:
if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), a.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *output.GeneratedCommand:
if err := onMessage(MessageKindCommand, []byte(message.Command), a.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *output.StreamingMessage:
if err := func() error {
var text strings.Builder
defer onMessage(MessageKindAssistantPartialFinalize, nil, a.clock.Now().UTC())
for part := range message.Parts {
text.WriteString(part)

if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), a.clock.Now().UTC()); err != nil {
return trace.Wrap(err)
}
}
return nil
}(); err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.Errorf("Unexpected message type: %T", message)
}

return tc, nil
}

// GenerateCommandSummary summarizes the output of a command executed on one or
// many nodes. The conversation history is also sent into the prompt in order
// to gather context and know what information is relevant in the command output.
Expand Down
47 changes: 47 additions & 0 deletions lib/web/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ const (
actionSSHGenerateCommand = "ssh-cmdgen"
// actionSSHExplainCommand is a name of the action for explaining terminal output in SSH session.
actionSSHExplainCommand = "ssh-explain"
// actionGenerateAuditQuery is the name of the action for generating audit queries.
actionGenerateAuditQuery = "audit-query"
)

// createAssistantConversationResponse is a response for POST /webapi/assistant/conversations.
Expand Down Expand Up @@ -489,6 +491,8 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
err = h.assistGenSSHCommandLoop(ctx, assistClient, ws, sctx.GetUser(), authClient)
case actionSSHExplainCommand:
err = h.assistSSHExplainOutputLoop(ctx, assistClient, ws, authClient)
case actionGenerateAuditQuery:
err = h.assistGenAuditQueryLoop(ctx, assistClient, ws, sctx.GetUser(), authClient)
default:
err = h.assistChatLoop(ctx, assistClient, authClient, conversationID, sctx, ws)
}
Expand All @@ -500,6 +504,49 @@ type usageReporter interface {
SubmitUsageEvent(ctx context.Context, req *proto.SubmitUsageEventRequest) error
}

// assistGenAuditQueryLoop reads the user's input and generates an audit query.
func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, username string, usageRep usageReporter) error {
for {
_, payload, err := ws.ReadMessage()
if err != nil {
if wsIsClosed(err) {
break
}
return trace.Wrap(err)
}

onMessage := func(kind assist.MessageType, payload []byte, createdTime time.Time) error {
return onMessageFn(ws, kind, payload, createdTime)
}

toolCtx := &tools.ToolContext{User: username}

tokenCount, err := assistClient.RunTool(ctx, onMessage, tools.AuditQueryGenerationToolName, string(payload), toolCtx)
if err != nil {
return trace.Wrap(err)
}

prompt, completion := tokens.CountTokens(tokenCount)

usageEventReq := &clientproto.SubmitUsageEventRequest{
Event: &usageeventsv1.UsageEventOneOf{
Event: &usageeventsv1.UsageEventOneOf_AssistAction{
AssistAction: &usageeventsv1.AssistAction{
Action: actionGenerateAuditQuery,
TotalTokens: int64(completion + prompt),
PromptTokens: int64(prompt),
CompletionTokens: int64(completion),
},
},
},
}
if err := usageRep.SubmitUsageEvent(ctx, usageEventReq); err != nil {
h.log.WithError(err).Warn("Failed to emit usage event")
}
}
return nil
}

// assistSSHExplainOutputLoop reads the user's input and generates a command summary.
func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, usageRep usageReporter) error {
_, payload, err := ws.ReadMessage()
Expand Down