diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go index 064a02ae59807..ea96dc7db4cb7 100644 --- a/lib/ai/chat_test.go +++ b/lib/ai/chat_test.go @@ -321,13 +321,13 @@ func generateAccessRequestResponse(t *testing.T) string { func TestChat_Complete_AuditQuery(t *testing.T) { // Test setup: generate the responses that will be served by our OpenAI mock action := model.PlanOutput{ - Action: "Audit Query Generation", + Action: tools.AuditQueryGenerationToolName, ActionInput: "Lists user who connected to a server as root.", Reasoning: "foo", } selectedAction, err := json.Marshal(action) require.NoError(t, err) - generatedQuery := "SELECT user FROM session_start WHERE login='root'" + const generatedQuery = "SELECT user FROM session_start WHERE login='root'" responses := []string{ // The model must select the audit query tool diff --git a/lib/ai/client.go b/lib/ai/client.go index 532fa935fe65b..3ec1d29807551 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -92,6 +92,30 @@ func (client *Client) NewCommand(username string) *Chat { } } +func (client *Client) RunTool(ctx context.Context, toolContext *modeltools.ToolContext, toolName, toolInput string) (any, *tokens.TokenCount, error) { + tools := []modeltools.Tool{ + &modeltools.CommandExecutionTool{}, + &modeltools.EmbeddingRetrievalTool{}, + &modeltools.AuditQueryGenerationTool{LLM: client.svc}, + } + // The following tools are only available in the enterprise build. They will fail + // if included in OSS due to the lack of the required backend APIs. + if modules.GetModules().BuildType() == modules.BuildEnterprise { + tools = append(tools, &modeltools.AccessRequestCreateTool{}, + &modeltools.AccessRequestsListTool{}, + &modeltools.AccessRequestListRequestableRolesTool{}, + &modeltools.AccessRequestListRequestableResourcesTool{}) + } + agent := model.NewAgent(toolContext, tools...) + action := &model.AgentAction{ + Action: toolName, + Input: toolInput, + Reasoning: "Tool invoked directly", + } + + return agent.DoAction(ctx, client.svc, action) +} + func (client *Client) NewAuditQuery(username string) *Chat { toolContext := &modeltools.ToolContext{User: username} return &Chat{ diff --git a/lib/ai/client_test.go b/lib/ai/client_test.go new file mode 100644 index 0000000000000..2dda4461b75c4 --- /dev/null +++ b/lib/ai/client_test.go @@ -0,0 +1,110 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai + +import ( + "context" + "encoding/json" + "net/http/httptest" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/lib/ai/model/output" + "github.com/gravitational/teleport/lib/ai/model/tools" + "github.com/gravitational/teleport/lib/ai/testutils" +) + +func TestRunTool_AuditQueryGeneration(t *testing.T) { + // Test setup: starting a mock openai server and creating the client + const generatedQuery = "SELECT user FROM session_start WHERE login='root'" + + responses := []string{ + // Then the audit query tool chooses to request session.start events + "session.start", + // Finally the tool builds a query based on the provided schemas + generatedQuery, + } + server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + cfg := openai.DefaultConfig("secret-test-token") + cfg.BaseURL = server.URL + + client := NewClientFromConfig(cfg) + + // Doing the test: Check that the AuditQueryGeneration tool can be invoked + // through client.RunTool and validate its response. + ctx := context.Background() + toolCtx := &tools.ToolContext{User: "alice"} + response, _, err := client.RunTool(ctx, toolCtx, tools.AuditQueryGenerationToolName, "List users who connected to a server as root") + require.NoError(t, err) + message, ok := response.(*output.StreamingMessage) + require.True(t, ok) + require.Equal(t, generatedQuery, message.WaitAndConsume()) +} + +type mockEmbeddingGetter struct { + response []*assistpb.EmbeddedDocument +} + +func (m *mockEmbeddingGetter) GetAssistantEmbeddings(ctx context.Context, in *assistpb.GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*assistpb.GetAssistantEmbeddingsResponse, error) { + return &assistpb.GetAssistantEmbeddingsResponse{Embeddings: m.response}, nil +} + +func TestRunTool_EmbeddingRetrieval(t *testing.T) { + // Test setup: starting a mock openai server and embedding getter, + // then create the client. + mock := &mockEmbeddingGetter{ + []*assistpb.EmbeddedDocument{ + { + Id: "1", + Content: "foo", + SimilarityScore: 1, + }, + { + Id: "2", + Content: "bar", + SimilarityScore: 0.9, + }, + }, + } + ctx := context.Background() + toolCtx := &tools.ToolContext{AssistEmbeddingServiceClient: mock} + + responses := make([]string, 0) + server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + cfg := openai.DefaultConfig("secret-test-token") + cfg.BaseURL = server.URL + client := NewClientFromConfig(cfg) + + // Doing the test: Check that the EmbeddingRetrieval tool can be invoked + // through client.RunTool and validate its response. + input := tools.EmbeddingRetrievalToolInput{Question: "Find foobar"} + inputText, err := json.Marshal(input) + require.NoError(t, err) + response, _, err := client.RunTool(ctx, toolCtx, "Nodes names and labels retrieval", string(inputText)) + require.NoError(t, err) + message, ok := response.(*output.Message) + require.True(t, ok) + require.Equal(t, "foo\nbar\n", message.Content) +} diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go index 812b32208fafc..e3522d5755a6b 100644 --- a/lib/ai/model/agent.go +++ b/lib/ai/model/agent.go @@ -133,6 +133,28 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist } } +func (a *Agent) DoAction(ctx context.Context, llm *openai.Client, action *AgentAction) (any, *tokens.TokenCount, error) { + state := &executionState{ + llm: llm, + tokenCount: tokens.NewTokenCount(), + } + out, err := a.doAction(ctx, state, action) + if err != nil { + return nil, nil, trace.Wrap(err) + } + switch { + case out.finish != nil: + // If the tool already breaks execution, we don't have to do anything + return out.finish.output, state.tokenCount, nil + case out.observation != "": + // If the tool doesn't break execution and returns a single observation, + // we wrap the observation in a Message. + return &output.Message{Content: out.observation}, state.tokenCount, nil + default: + return nil, state.tokenCount, trace.Errorf("action %s did not end execution nor returned an observation", action.Action) + } +} + // stepOutput represents the inputs and outputs of a single thought step. type stepOutput struct { // if the agent is done, finish is set. @@ -189,6 +211,10 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres // If action is set, the agent is not done and called upon a tool. progressUpdates(action) + return a.doAction(ctx, state, action) +} + +func (a *Agent) doAction(ctx context.Context, state *executionState, action *AgentAction) (stepOutput, error) { var tool tools.Tool for _, candidate := range a.tools { if candidate.Name() == action.Action { diff --git a/lib/ai/model/tools/auditquery.go b/lib/ai/model/tools/auditquery.go index f60d855c416fc..86def094ba98a 100644 --- a/lib/ai/model/tools/auditquery.go +++ b/lib/ai/model/tools/auditquery.go @@ -30,12 +30,14 @@ import ( "github.com/gravitational/teleport/lib/ai/tokens" ) +const AuditQueryGenerationToolName = "Audit Query Generation" + type AuditQueryGenerationTool struct { LLM *openai.Client } func (t *AuditQueryGenerationTool) Name() string { - return "Audit Query Generation" + return AuditQueryGenerationToolName } func (t *AuditQueryGenerationTool) Description() string { diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 6409b7f56e5fe..d9d52b9b5d8e4 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -176,6 +176,45 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e return a.client.Summary(ctx, message) } +// RunTool runs a model tool without an ai.Chat. +func (a *Assist) RunTool(ctx context.Context, onMessage onMessageFunc, toolName, userInput string, toolContext *tools.ToolContext, +) (*tokens.TokenCount, error) { + message, tc, err := a.client.RunTool(ctx, toolContext, toolName, userInput) + if err != nil { + return nil, trace.Wrap(err) + } + + switch message := message.(type) { + case *output.Message: + if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), a.clock.Now().UTC()); err != nil { + return nil, trace.Wrap(err) + } + case *output.GeneratedCommand: + if err := onMessage(MessageKindCommand, []byte(message.Command), a.clock.Now().UTC()); err != nil { + return nil, trace.Wrap(err) + } + case *output.StreamingMessage: + if err := func() error { + var text strings.Builder + defer onMessage(MessageKindAssistantPartialFinalize, nil, a.clock.Now().UTC()) + for part := range message.Parts { + text.WriteString(part) + + if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), a.clock.Now().UTC()); err != nil { + return trace.Wrap(err) + } + } + return nil + }(); err != nil { + return nil, trace.Wrap(err) + } + default: + return nil, trace.Errorf("Unexpected message type: %T", message) + } + + return tc, nil +} + // GenerateCommandSummary summarizes the output of a command executed on one or // many nodes. The conversation history is also sent into the prompt in order // to gather context and know what information is relevant in the command output. diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 12f96bde953c6..1b507d6c27825 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -48,6 +48,8 @@ const ( actionSSHGenerateCommand = "ssh-cmdgen" // actionSSHExplainCommand is a name of the action for explaining terminal output in SSH session. actionSSHExplainCommand = "ssh-explain" + // actionGenerateAuditQuery is the name of the action for generating audit queries. + actionGenerateAuditQuery = "audit-query" ) // createAssistantConversationResponse is a response for POST /webapi/assistant/conversations. @@ -489,6 +491,8 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, err = h.assistGenSSHCommandLoop(ctx, assistClient, ws, sctx.GetUser(), authClient) case actionSSHExplainCommand: err = h.assistSSHExplainOutputLoop(ctx, assistClient, ws, authClient) + case actionGenerateAuditQuery: + err = h.assistGenAuditQueryLoop(ctx, assistClient, ws, sctx.GetUser(), authClient) default: err = h.assistChatLoop(ctx, assistClient, authClient, conversationID, sctx, ws) } @@ -500,6 +504,49 @@ type usageReporter interface { SubmitUsageEvent(ctx context.Context, req *proto.SubmitUsageEventRequest) error } +// assistGenAuditQueryLoop reads the user's input and generates an audit query. +func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, username string, usageRep usageReporter) error { + for { + _, payload, err := ws.ReadMessage() + if err != nil { + if wsIsClosed(err) { + break + } + return trace.Wrap(err) + } + + onMessage := func(kind assist.MessageType, payload []byte, createdTime time.Time) error { + return onMessageFn(ws, kind, payload, createdTime) + } + + toolCtx := &tools.ToolContext{User: username} + + tokenCount, err := assistClient.RunTool(ctx, onMessage, tools.AuditQueryGenerationToolName, string(payload), toolCtx) + if err != nil { + return trace.Wrap(err) + } + + prompt, completion := tokens.CountTokens(tokenCount) + + usageEventReq := &clientproto.SubmitUsageEventRequest{ + Event: &usageeventsv1.UsageEventOneOf{ + Event: &usageeventsv1.UsageEventOneOf_AssistAction{ + AssistAction: &usageeventsv1.AssistAction{ + Action: actionGenerateAuditQuery, + TotalTokens: int64(completion + prompt), + PromptTokens: int64(prompt), + CompletionTokens: int64(completion), + }, + }, + }, + } + if err := usageRep.SubmitUsageEvent(ctx, usageEventReq); err != nil { + h.log.WithError(err).Warn("Failed to emit usage event") + } + } + return nil +} + // assistSSHExplainOutputLoop reads the user's input and generates a command summary. func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, usageRep usageReporter) error { _, payload, err := ws.ReadMessage()