Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,23 @@ func (client *Client) CommandSummary(ctx context.Context, messages []openai.Chat

return resp.Choices[0].Message.Content, nil
}

// ClassifyMessage takes a user message, a list of categories, and uses the AI mode as a zero shot classifier.
func (client *Client) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) {
Comment thread
hugoShaka marked this conversation as resolved.
resp, err := client.svc.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4,
Messages: []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: model.MessageClassificationPrompt(classes)},
{Role: openai.ChatMessageRoleUser, Content: message},
},
},
)

if err != nil {
return "", trace.Wrap(err)
}

return resp.Choices[0].Message.Content, nil
}
15 changes: 15 additions & 0 deletions lib/ai/model/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,18 @@ func ConversationCommandResult(result map[string][]byte) string {
message.WriteString("Based on the chat history, extract relevant information out of the command output and write a summary.")
return message.String()
}

func MessageClassificationPrompt(classes map[string]string) string {
var classList strings.Builder
for name, description := range classes {
classList.WriteString(fmt.Sprintf("- `%s` (%s)\n", name, description))
}

return fmt.Sprintf(`Teleport is a tool that provides access to servers, kubernetes clusters, databases, and applications. All connected Teleport resources are called a cluster. Server resources might be called nodes.

Classify the provided message between the following categories:

%v

Answer only with the category name. Nothing else.`, classList.String())
}
18 changes: 18 additions & 0 deletions lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package assist
import (
"context"
"encoding/json"
"strings"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -172,6 +173,23 @@ func (c *Chat) reloadMessages(ctx context.Context) error {
return c.loadMessages(ctx)
}

// ClassifyMessage takes a user message, a list of categories, and uses the AI
// mode as a zero shot classifier. It returns an error if the classification
// result is not a valid class.
func (a *Assist) ClassifyMessage(ctx context.Context, message string, classes map[string]string) (string, error) {
category, err := a.client.ClassifyMessage(ctx, message, classes)
if err != nil {
return "", trace.Wrap(err)
}

cleanedCategory := strings.ToLower(strings.Trim(category, ". "))
if _, ok := classes[cleanedCategory]; ok {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd consider using strings.Contains instead of exact match to support some random strings added by ChatGPT maybe?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this change would break if we create a category name containing another category name. It can also lead to false positives. From what I tested, error rates are quite low. I'd prefer to wait for real data before doing those kinds of optimizations.

return cleanedCategory, nil
}

return "", trace.CompareFailed("classification failed, category '%s' is not a valid classes", cleanedCategory)
}

// loadMessages loads the messages from the database.
func (c *Chat) loadMessages(ctx context.Context) error {
// existing conversation, retrieve old messages
Expand Down
45 changes: 45 additions & 0 deletions lib/assist/assist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,51 @@ func TestChatComplete(t *testing.T) {
})
}

func TestClassifyMessage(t *testing.T) {
// Given an OpenAI server that returns a response for a chat completion request.
responses := []string{
"troubleshooting",
"Troubleshooting",
"Troubleshooting.",
"non-existent",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you ever seen something like "As a Large Language Model, I think that the input can be classified as troubleshooting", or something similar? 😅

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the few dozen tests I ran, I never observed such a response. gpt-3.5 tend to add uppercase and dots, gpt-4 never made mistakes so far

}

server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL + "/v1"

// And a chat client.
ctx := context.Background()
client, err := NewClient(ctx, &mockPluginGetter{}, &apiKeyMock{}, &cfg)
require.NoError(t, err)

t.Run("Valid class", func(t *testing.T) {
class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses)
require.NoError(t, err)
require.Equal(t, class, "troubleshooting")
})

t.Run("Valid class starting with upper-case", func(t *testing.T) {
class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses)
require.NoError(t, err)
require.Equal(t, class, "troubleshooting")
})

t.Run("Valid class starting with upper-case and ending with dot", func(t *testing.T) {
class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses)
require.NoError(t, err)
require.Equal(t, class, "troubleshooting")
})

t.Run("Model hallucinates", func(t *testing.T) {
class, err := client.ClassifyMessage(ctx, "whatever", MessageClasses)
require.Error(t, err)
require.Empty(t, class)
})
}

