diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 986880c00f244..6094f44d4ebfc 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -65,6 +65,11 @@ func (chat *Chat) Complete(ctx context.Context, userInput string, progressUpdate }, model.NewTokenCount(), nil } + return chat.Reply(ctx, userInput, progressUpdates) +} + +// Reply replies to the user input with a message from the assistant based on the current context. +func (chat *Chat) Reply(ctx context.Context, userInput string, progressUpdates func(*model.AgentAction)) (any, *model.TokenCount, error) { userMessage := openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: userInput, diff --git a/lib/ai/client.go b/lib/ai/client.go index f03377f29e3dc..9ae56effbbb2c 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -51,7 +51,34 @@ func NewClientFromConfig(config openai.ClientConfig) *Client { // toolsConfig contains all required clients and configuration for agent tools // to interact with Teleport. func (client *Client) NewChat(username string, toolsConfig model.ToolsConfig) (*Chat, error) { - agent, err := model.NewAgent(username, toolsConfig) + tools := []model.Tool{ + model.NewExecutionTool(), + } + if !toolsConfig.DisableEmbeddingsTool { + tools = append(tools, model.NewRetrievalTool(toolsConfig.EmbeddingsClient, toolsConfig.NodeClient, + toolsConfig.AccessChecker, username)) + } + agent, err := model.NewAgent(tools...) + if err != nil { + return nil, trace.Wrap(err) + } + return &Chat{ + client: client, + messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: model.PromptCharacter(username), + }, + }, + // Initialize a tokenizer for prompt token accounting. + // Cl100k is used by GPT-3 and GPT-4. + tokenizer: codec.NewCl100kBase(), + agent: agent, + }, nil +} + +func (client *Client) NewCommand(username string) (*Chat, error) { + agent, err := model.NewAgent(model.NewGenerateTool()) if err != nil { return nil, trace.Wrap(err) } @@ -121,7 +148,7 @@ func (client *Client) CommandSummary(ctx context.Context, messages []openai.Chat return completion, tc, trace.Wrap(err) } -// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero shot classifier. +// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero-shot classifier. func (client *Client) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) { resp, err := client.svc.CreateChatCompletion( ctx, diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go index c111a8a9c5a08..3db8f7d07f856 100644 --- a/lib/ai/model/agent.go +++ b/lib/ai/model/agent.go @@ -47,23 +47,34 @@ const ( finalResponseHeader = "" ) -// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature. -func NewAgent(username string, config ToolsConfig) (*Agent, error) { - err := config.CheckAndSetDefaults() - if err != nil { - return nil, trace.Wrap(err) - } +// NewExecutionTool creates a new execution tool. The execution tool is responsible for executing commands. +func NewExecutionTool() Tool { + return &commandExecutionTool{} +} - tools := []Tool{&commandExecutionTool{}} +// NewGenerateTool creates a new generation tool. The generation tool is responsible for generating Bash commands. +func NewGenerateTool() Tool { + return &commandGenerationTool{} +} - if !config.DisableEmbeddingsTool { - tools = append(tools, - &embeddingRetrievalTool{ - assistClient: config.EmbeddingsClient, - currentUser: username, - nodeClient: config.NodeClient, - userAccessChecker: config.AccessChecker, - }) +// NewRetrievalTool creates a new retrieval tool. The retrieval tool is responsible for retrieving embeddings. +func NewRetrievalTool(assistClient assist.AssistEmbeddingServiceClient, + nodeClient NodeGetter, + userAccessChecker services.AccessChecker, + currentUser string, +) Tool { + return &embeddingRetrievalTool{ + assistClient: assistClient, + currentUser: currentUser, + nodeClient: nodeClient, + userAccessChecker: userAccessChecker, + } +} + +// NewAgent creates a new agent. The Assist agent which defines the model responsible for the Assist feature. +func NewAgent(tools ...Tool) (*Agent, error) { + if len(tools) == 0 { + return nil, trace.BadParameter("at least one tool is required") } return &Agent{ @@ -264,6 +275,25 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres return stepOutput{finish: &agentFinish{output: completion}}, nil } + if tool, ok := tool.(*commandGenerationTool); ok { + input, err := tool.parseInput(action.Input) + if err != nil { + action := &AgentAction{ + Action: actionException, + Input: observationPrefix + "Invalid or incomplete response", + Log: thoughtPrefix + err.Error(), + } + + return stepOutput{action: action, observation: action.Input}, nil + } + completion := &GeneratedCommand{ + Command: input.Command, + } + + log.Tracef("agent decided on command generation, let's translate to an agentFinish") + return stepOutput{finish: &agentFinish{output: completion}}, nil + } + runOut, err := tool.Run(ctx, action.Input) if err != nil { return stepOutput{}, trace.Wrap(err) diff --git a/lib/ai/model/generationtool.go b/lib/ai/model/generationtool.go new file mode 100644 index 0000000000000..1d83940b13899 --- /dev/null +++ b/lib/ai/model/generationtool.go @@ -0,0 +1,80 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "context" + "fmt" + + "github.com/gravitational/trace" +) + +type commandGenerationTool struct{} + +type commandGenerationToolInput struct { + // Command is a unix command to execute. + Command string `json:"command"` +} + +func (c *commandGenerationTool) Name() string { + return "Command Generation" +} + +func (c *commandGenerationTool) Description() string { + // acknowledgement field is used to convince the LLM to return the JSON. + // Base on my testing LLM ignores the JSON when the schema has only one field. + // Adding additional "pseudo-fields" to the schema makes the LLM return the JSON. + return fmt.Sprintf(`Generate a Bash command. +The input must be a JSON object with the following schema: +%vjson +{ + "command": string, \\ The generated command + "acknowledgement": boolean \\ Set to true to ackowledge that you understand the formatting +} +%v +`, "```", "```") +} + +func (c *commandGenerationTool) Run(_ context.Context, _ string) (string, error) { + // This is stubbed because commandGenerationTool is handled specially. + // This is because execution of this tool breaks the loop and returns a command suggestion to the user. + // It is still handled as a tool because testing has shown that the LLM behaves better when it is treated as a tool. + // + // In addition, treating it as a Tool interface item simplifies the display and prompt assembly logic significantly. + return "", trace.NotImplemented("not implemented") +} + +// parseInput is called in a special case if the planned tool is commandExecutionTool. +// This is because commandExecutionTool is handled differently from most other tools and forcibly terminates the thought loop. +func (*commandGenerationTool) parseInput(input string) (*commandGenerationToolInput, error) { + output, err := parseJSONFromModel[commandGenerationToolInput](input) + if err != nil { + return nil, err + } + + if output.Command == "" { + return nil, &invalidOutputError{ + coarse: "command generation: missing command", + detail: "command must be non-empty", + } + } + + // Ignore the acknowledgement field. + // We do not care about the value. Having the command it enough. + + return &output, nil +} diff --git a/lib/ai/model/messages.go b/lib/ai/model/messages.go index 7774afad27946..687314ba85331 100644 --- a/lib/ai/model/messages.go +++ b/lib/ai/model/messages.go @@ -50,3 +50,8 @@ type CompletionCommand struct { Nodes []string `json:"nodes,omitempty"` Labels []Label `json:"labels,omitempty"` } + +// GeneratedCommand represents a Bash command generated by LLM. +type GeneratedCommand struct { + Command string `json:"command"` +} diff --git a/lib/ai/model/prompt.go b/lib/ai/model/prompt.go index fd940c6c72773..54ceef1f66adf 100644 --- a/lib/ai/model/prompt.go +++ b/lib/ai/model/prompt.go @@ -111,7 +111,8 @@ func ConversationCommandResult(result map[string][]byte) string { message.WriteString(string(output)) message.WriteString("\n") } - message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.") + message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary. " + + "For error messages suggest a solution if possible. The solution can contain a Linux command or a description.") return message.String() } diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 3c8db4f3ef18c..62633d7c03d62 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -144,6 +144,31 @@ func (a *Assist) NewChat(ctx context.Context, assistService MessageService, return chat, nil } +// LightweightChat is a Teleport Assist chat that doesn't store the history +// of the conversation. +type LightweightChat struct { + assist *Assist + chat *ai.Chat +} + +// NewLightweightChat creates a new Assist chat what doesn't store the history +// of the conversation. +func (a *Assist) NewLightweightChat(username string) (*LightweightChat, error) { + aichat, err := a.client.NewCommand(username) // TODO(jakule): fix this after all in-flight PRs are merged + if err != nil { + return nil, trace.Wrap(err) + } + + return &LightweightChat{ + assist: a, + chat: aichat, + }, nil +} + +func (a *Assist) NewSSHCommand(username string) (*ai.Chat, error) { + return a.client.NewCommand(username) +} + // GenerateSummary generates a summary for the given message. func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, error) { return a.client.Summary(ctx, message) @@ -179,7 +204,7 @@ func (c *Chat) reloadMessages(ctx context.Context) error { } // ClassifyMessage takes a user message, a list of categories, and uses the AI -// mode as a zero shot classifier. It returns an error if the classification +// mode as a zero-shot classifier. It returns an error if the classification // result is not a valid class. func (a *Assist) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) { category, err := a.client.ClassifyMessage(ctx, message, classes) @@ -406,6 +431,63 @@ func (c *Chat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, use return tokenCount, nil } +// ProcessComplete processes a user message and returns the assistant's response. +func (c *LightweightChat) ProcessComplete(ctx context.Context, onMessage onMessageFunc, userInput string, +) (*model.TokenCount, error) { + progressUpdates := func(update *model.AgentAction) { + payload, err := json.Marshal(update) + if err != nil { + log.WithError(err).Debugf("Failed to marshal progress update: %v", update) + return + } + + if err := onMessage(MessageKindProgressUpdate, payload, c.assist.clock.Now().UTC()); err != nil { + log.WithError(err).Debugf("Failed to send progress update: %v", update) + return + } + } + + message, tokenCount, err := c.chat.Reply(ctx, userInput, progressUpdates) + if err != nil { + return nil, trace.Wrap(err) + } + + c.chat.Insert(openai.ChatMessageRoleUser, userInput) + + switch message := message.(type) { + case *model.Message: + c.chat.Insert(openai.ChatMessageRoleAssistant, message.Content) + if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), c.assist.clock.Now().UTC()); err != nil { + return nil, trace.Wrap(err) + } + case *model.GeneratedCommand: + c.chat.Insert(openai.ChatMessageRoleAssistant, message.Command) + if err := onMessage(MessageKindCommand, []byte(message.Command), c.assist.clock.Now().UTC()); err != nil { + return nil, trace.Wrap(err) + } + case *model.StreamingMessage: + if err := func() error { + var text strings.Builder + defer onMessage(MessageKindAssistantPartialFinalize, nil, c.assist.clock.Now().UTC()) + for part := range message.Parts { + text.WriteString(part) + + if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), c.assist.clock.Now().UTC()); err != nil { + return trace.Wrap(err) + } + } + c.chat.Insert(openai.ChatMessageRoleAssistant, text.String()) + return nil + }(); err != nil { + return nil, trace.Wrap(err) + } + default: + return nil, trace.Errorf("Unexpected message type: %T", message) + } + + return tokenCount, nil +} + func getOpenAITokenFromDefaultPlugin(ctx context.Context, proxyClient PluginGetter) (string, error) { // Try retrieving credentials from the plugin resource first openaiPlugin, err := proxyClient.PluginsClient().GetPlugin(ctx, &pluginsv1.GetPluginRequest{ diff --git a/lib/web/assistant.go b/lib/web/assistant.go index ce4c022006952..dcfb2108e38fb 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -41,6 +41,17 @@ import ( "github.com/gravitational/teleport/lib/reversetunnelclient" ) +// We can not know how many tokens we will consume in advance. +// Try to consume a small number of tokens first. +const lookaheadTokens = 100 + +const ( + // actionSSHGenerateCommand is a name of the action for generating SSH commands. + actionSSHGenerateCommand = "ssh-cmdgen" + // actionSSHExplainCommand is a name of the action for explaining terminal output in SSH session. + actionSSHExplainCommand = "ssh-explain" +) + // createAssistantConversationResponse is a response for POST /webapi/assistant/conversations. type createdAssistantConversationResponse struct { // ID is a conversation ID. @@ -315,6 +326,9 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, return conversationInfo, nil } +// assistant is a handler for GET /webapi/sites/:site/assistant. +// This handler covers the main chat conversation as well as the +// SSH completition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ) (any, error) { @@ -327,6 +341,7 @@ func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter } func (h *Handler) reportTokenUsage(usedTokens *model.TokenCount, lookaheadTokens int, conversationID string, authClient auth.ClientI) { + // Create a new context to not be bounded by the request timeout. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -375,8 +390,9 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") - if conversationID == "" { - return trace.BadParameter("conversation ID is required") + actionParam := r.URL.Query().Get("action") + if conversationID == "" && actionParam == "" { + return trace.BadParameter("conversation ID or action is required") } authClient, err := sctx.GetClient() @@ -454,8 +470,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, // Update the read deadline upon receiving a pong message. ws.SetPongHandler(func(_ string) error { - ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) - return nil + return trace.Wrap(ws.SetReadDeadline(deadlineForInterval(keepAliveInterval))) }) ws.SetCloseHandler(func(code int, text string) error { @@ -471,6 +486,82 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } + switch r.URL.Query().Get("action") { + case actionSSHGenerateCommand: + err = h.assistGenSSHCommandLoop(ctx, assistClient, ws, sctx.GetUser()) + case actionSSHExplainCommand: + err = h.assistSSHExplainOutputLoop(ctx, assistClient, ws) + default: + err = h.assistChatLoop(ctx, assistClient, authClient, conversationID, sctx, ws) + } + + return trace.Wrap(err) +} + +// assistSSHExplainOutputLoop reads the user's input and generates a command summary. +func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn) error { + _, payload, err := ws.ReadMessage() + if err != nil { + if wsIsClosed(err) { + return nil + } + return trace.Wrap(err) + } + + modelMessages := []*assistpb.AssistantMessage{ + { + Type: string(assist.MessageKindUserMessage), + Payload: string(payload), + }, + } + + summary, _, err := assistClient.GenerateCommandSummary(ctx, modelMessages, map[string][]byte{}) + if err != nil { + return trace.Wrap(err) + } + + if err := onMessageFn(ws, assist.MessageKindAssistantMessage, []byte(summary), h.clock.Now().UTC()); err != nil { + return trace.Wrap(err) + } + + //TODO(jakule): add token usage reporting when events are added to posthog + + return nil +} + +// assistSSHCommandLoop reads the user's input and generates a Linux command. +func (h *Handler) assistGenSSHCommandLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, username string) error { + chat, err := assistClient.NewLightweightChat(username) + if err != nil { + return trace.Wrap(err) + } + + for { + _, payload, err := ws.ReadMessage() + if err != nil { + if wsIsClosed(err) { + break + } + return trace.Wrap(err) + } + + _, err = chat.ProcessComplete(ctx, func(kind assist.MessageType, payload []byte, createdTime time.Time) error { + return onMessageFn(ws, kind, payload, createdTime) + }, string(payload)) + if err != nil { + return trace.Wrap(err) + } + + //TODO(jakule): add token usage reporting when events are added to posthog + } + return nil +} + +// assistChatLoop is the main chat loop for the assistant. +// It reads the user's input from provided WS and generates a response. +func (h *Handler) assistChatLoop(ctx context.Context, assistClient *assist.Assist, authClient auth.ClientI, + conversationID string, sctx *SessionContext, ws *websocket.Conn, +) error { ac, err := sctx.GetUserAccessChecker() if err != nil { return trace.Wrap(err) @@ -486,20 +577,13 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - // onMessageFn is called when a message is received from the OpenAI API. - onMessageFn := func(kind assist.MessageType, payload []byte, createdTime time.Time) error { - msg := &assistantMessage{ - Type: kind, - Payload: string(payload), - CreatedTime: createdTime.Format(time.RFC3339), - } - - return trace.Wrap(ws.WriteJSON(msg)) + onMessage := func(kind assist.MessageType, payload []byte, createdTime time.Time) error { + return trace.Wrap(onMessageFn(ws, kind, payload, createdTime)) } if chat.IsNewConversation() { // new conversation, generate a hello message - if _, err := chat.ProcessComplete(ctx, onMessageFn, ""); err != nil { + if _, err := chat.ProcessComplete(ctx, onMessage, ""); err != nil { return trace.Wrap(err) } } @@ -507,8 +591,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, for { _, payload, err := ws.ReadMessage() if err != nil { - if err == io.EOF || websocket.IsCloseError(err, websocket.CloseAbnormalClosure, - websocket.CloseGoingAway, websocket.CloseNormalClosure) { + if wsIsClosed(err) { break } return trace.Wrap(err) @@ -519,11 +602,8 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - // We can not know how many tokens we will consume in advance. - // Try to consume a small amount of tokens first. - const lookaheadTokens = 100 - if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) { - err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) + if !h.assistantLimiter.AllowN(h.clock.Now(), lookaheadTokens) { + err := onMessageFn(ws, assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) if err != nil { return trace.Wrap(err) } @@ -531,7 +611,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, } //TODO(jakule): Should we sanitize the payload? - usedTokens, err := chat.ProcessComplete(ctx, onMessageFn, wsIncoming.Payload) + usedTokens, err := chat.ProcessComplete(ctx, onMessage, wsIncoming.Payload) if err != nil { return trace.Wrap(err) } @@ -544,3 +624,20 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, h.log.Debug("end assistant conversation loop") return nil } + +// wsIsClosed returns true if the error is caused by a closed websocket. +func wsIsClosed(err error) bool { + return err == io.EOF || websocket.IsCloseError(err, websocket.CloseAbnormalClosure, + websocket.CloseGoingAway, websocket.CloseNormalClosure) +} + +// onMessageFn is called when a message is received from the OpenAI API. +func onMessageFn(ws *websocket.Conn, kind assist.MessageType, payload []byte, createdTime time.Time) error { + msg := &assistantMessage{ + Type: kind, + Payload: string(payload), + CreatedTime: createdTime.Format(time.RFC3339), + } + + return trace.Wrap(ws.WriteJSON(msg)) +} diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 55c7dd1b96122..8632d07a7ab5f 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -151,16 +151,7 @@ func Test_runAssistant(t *testing.T) { if tc.setup != nil { tc.setup(t, s) } - - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) - require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) + assistRole := allowAssistAccess(t, s) ctx := context.Background() authPack := s.authPack(t, "foo", assistRole.GetName()) @@ -168,7 +159,7 @@ func Test_runAssistant(t *testing.T) { conversationID := s.makeAssistConversation(t, ctx, authPack) // Make WS client and start the conversation - ws, err := s.makeAssistant(t, authPack, conversationID) + ws, err := s.makeAssistant(t, authPack, conversationID, "") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -211,7 +202,7 @@ func Test_runAssistError(t *testing.T) { require.NoError(t, err) _, payload, err := ws.ReadMessage() - require.NoError(t, err) + require.NoError(t, err, "expected error message, payload: %s", payload) var msg assistantMessage err = json.Unmarshal(payload, &msg) @@ -250,24 +241,15 @@ func Test_runAssistError(t *testing.T) { openaiCfg.BaseURL = server.URL s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) - ctx := context.Background() - - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) - require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) - + assistRole := allowAssistAccess(t, s) authPack := s.authPack(t, "foo", assistRole.GetName()) + + ctx := context.Background() // Create the conversation conversationID := s.makeAssistConversation(t, ctx, authPack) // Make WS client and start the conversation - ws, err := s.makeAssistant(t, authPack, conversationID) + ws, err := s.makeAssistant(t, authPack, conversationID, "") require.NoError(t, err) t.Cleanup(func() { // The TLS connection might or might not be closed, this is an implementation detail. @@ -287,59 +269,88 @@ func Test_runAssistError(t *testing.T) { require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure") } -// makeAssistConversation creates a new assist conversation and returns its ID -func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { - clt := authPack.clt +func Test_SSHCommandGeneration(t *testing.T) { + t.Parallel() - resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil) + assertGenCommand := func(ws *websocket.Conn) { + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect "hello" message + require.Equal(t, assist.MessageKindProgressUpdate, msg.Type) + require.Contains(t, msg.Payload, "openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj") + } + + responses := []string{generateCommandResponse()} + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) + + assistRole := allowAssistAccess(t, s) + authPack := s.authPack(t, "foo", assistRole.GetName()) + + // Make WS client and start the conversation + ws, err := s.makeAssistant(t, authPack, "", "ssh-cmdgen") require.NoError(t, err) + t.Cleanup(func() { + err := ws.Close() + require.NoError(t, err) + }) - convResp := struct { - ConversationID string `json:"id"` - }{} - err = json.Unmarshal(resp.Bytes(), &convResp) + err = ws.WriteMessage(websocket.TextMessage, []byte(`{"input:" "My cert expired!!! What is x509?"}`)) require.NoError(t, err) - return convResp.ConversationID + // verify responses + assertGenCommand(ws) } -// makeAssistant creates a new assistant websocket connection. -func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID string) (*websocket.Conn, error) { - u := url.URL{ - Host: s.url().Host, - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%s/assistant", currentSiteShortcut), - } +func Test_SSHCommandExplain(t *testing.T) { + t.Parallel() - q := u.Query() - q.Set("conversation_id", conversationID) - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) - u.RawQuery = q.Encode() + assertResponse := func(ws *websocket.Conn) { + _, payload, err := ws.ReadMessage() + require.NoError(t, err) - dialer := websocket.Dialer{} - dialer.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) - header := http.Header{} - header.Add("Origin", "http://localhost") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) + // Expect "hello" message + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + require.Contains(t, msg.Payload, "The application has failed to connect to the database. The database is not running.") } - ws, resp, err := dialer.Dial(u.String(), header) - if err != nil { - res, err2 := io.ReadAll(resp.Body) - t.Log("response body:", string(res), err2) - return nil, trace.Wrap(err) - } + responses := []string{commandSummaryResponse()} + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) - err = resp.Body.Close() - if err != nil { - return nil, trace.Wrap(err) - } + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) - return ws, nil + assistRole := allowAssistAccess(t, s) + authPack := s.authPack(t, "foo", assistRole.GetName()) + + // Make WS client and start the conversation + ws, err := s.makeAssistant(t, authPack, "", "ssh-explain") + require.NoError(t, err) + t.Cleanup(func() { + err := ws.Close() + require.NoError(t, err) + }) + + err = ws.WriteMessage(websocket.TextMessage, []byte(`{"input:" "listen tcp 0.0.0.0:5432: bind: address already in use"}`)) + require.NoError(t, err) + + // verify responses + assertResponse(ws) } func Test_generateAssistantTitle(t *testing.T) { @@ -360,14 +371,7 @@ func Test_generateAssistantTitle(t *testing.T) { OpenAIConfig: &openaiCfg, }) - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) + assistRole := allowAssistAccess(t, s) require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) pack := s.authPack(t, "foo", assistRole.GetName()) @@ -388,7 +392,98 @@ func Test_generateAssistantTitle(t *testing.T) { require.NotEmpty(t, info.Title) } +func allowAssistAccess(t *testing.T, s *WebSuite) types.Role { + assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindAssistant, services.RW()), + }, + }, + }) + require.NoError(t, err) + require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) + + return assistRole +} + +// makeAssistConversation creates a new assist conversation and returns its ID +func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { + clt := authPack.clt + + resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil) + require.NoError(t, err) + + convResp := struct { + ConversationID string `json:"id"` + }{} + err = json.Unmarshal(resp.Bytes(), &convResp) + require.NoError(t, err) + + return convResp.ConversationID +} + +// makeAssistant creates a new assistant websocket connection. +func (s *WebSuite) makeAssistant(_ *testing.T, pack *authPack, conversationID, action string) (*websocket.Conn, error) { + if action == "" && conversationID == "" { + return nil, trace.BadParameter("must specify either conversation_id or action") + } + + u := url.URL{ + Host: s.url().Host, + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%s/assistant", currentSiteShortcut), + } + + q := u.Query() + if conversationID != "" { + q.Set("conversation_id", conversationID) + } + + if action != "" { + q.Set("action", action) + } + + q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) + u.RawQuery = q.Encode() + + dialer := websocket.Dialer{} + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + header := http.Header{} + header.Add("Origin", "http://localhost") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + ws, resp, err := dialer.Dial(u.String(), header) + if err != nil { + return nil, trace.Wrap(err) + } + + err = resp.Body.Close() + if err != nil { + return nil, trace.Wrap(err) + } + + return ws, nil +} + // generateTextResponse generates a response for a text completion func generateTextResponse() string { return "\nWhich node do you want to use?" } + +func generateCommandResponse() string { + return "```" + `json + { + "action": "Command Generation", + "action_input": "{\"command\":\"openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes -subj\"}" + } + ` + "```" +} + +func commandSummaryResponse() string { + return "The application has failed to connect to the database. The database is not running." +}