-
Notifications
You must be signed in to change notification settings - Fork 2k
[Assist] Classify messages for usage analytics #28221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ package assist | |
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "strings" | ||
| "time" | ||
|
|
||
| "github.com/gravitational/trace" | ||
|
|
@@ -172,6 +173,23 @@ func (c *Chat) reloadMessages(ctx context.Context) error { | |
| return c.loadMessages(ctx) | ||
| } | ||
|
|
||
| // ClassifyMessage takes a user message, a list of categories, and uses the AI | ||
| // mode as a zero shot classifier. It returns an error if the classification | ||
| // result is not a valid class. | ||
| func (a *Assist) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) { | ||
| category, err := a.client.ClassifyMessage(ctx, message, classes) | ||
| if err != nil { | ||
| return "", trace.Wrap(err) | ||
| } | ||
|
|
||
| cleanedCategory := strings.ToLower(strings.Trim(category, ". ")) | ||
| if _, ok := classes[cleanedCategory]; ok { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd consider using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid this change would break if we create a category name containing another category name. It can also lead to false positives. From what I tested, error rates are quite low. I'd prefer to wait for real data before doing those kinds of optimizations. |
||
| return cleanedCategory, nil | ||
| } | ||
|
|
||
| return "", trace.CompareFailed("classification failed, category '%s' is not a valid classes", cleanedCategory) | ||
| } | ||
|
|
||
| // loadMessages loads the messages from the database. | ||
| func (c *Chat) loadMessages(ctx context.Context) error { | ||
| // existing conversation, retrieve old messages | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,6 +114,51 @@ func TestChatComplete(t *testing.T) { | |
| }) | ||
| } | ||
|
|
||
| func TestClassifyMessage(t *testing.T) { | ||
| // Given an OpenAI server that returns a response for a chat completion request. | ||
| responses := []string{ | ||
| "troubleshooting", | ||
| "Troubleshooting", | ||
| "Troubleshooting.", | ||
| "non-existent", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have you ever seen something like "As a Large Language Model, I think that the input can be classified as troubleshooting", or something similar? 😅
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the few dozen tests I ran, I never observed such a response. gpt-3.5 tend to add uppercase and dots, gpt-4 never made mistakes so far |
||
| } | ||
|
|
||
| server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) | ||
| t.Cleanup(server.Close) | ||
|
|
||
| cfg := openai.DefaultConfig("secret-test-token") | ||
| cfg.BaseURL = server.URL + "/v1" | ||
|
|
||
| // And a chat client. | ||
| ctx := context.Background() | ||
| client, err := NewClient(ctx, &mockPluginGetter{}, &apiKeyMock{}, &cfg) | ||
| require.NoError(t, err) | ||
|
|
||
| t.Run("Valid class", func(t *testing.T) { | ||
| class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) | ||
| require.NoError(t, err) | ||
| require.Equal(t, class, "troubleshooting") | ||
| }) | ||
|
|
||
| t.Run("Valid class starting with upper-case", func(t *testing.T) { | ||
| class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) | ||
| require.NoError(t, err) | ||
| require.Equal(t, class, "troubleshooting") | ||
| }) | ||
|
|
||
| t.Run("Valid class starting with upper-case and ending with dot", func(t *testing.T) { | ||
| class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) | ||
| require.NoError(t, err) | ||
| require.Equal(t, class, "troubleshooting") | ||
| }) | ||
|
|
||
| t.Run("Model hallucinates", func(t *testing.T) { | ||
| class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses) | ||
| require.Error(t, err) | ||
| require.Empty(t, class) | ||
| }) | ||
| } | ||
|
|
||
| type apiKeyMock struct{} | ||
|
|
||
| // GetOpenAIAPIKey returns a mock API key. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| /* | ||
| Copyright 2023 Gravitational, Inc. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package assist | ||
|
|
||
| // MessageClasses contains type of assist message we expect users to send. | ||
| // When running on Cloud we attempt to classify user messages in one of those | ||
| // categories. If this succeeds, we send an event into the analytics pipeline. | ||
| // | ||
| // Keys are the category names, those are the ones reported in the event and | ||
| // generated by the model. Values are the category description. They are used to | ||
| // build the model prompt and allow to provide more context to the model to | ||
| // improve the classification. | ||
| var MessageClasses = map[string]string{ | ||
| "command execution": "the user want to execute a command on one or many servers", | ||
| "troubleshooting": "the user wants to diagnose a problem or understand an error message", | ||
| "configuration": "the user wants to generate configuration for a software which is not Teleport", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can the LLM know if the software is "Teleport"? Also, do we even care about Teleport/non-Teleport scenario?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The prompt contains a super-short Teleport description. From what I observed this is enough for gpt-4 to reliably infer if the request is about Teleport or not. (gpt-3.5 showed some issues but was right most of the time). I suspect gpt-4 also learned what Teleport is and can leverage this information when classifying. I think this is important to separate Teleport-related requests from non-teleport related requests as those two categories are not actionable in the same way. We can help the model to answer Teleport configuration requests by giving it access to the docs and allowing it to link back the user to the docs. On the other hand, we rely on the generic part of the model to answer most other configuration questions (embedding man pages might help but this is a long shot). From what I understand we will want to know if users are asking "write me a working nginx configuration" or if they are asking "I want the teleport agent configuration to provide access to a ssh server and a database at the same time". |
||
| "manage resources": "the user wants to list/add/remove/edit resources connected to the Teleport cluster", | ||
| "access request": "the user requests access to one or many resources from the Teleport cluster", | ||
| "teleport setup": "the user wants help with its Teleport cluster, like setting up a new feature or knowing if something is feasible", | ||
| "other": "the user asks a question which is not IT nor Teleport-related", | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -308,6 +308,42 @@ func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID st | |
| return ws, nil | ||
| } | ||
|
|
||
| func Test_generateAssistantTitle(t *testing.T) { | ||
| // Test setup | ||
| t.Parallel() | ||
| ctx := context.Background() | ||
|
|
||
| responses := []string{"This is the message summary.", "troubleshooting"} | ||
| server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) | ||
| t.Cleanup(server.Close) | ||
|
|
||
| openaiCfg := openai.DefaultConfig("test-token") | ||
| openaiCfg.BaseURL = server.URL | ||
| s := newWebSuiteWithConfig(t, webSuiteConfig{ | ||
| ClusterFeatures: &authproto.Features{ | ||
| Cloud: true, | ||
| }, | ||
| OpenAIConfig: &openaiCfg, | ||
| }) | ||
|
|
||
| pack := s.authPack(t, "foo") | ||
|
|
||
| // Real test: we craft a request asking for a summary | ||
| endpoint := pack.clt.Endpoint("webapi", "assistant", "title", "summary") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm pretty sure that we have a wrapper for that. If not, we should probably add it. |
||
| req := generateAssistantTitleRequest{Message: "This is a test user message asking Teleport assist to do something."} | ||
|
|
||
| // Executing the request and validating the output is as expected | ||
| resp, err := pack.clt.PostJSON(ctx, endpoint, &req) | ||
| require.NoError(t, err) | ||
|
|
||
| var info conversationInfo | ||
| body, err := io.ReadAll(resp.Reader()) | ||
| require.NoError(t, err) | ||
| err = json.Unmarshal(body, &info) | ||
| require.NoError(t, err) | ||
| require.NotEmpty(t, info.Title) | ||
| } | ||
|
|
||
| // generateTextResponse generates a response for a text completion | ||
| func generateTextResponse() string { | ||
| return "```" + `json | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.