type apiKeyMock struct{}

// GetOpenAIAPIKey returns a mock API key.
Expand Down
35 changes: 35 additions & 0 deletions lib/assist/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package assist

// MessageClasses contains type of assist message we expect users to send.
// When running on Cloud we attempt to classify user messages in one of those
// categories. If this succeeds, we send an event into the analytics pipeline.
//
// Keys are the category names, those are the ones reported in the event and
// generated by the model. Values are the category description. They are used to
// build the model prompt and allow to provide more context to the model to
// improve the classification.
var MessageClasses = map[string]string{
"command execution": "the user want to execute a command on one or many servers",
"troubleshooting": "the user wants to diagnose a problem or understand an error message",
"configuration": "the user wants to generate configuration for a software which is not Teleport",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can the LLM know if the software is "Teleport"? Also, do we even care about Teleport/non-Teleport scenario?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt contains a super-short Teleport description. From what I observed this is enough for gpt-4 to reliably infer if the request is about Teleport or not. (gpt-3.5 showed some issues but was right most of the time). I suspect gpt-4 also learned what Teleport is and can leverage this information when classifying.

I think this is important to separate Teleport-related requests from non-teleport related requests as those two categories are not actionable in the same way. We can help the model to answer Teleport configuration requests by giving it access to the docs and allowing it to link back the user to the docs. On the other hand, we rely on the generic part of the model to answer most other configuration questions (embedding man pages might help but this is a long shot).

From what I understand we will want to know if users are asking "write me a working nginx configuration" or if they are asking "I want the teleport agent configuration to provide access to a ssh server and a database at the same time".

"manage resources": "the user wants to list/add/remove/edit resources connected to the Teleport cluster",
"access request": "the user requests access to one or many resources from the Teleport cluster",
"teleport setup": "the user wants help with its Teleport cluster, like setting up a new feature or knowing if something is feasible",
"other": "the user asks a question which is not IT nor Teleport-related",
}
18 changes: 17 additions & 1 deletion lib/web/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/gravitational/teleport/lib/assist"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/reversetunnelclient"
)

Expand Down Expand Up @@ -250,7 +251,7 @@ type generateAssistantTitleRequest struct {
Message string `json:"message"`
}

// generateAssistantTitle is a handler for POST /webapi/assistant/conversations/:conversation_id/generate_title.
// generateAssistantTitle is a handler for POST /webapi/assistant/title/summary.
func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request,
_ httprouter.Params, sctx *SessionContext,
) (any, error) {
Expand Down Expand Up @@ -283,6 +284,21 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request,
Title: titleSummary,
}

// We only want to emmit
if modules.GetModules().Features().Cloud {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
class, err := client.ClassifyMessage(ctx, req.Message, assist.MessageClasses)
if err != nil {
return
}
h.log.Debugf("message classified as '%s'", class)
// TODO(shaka): emit event here to report the message class
}()

}

return conversationInfo, nil
}

Expand Down
36 changes: 36 additions & 0 deletions lib/web/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,42 @@ func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack, conversationID st
return ws, nil
}

func Test_generateAssistantTitle(t *testing.T) {
// Test setup
t.Parallel()
ctx := context.Background()

responses := []string{"This is the message summary.", "troubleshooting"}
server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)

openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{
ClusterFeatures: &authproto.Features{
Cloud: true,
},
OpenAIConfig: &openaiCfg,
})

pack := s.authPack(t, "foo")

// Real test: we craft a request asking for a summary
endpoint := pack.clt.Endpoint("webapi", "assistant", "title", "summary")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure that we have a wrapper for that. If not, we should probably add it.

req := generateAssistantTitleRequest{Message: "This is a test user message asking Teleport assist to do something."}

// Executing the request and validating the output is as expected
resp, err := pack.clt.PostJSON(ctx, endpoint, &req)
require.NoError(t, err)

var info conversationInfo
body, err := io.ReadAll(resp.Reader())
require.NoError(t, err)
err = json.Unmarshal(body, &info)
require.NoError(t, err)
require.NotEmpty(t, info.Title)
}

// generateTextResponse generates a response for a text completion
func generateTextResponse() string {
return "```" + `json
Expand Down