diff --git a/lib/ai/client.go b/lib/ai/client.go index aa14ab64c20ff..72f72925e7102 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -102,3 +102,23 @@ func (client *Client) CommandSummary(ctx context.Context, messages []openai.Chat return resp.Choices[0].Message.Content, nil } + +// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero shot classifier. +func (client *Client) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) { + resp, err := client.svc.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4, + Messages: []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleSystem, Content: model.MessageClassificationPrompt(classes)}, + {Role: openai.ChatMessageRoleUser, Content: message}, + }, + }, + ) + + if err != nil { + return "", trace.Wrap(err) + } + + return resp.Choices[0].Message.Content, nil +} diff --git a/lib/ai/model/prompt.go b/lib/ai/model/prompt.go index 3b728d6bfb949..94b29e91ced1d 100644 --- a/lib/ai/model/prompt.go +++ b/lib/ai/model/prompt.go @@ -117,3 +117,18 @@ func ConversationCommandResult(result map[string][]byte) string { message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.") return message.String() } + +func MessageClassificationPrompt(classes map[string]string) string { + var classList strings.Builder + for name, description := range classes { + classList.WriteString(fmt.Sprintf("- `%s` (%s)\n", name, description)) + } + + return fmt.Sprintf(`Teleport is a tool that provides access to servers, kubernetes clusters, databases, and applications. All connected Teleport resources are called a cluster. Server resources might be called nodes. + +Classify the provided message between the following categories: + +%v + +Answer only with the category name. Nothing else.`, classList.String()) +} diff --git a/lib/assist/assist.go b/lib/assist/assist.go index e0c34d8a007cf..c5a29a9d5c6d8 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -21,6 +21,7 @@ package assist import ( "context" "encoding/json" + "strings" "time" "github.com/gravitational/trace" @@ -172,6 +173,23 @@ func (c *Chat) reloadMessages(ctx context.Context) error { return c.loadMessages(ctx) } +// ClassifyMessage takes a user message, a list of categories, and uses the AI +// mode as a zero shot classifier. It returns an error if the classification +// result is not a valid class. +func (a *Assist) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) { + category, err := a.client.ClassifyMessage(ctx, message, classes) + if err != nil { + return "", trace.Wrap(err) + } + + cleanedCategory := strings.ToLower(strings.Trim(category, ". ")) + if _, ok := classes[cleanedCategory]; ok { + return cleanedCategory, nil + } + + return "", trace.CompareFailed("classification failed, category '%s' is not a valid classes", cleanedCategory) +} + // loadMessages loads the messages from the database. func (c *Chat) loadMessages(ctx context.Context) error { // existing conversation, retrieve old messages diff --git a/lib/assist/assist_test.go b/lib/assist/assist_test.go index 2ac2141cc3fa5..30547794fbaee 100644 --- a/lib/assist/assist_test.go +++ b/lib/assist/assist_test.go @@ -114,6 +114,51 @@ func TestChatComplete(t *testing.T) { }) } +func TestClassifyMessage(t *testing.T) { + // Given an OpenAI server that returns a response for a chat completion request. + responses := []string{ + "troubleshooting", + "Troubleshooting", + "Troubleshooting.", + "non-existent", + } + + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + cfg := openai.DefaultConfig("secret-test-token") + cfg.BaseURL = server.URL + "/v1" + + // And a chat client. + ctx := context.Background() + client, err := NewClient(ctx, &mockPluginGetter{}, &apiKeyMock{}, &cfg) + require.NoError(t, err) + + t.Run("Valid class", func(t *testing.T) { + class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) + require.NoError(t, err) + require.Equal(t, class, "troubleshooting") + }) + + t.Run("Valid class starting with upper-case", func(t *testing.T) { + class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) + require.NoError(t, err) + require.Equal(t, class, "troubleshooting") + }) + + t.Run("Valid class starting with upper-case and ending with dot", func(t *testing.T) { + class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) + require.NoError(t, err) + require.Equal(t, class, "troubleshooting") + }) + + t.Run("Model hallucinates", func(t *testing.T) { + class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) + require.Error(t, err) + require.Empty(t, class) + }) +} + type apiKeyMock struct{} // GetOpenAIAPIKey returns a mock API key. diff --git a/lib/assist/constants.go b/lib/assist/constants.go new file mode 100644 index 0000000000000..230c68f6b37e7 --- /dev/null +++ b/lib/assist/constants.go @@ -0,0 +1,35 @@ +/* + Copyright 2023 Gravitational, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package assist + +// MessageClasses contains type of assist message we expect users to send. +// When running on Cloud we attempt to classify user messages in one of those +// categories. If this succeeds, we send an event into the analytics pipeline. +// +// Keys are the category names, those are the ones reported in the event and +// generated by the model. Values are the category description. They are used to +// build the model prompt and allow to provide more context to the model to +// improve the classification. +var MessageClasses = map[string]string{ + "command execution": "the user want to execute a command on one or many servers", + "troubleshooting": "the user wants to diagnose a problem or understand an error message", + "configuration": "the user wants to generate configuration for a software which is not Teleport", + "manage resources": "the user wants to list/add/remove/edit resources connected to the Teleport cluster", + "access request": "the user requests access to one or many resources from the Teleport cluster", + "teleport setup": "the user wants help with its Teleport cluster, like setting up a new feature or knowing if something is feasible", + "other": "the user asks a question which is not IT nor Teleport-related", +} diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 6845dcb9f9c40..9ea879838e186 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/reversetunnelclient" ) @@ -250,7 +251,7 @@ type generateAssistantTitleRequest struct { Message string `json:"message"` } -// generateAssistantTitle is a handler for POST /webapi/assistant/conversations/:conversation_id/generate_title. +// generateAssistantTitle is a handler for POST /webapi/assistant/title/summary. func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, _ httprouter.Params, sctx *SessionContext, ) (any, error) { @@ -283,6 +284,21 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, Title: titleSummary, } + // We only want to emmit + if modules.GetModules().Features().Cloud { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + class, err := client.ClassifyMessage(ctx, req.Message, assist.MessageClasses) + if err != nil { + return + } + h.log.Debugf("message classified as '%s'", class) + // TODO(shaka): emit event here to report the message class + }() + + } + return conversationInfo, nil } diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 60b45d871336a..c163b255e9f50 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -308,6 +308,42 @@ func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID st return ws, nil } +func Test_generateAssistantTitle(t *testing.T) { + // Test setup + t.Parallel() + ctx := context.Background() + + responses := []string{"This is the message summary.", "troubleshooting"} + server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) + t.Cleanup(server.Close) + + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{ + ClusterFeatures: &authproto.Features{ + Cloud: true, + }, + OpenAIConfig: &openaiCfg, + }) + + pack := s.authPack(t, "foo") + + // Real test: we craft a request asking for a summary + endpoint := pack.clt.Endpoint("webapi", "assistant", "title", "summary") + req := generateAssistantTitleRequest{Message: "This is a test user message asking Teleport assist to do something."} + + // Executing the request and validating the output is as expected + resp, err := pack.clt.PostJSON(ctx, endpoint, &req) + require.NoError(t, err) + + var info conversationInfo + body, err := io.ReadAll(resp.Reader()) + require.NoError(t, err) + err = json.Unmarshal(body, &info) + require.NoError(t, err) + require.NotEmpty(t, info.Title) +} + // generateTextResponse generates a response for a text completion func generateTextResponse() string { return "```" + `json