diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 12e4347b83eb0..0ab5aa061adf1 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -21,7 +21,6 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" - "github.com/tiktoken-go/tokenizer" "github.com/gravitational/teleport/lib/ai/model" "github.com/gravitational/teleport/lib/ai/model/output" @@ -30,10 +29,9 @@ import ( // Chat represents a conversation between a user and an assistant with context memory. type Chat struct { - client *Client - messages []openai.ChatCompletionMessage - tokenizer tokenizer.Codec - agent *model.Agent + client *Client + messages []openai.ChatCompletionMessage + agent *model.Agent } // Insert inserts a message into the conversation. Returns the index of the message. diff --git a/lib/ai/chat_test.go b/lib/ai/chat_test.go index af0d38883cd42..064a02ae59807 100644 --- a/lib/ai/chat_test.go +++ b/lib/ai/chat_test.go @@ -356,7 +356,7 @@ func TestChat_Complete_AuditQuery(t *testing.T) { require.NoError(t, err) // We check that the agent returns the expected response - message, ok := result.(*output.Message) + message, ok := result.(*output.StreamingMessage) require.True(t, ok) - require.Equal(t, generatedQuery, message.Content) + require.Equal(t, generatedQuery, message.WaitAndConsume()) } diff --git a/lib/ai/client.go b/lib/ai/client.go index 6f7be50ecf2b1..532fa935fe65b 100644 --- a/lib/ai/client.go +++ b/lib/ai/client.go @@ -21,7 +21,6 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" - "github.com/tiktoken-go/tokenizer/codec" "github.com/gravitational/teleport/lib/ai/embedding" "github.com/gravitational/teleport/lib/ai/model" @@ -75,10 +74,7 @@ func (client *Client) NewChat(toolContext *modeltools.ToolContext) *Chat { Content: model.PromptCharacter(toolContext.User), }, }, - // Initialize a tokenizer for prompt token accounting. - // Cl100k is used by GPT-3 and GPT-4. - tokenizer: codec.NewCl100kBase(), - agent: model.NewAgent(toolContext, tools...), + agent: model.NewAgent(toolContext, tools...), } } @@ -92,10 +88,7 @@ func (client *Client) NewCommand(username string) *Chat { Content: model.PromptCharacter(username), }, }, - // Initialize a tokenizer for prompt token accounting. - // Cl100k is used by GPT-3 and GPT-4. - tokenizer: codec.NewCl100kBase(), - agent: model.NewAgent(toolContext, &modeltools.CommandGenerationTool{}), + agent: model.NewAgent(toolContext, &modeltools.CommandGenerationTool{}), } } diff --git a/lib/ai/model/agent.go b/lib/ai/model/agent.go index aaec4061e1eda..812b32208fafc 100644 --- a/lib/ai/model/agent.go +++ b/lib/ai/model/agent.go @@ -19,9 +19,7 @@ package model import ( "context" "encoding/json" - "errors" "fmt" - "io" "strings" "time" @@ -266,14 +264,12 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres } log.Tracef("Tool chose to query table '%s'", tableName) - query, err := tool.GenerateQuery(ctx, tableName, action.Input, state.tokenCount) + response, err := tool.GenerateQuery(ctx, tableName, action.Input, state.tokenCount) if err != nil { return stepOutput{}, trace.Wrap(err) } - log.Tracef("Tool generated query: %s", query) - completion := &output.Message{Content: query} - return stepOutput{finish: &agentFinish{output: completion}}, nil + return stepOutput{finish: &agentFinish{output: response}}, nil default: runOut, err := tool.Run(ctx, a.toolCtx, action.Input) if err != nil { @@ -305,23 +301,7 @@ func (a *Agent) plan(ctx context.Context, state *executionState) (*AgentAction, return nil, nil, trace.Wrap(err) } - deltas := make(chan string) - go func() { - defer close(deltas) - - for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - return - } else if err != nil { - log.Tracef("agent encountered an error while streaming: %v", err) - return - } - - delta := response.Choices[0].Delta.Content - deltas <- delta - } - }() + deltas := output.StreamToDeltas(stream) action, finish, completionTokenCounter, err := parsePlanningOutput(deltas) state.tokenCount.AddCompletionCounter(completionTokenCounter) @@ -391,25 +371,11 @@ func parsePlanningOutput(deltas <-chan string) (*AgentAction, *agentFinish, toke text += delta if strings.HasPrefix(text, finalResponseHeader) { - parts := make(chan string) - streamingTokenCounter, err := tokens.NewAsynchronousTokenCounter(text) + message, tc, err := output.NewStreamingMessage(deltas, text, finalResponseHeader) if err != nil { return nil, nil, nil, trace.Wrap(err) } - go func() { - defer close(parts) - - parts <- strings.TrimPrefix(text, finalResponseHeader) - for delta := range deltas { - parts <- delta - errCount := streamingTokenCounter.Add() - if errCount != nil { - log.WithError(errCount).Debug("Failed to add streamed completion text to the token counter") - } - } - }() - - return nil, &agentFinish{output: &output.StreamingMessage{Parts: parts}}, streamingTokenCounter, nil + return nil, &agentFinish{output: message}, tc, nil } } diff --git a/lib/ai/model/output/messages.go b/lib/ai/model/output/messages.go index 94f0c2c5244ce..e175fff98ca37 100644 --- a/lib/ai/model/output/messages.go +++ b/lib/ai/model/output/messages.go @@ -16,6 +16,8 @@ package output +import "strings" + // Message represents a new message within a live conversation. type Message struct { Content string @@ -26,6 +28,16 @@ type StreamingMessage struct { Parts <-chan string } +// WaitAndConsume waits until the message stream is over and returns the full message. +// This can only be called once on a message as it empties its Parts channel. +func (msg *StreamingMessage) WaitAndConsume() string { + sb := strings.Builder{} + for part := range msg.Parts { + sb.WriteString(part) + } + return sb.String() +} + // Label represents a label returned by OpenAI's completion API. type Label struct { Key string `json:"key"` diff --git a/lib/ai/model/output/streaming.go b/lib/ai/model/output/streaming.go new file mode 100644 index 0000000000000..783df5227f07f --- /dev/null +++ b/lib/ai/model/output/streaming.go @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package output + +import ( + "errors" + "io" + "strings" + + "github.com/gravitational/trace" + "github.com/sashabaranov/go-openai" + log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/lib/ai/tokens" +) + +// StreamToDeltas converts an openai.CompletionStream into a channel of strings. +// This channel can then be consumed manually to search for specific markers, +// or directly converted into a StreamingMessage with NewStreamingMessage. +func StreamToDeltas(stream *openai.ChatCompletionStream) chan string { + deltas := make(chan string) + go func() { + defer close(deltas) + + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + return + } else if err != nil { + log.Tracef("agent encountered an error while streaming: %v", err) + return + } + + delta := response.Choices[0].Delta.Content + deltas <- delta + } + }() + return deltas +} + +// NewStreamingMessage takes a string channel and converts it to +// a StreamingMessage. +// If content was already streamed, it must be passed through the alreadyStreamed parameter. +// If the already streamed content contains a prefix that must be stripped +// (like a marker to identify the kind of response the model is providing), +// the prefix can be passed through the prefix parameter. It will be stripped +// but will still be reflected in the token count. +func NewStreamingMessage(deltas <-chan string, alreadyStreamed, prefix string) (*StreamingMessage, *tokens.AsynchronousTokenCounter, error) { + parts := make(chan string) + streamingTokenCounter, err := tokens.NewAsynchronousTokenCounter(alreadyStreamed) + if err != nil { + return nil, nil, trace.Wrap(err) + } + go func() { + defer close(parts) + + parts <- strings.TrimPrefix(alreadyStreamed, prefix) + for delta := range deltas { + parts <- delta + errCount := streamingTokenCounter.Add() + if errCount != nil { + log.WithError(errCount).Debug("Failed to add streamed completion text to the token counter") + } + } + }() + return &StreamingMessage{Parts: parts}, streamingTokenCounter, nil +} diff --git a/lib/ai/model/tools/auditquery.go b/lib/ai/model/tools/auditquery.go index 63b8550966444..f60d855c416fc 100644 --- a/lib/ai/model/tools/auditquery.go +++ b/lib/ai/model/tools/auditquery.go @@ -26,6 +26,7 @@ import ( "github.com/sashabaranov/go-openai" "github.com/gravitational/teleport/gen/go/eventschema" + "github.com/gravitational/teleport/lib/ai/model/output" "github.com/gravitational/teleport/lib/ai/tokens" ) @@ -114,14 +115,14 @@ You MUST RESPOND ONLY with a single table name. If no table can answer the quest // GenerateQuery takes an event type, fetches its schema, and calls the LLM to // generate SQL and answer the user query. -func (t *AuditQueryGenerationTool) GenerateQuery(ctx context.Context, eventType, input string, tc *tokens.TokenCount) (string, error) { +func (t *AuditQueryGenerationTool) GenerateQuery(ctx context.Context, eventType, input string, tc *tokens.TokenCount) (*output.StreamingMessage, error) { eventSchema, err := eventschema.GetEventSchemaFromType(eventType) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } tableSchema, err := eventSchema.TableSchema() if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } prompt := []openai.ChatCompletionMessage{ @@ -144,28 +145,29 @@ Today's date is DATE('%s')`, time.Now().Format("2006-01-02")), } promptTokens, err := tokens.NewPromptTokenCounter(prompt) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } tc.AddPromptCounter(promptTokens) - response, err := t.LLM.CreateChatCompletion( + stream, err := t.LLM.CreateChatCompletionStream( ctx, openai.ChatCompletionRequest{ Model: openai.GPT4, Messages: prompt, Temperature: 0, + Stream: true, }, ) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } - completion := response.Choices[0].Message.Content - completionTokens, err := tokens.NewSynchronousTokenCounter(completion) + deltas := output.StreamToDeltas(stream) + message, completionTokens, err := output.NewStreamingMessage(deltas, "", "") if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } tc.AddCompletionCounter(completionTokens) - return completion, nil + return message, nil }