Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (chat *Chat) GetMessages() []openai.ChatCompletionMessage {
// Message types:
// - CompletionCommand: a command from the assistant
// - Message: a text message from the assistant
func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) {
func (chat *Chat) Complete(ctx context.Context, userInput string, progressUpdates func(*model.AgentAction)) (any, error) {
// if the chat is empty, return the initial response we predefine instead of querying GPT-4
if len(chat.messages) == 1 {
return &model.Message{
Expand All @@ -71,7 +71,7 @@ func (chat *Chat) Complete(ctx context.Context, userInput string) (any, error) {
Content: userInput,
}

response, err := chat.agent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage)
response, err := chat.agent.PlanAndExecute(ctx, chat.client.svc, chat.messages, userMessage, progressUpdates)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
146 changes: 100 additions & 46 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ package ai

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/ai/model"
aitest "github.com/gravitational/teleport/lib/ai/testutils"
)

func TestChat_PromptTokens(t *testing.T) {
Expand All @@ -49,7 +51,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hello",
},
},
want: 743,
want: 697,
},
{
name: "system and user messages",
Expand All @@ -63,7 +65,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Hi LLM.",
},
},
want: 751,
want: 705,
},
{
name: "tokenize our prompt",
Expand All @@ -77,7 +79,7 @@ func TestChat_PromptTokens(t *testing.T) {
Content: "Show me free disk space on localhost node.",
},
},
want: 954,
want: 908,
},
}

Expand All @@ -89,7 +91,17 @@ func TestChat_PromptTokens(t *testing.T) {
responses := []string{
generateCommandResponse(),
}
server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]
_, err := w.Write([]byte(dataBytes))
require.NoError(t, err, "Write error")

responses = responses[1:]
}))

t.Cleanup(server.Close)

cfg := openai.DefaultConfig("secret-test-token")
Expand All @@ -103,7 +115,7 @@ func TestChat_PromptTokens(t *testing.T) {
}

ctx := context.Background()
message, err := chat.Complete(ctx, "")
message, err := chat.Complete(ctx, "", func(aa *model.AgentAction) {})
require.NoError(t, err)
msg, ok := message.(interface{ UsedTokens() *model.TokensUsed })
require.True(t, ok)
Expand All @@ -117,51 +129,49 @@ func TestChat_PromptTokens(t *testing.T) {
func TestChat_Complete(t *testing.T) {
t.Parallel()

responses := []string{
generateTextResponse(),
generateCommandResponse(),
responses := [][]byte{
[]byte(generateTextResponse()),
[]byte(generateCommandResponse()),
}
server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")

responses = responses[1:]
}))
defer server.Close()

cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL + "/v1"
client := NewClientFromConfig(cfg)

chat := client.NewChat(nil, "Bob")

t.Run("initial message", func(t *testing.T) {
msgAny, err := chat.Complete(context.Background(), "Hello")
require.NoError(t, err)

msg, ok := msgAny.(*model.Message)
require.True(t, ok)
ctx := context.Background()
_, err := chat.Complete(ctx, "Hello", func(aa *model.AgentAction) {})
require.NoError(t, err)

expectedResp := &model.Message{
Content: "Hey, I'm Teleport - a powerful tool that can assist you in managing your Teleport cluster via OpenAI GPT-4.",
}
require.Equal(t, expectedResp.Content, msg.Content)
require.NotNil(t, msg.TokensUsed)
})
chat.Insert(openai.ChatMessageRoleUser, "Show me free disk space on localhost node.")

t.Run("text completion", func(t *testing.T) {
chat.Insert(openai.ChatMessageRoleUser, "Show me free disk space")

msg, err := chat.Complete(context.Background(), "")
msg, err := chat.Complete(ctx, "Show me free disk space", func(aa *model.AgentAction) {})
require.NoError(t, err)

require.IsType(t, &model.Message{}, msg)
streamingMessage := msg.(*model.Message)

const expectedResponse = "Which node do you want use?"

require.Equal(t, expectedResponse, streamingMessage.Content)
require.IsType(t, &model.StreamingMessage{}, msg)
streamingMessage := msg.(*model.StreamingMessage)
require.Equal(t, "Which ", <-streamingMessage.Parts)
require.Equal(t, "node do ", <-streamingMessage.Parts)
require.Equal(t, "you want ", <-streamingMessage.Parts)
require.Equal(t, "use?", <-streamingMessage.Parts)
})

t.Run("command completion", func(t *testing.T) {
chat.Insert(openai.ChatMessageRoleUser, "localhost")

msg, err := chat.Complete(context.Background(), "")
msg, err := chat.Complete(ctx, "localhost", func(aa *model.AgentAction) {})
require.NoError(t, err)

require.IsType(t, &model.CompletionCommand{}, msg)
Expand All @@ -174,20 +184,64 @@ func TestChat_Complete(t *testing.T) {

// generateTextResponse generates a response for a text completion
func generateTextResponse() string {
return "```" + `json
{
"action": "Final Answer",
"action_input": "Which node do you want use?"
}
` + "```"
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)

data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "<FINAL RESPONSE>Which ", "role": "assistant"}}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: message\n")...)

data = `{"id":"2","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "node do ", "role": "assistant"}}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: message\n")...)

data = `{"id":"3","object":"completion","created":1598069255,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "you want ", "role": "assistant"}}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: message\n")...)

data = `{"id":"4","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":{"content": "use?", "role": "assistant"}}]}`
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
dataBytes = append(dataBytes, []byte("event: done\n")...)

dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

return string(dataBytes)
}

// generateCommandResponse generates a response for the command "df -h" on the node "localhost"
func generateCommandResponse() string {
return "```" + `json
{
"action": "Command Execution",
"action_input": "{\"command\":\"df -h\",\"nodes\":[\"localhost\"],\"labels\":[]}"
dataBytes := []byte{}
dataBytes = append(dataBytes, []byte("event: message\n")...)

actionObj := model.PlanOutput{
Action: "Command Execution",
ActionInput: struct {
Command string `json:"command"`
Nodes []string `json:"nodes"`
}{"df -h", []string{"localhost"}},
}
actionJson, err := json.Marshal(actionObj)
if err != nil {
panic(err)
}
` + "```"

obj := struct {
Content string `json:"content"`
Role string `json:"role"`
}{
Content: string(actionJson),
Role: "assistant",
}
json, err := json.Marshal(obj)
if err != nil {
panic(err)
}

data := fmt.Sprintf(`{"id":"1","object":"completion","created":1598069254,"model":"gpt-4","choices":[{"index": 0, "delta":%v}]}`, string(json))
dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: done\n")...)
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)

return string(dataBytes)
}
2 changes: 1 addition & 1 deletion lib/ai/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (e *EmbeddingProcessor) Run(ctx context.Context, initialDelay, period time.
}

func (e *EmbeddingProcessor) process(ctx context.Context) {
batch := NewBatchReducer[*nodeStringPair, []*Embedding](e.mapProcessFn,
batch := NewBatchReducer(e.mapProcessFn,
maxEmbeddingAPISize, // Max batch size allowed by OpenAI API,
)

Expand Down
Loading