Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer"

"github.com/gravitational/teleport/lib/ai/model"
"github.com/gravitational/teleport/lib/ai/model/output"
Expand All @@ -30,10 +29,9 @@ import (

// Chat represents a conversation between a user and an assistant with context memory.
type Chat struct {
client *Client
messages []openai.ChatCompletionMessage
tokenizer tokenizer.Codec
agent *model.Agent
client *Client
messages []openai.ChatCompletionMessage
agent *model.Agent
}

// Insert inserts a message into the conversation. Returns the index of the message.
Expand Down
4 changes: 2 additions & 2 deletions lib/ai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ func TestChat_Complete_AuditQuery(t *testing.T) {
require.NoError(t, err)

// We check that the agent returns the expected response
message, ok := result.(*output.Message)
message, ok := result.(*output.StreamingMessage)
require.True(t, ok)
require.Equal(t, generatedQuery, message.Content)
require.Equal(t, generatedQuery, message.WaitAndConsume())
}
11 changes: 2 additions & 9 deletions lib/ai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer/codec"

"github.com/gravitational/teleport/lib/ai/embedding"
"github.com/gravitational/teleport/lib/ai/model"
Expand Down Expand Up @@ -75,10 +74,7 @@ func (client *Client) NewChat(toolContext *modeltools.ToolContext) *Chat {
Content: model.PromptCharacter(toolContext.User),
},
},
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: model.NewAgent(toolContext, tools...),
agent: model.NewAgent(toolContext, tools...),
}
}

Expand All @@ -92,10 +88,7 @@ func (client *Client) NewCommand(username string) *Chat {
Content: model.PromptCharacter(username),
},
},
// Initialize a tokenizer for prompt token accounting.
// Cl100k is used by GPT-3 and GPT-4.
tokenizer: codec.NewCl100kBase(),
agent: model.NewAgent(toolContext, &modeltools.CommandGenerationTool{}),
agent: model.NewAgent(toolContext, &modeltools.CommandGenerationTool{}),
}
}

Expand Down
44 changes: 5 additions & 39 deletions lib/ai/model/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package model
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"time"

Expand Down Expand Up @@ -266,14 +264,12 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
}

log.Tracef("Tool chose to query table '%s'", tableName)
query, err := tool.GenerateQuery(ctx, tableName, action.Input, state.tokenCount)
response, err := tool.GenerateQuery(ctx, tableName, action.Input, state.tokenCount)
if err != nil {
return stepOutput{}, trace.Wrap(err)
}
log.Tracef("Tool generated query: %s", query)

completion := &output.Message{Content: query}
return stepOutput{finish: &agentFinish{output: completion}}, nil
return stepOutput{finish: &agentFinish{output: response}}, nil
default:
runOut, err := tool.Run(ctx, a.toolCtx, action.Input)
if err != nil {
Expand Down Expand Up @@ -305,23 +301,7 @@ func (a *Agent) plan(ctx context.Context, state *executionState) (*AgentAction,
return nil, nil, trace.Wrap(err)
}

deltas := make(chan string)
go func() {
defer close(deltas)

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
return
} else if err != nil {
log.Tracef("agent encountered an error while streaming: %v", err)
return
}

delta := response.Choices[0].Delta.Content
deltas <- delta
}
}()
deltas := output.StreamToDeltas(stream)

action, finish, completionTokenCounter, err := parsePlanningOutput(deltas)
state.tokenCount.AddCompletionCounter(completionTokenCounter)
Expand Down Expand Up @@ -391,25 +371,11 @@ func parsePlanningOutput(deltas <-chan string) (*AgentAction, *agentFinish, toke
text += delta

if strings.HasPrefix(text, finalResponseHeader) {
parts := make(chan string)
streamingTokenCounter, err := tokens.NewAsynchronousTokenCounter(text)
message, tc, err := output.NewStreamingMessage(deltas, text, finalResponseHeader)
if err != nil {
return nil, nil, nil, trace.Wrap(err)
}
go func() {
defer close(parts)

parts <- strings.TrimPrefix(text, finalResponseHeader)
for delta := range deltas {
parts <- delta
errCount := streamingTokenCounter.Add()
if errCount != nil {
log.WithError(errCount).Debug("Failed to add streamed completion text to the token counter")
}
}
}()

return nil, &agentFinish{output: &output.StreamingMessage{Parts: parts}}, streamingTokenCounter, nil
return nil, &agentFinish{output: message}, tc, nil
}
}

Expand Down
12 changes: 12 additions & 0 deletions lib/ai/model/output/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package output

import "strings"

// Message represents a new message within a live conversation.
type Message struct {
Content string
Expand All @@ -26,6 +28,16 @@ type StreamingMessage struct {
Parts <-chan string
}

// WaitAndConsume waits until the message stream is over and returns the full message.
// This can only be called once on a message as it empties its Parts channel.
func (msg *StreamingMessage) WaitAndConsume() string {
sb := strings.Builder{}
for part := range msg.Parts {
Comment thread
hugoShaka marked this conversation as resolved.
sb.WriteString(part)
}
return sb.String()
}

// Label represents a label returned by OpenAI's completion API.
type Label struct {
Key string `json:"key"`
Expand Down
81 changes: 81 additions & 0 deletions lib/ai/model/output/streaming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package output
Comment thread
hugoShaka marked this conversation as resolved.

import (
"errors"
"io"
"strings"

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport/lib/ai/tokens"
)

// StreamToDeltas converts an openai.CompletionStream into a channel of strings.
// This channel can then be consumed manually to search for specific markers,
// or directly converted into a StreamingMessage with NewStreamingMessage.
func StreamToDeltas(stream *openai.ChatCompletionStream) chan string {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function need to be public? For me, it sounds like some internal helper that should not be called outside of the module.

Copy link
Copy Markdown
Contributor Author

@hugoShaka hugoShaka Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lib/ai/model and lib/ai/model/tools need to do this and are in separate packages. I don't think we need to expose this outside of lib/ai/model/... though. lib/ai/model manually processes the deltas, while lib/ai/model/tools can create the streaming message directly, so we can't make a single StreamtoStramingMessage() function

Copy link
Copy Markdown
Contributor Author

@hugoShaka hugoShaka Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can create a new lib/ai/model/utils package, or a lib/ai/model/internal one

deltas := make(chan string)
go func() {
defer close(deltas)

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
return
} else if err != nil {
log.Tracef("agent encountered an error while streaming: %v", err)
return
}

delta := response.Choices[0].Delta.Content
deltas <- delta
}
}()
return deltas
}

// NewStreamingMessage takes a string channel and converts it to
// a StreamingMessage.
// If content was already streamed, it must be passed through the alreadyStreamed parameter.
// If the already streamed content contains a prefix that must be stripped
// (like a marker to identify the kind of response the model is providing),
// the prefix can be passed through the prefix parameter. It will be stripped
// but will still be reflected in the token count.
func NewStreamingMessage(deltas <-chan string, alreadyStreamed, prefix string) (*StreamingMessage, *tokens.AsynchronousTokenCounter, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here, it sounds more like an internal helper rather than something that can be used outside of this module.

parts := make(chan string)
streamingTokenCounter, err := tokens.NewAsynchronousTokenCounter(alreadyStreamed)
if err != nil {
return nil, nil, trace.Wrap(err)
}
go func() {
defer close(parts)

parts <- strings.TrimPrefix(alreadyStreamed, prefix)
for delta := range deltas {
parts <- delta
errCount := streamingTokenCounter.Add()
if errCount != nil {
log.WithError(errCount).Debug("Failed to add streamed completion text to the token counter")
}
}
}()
return &StreamingMessage{Parts: parts}, streamingTokenCounter, nil
}
22 changes: 12 additions & 10 deletions lib/ai/model/tools/auditquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/sashabaranov/go-openai"

"github.com/gravitational/teleport/gen/go/eventschema"
"github.com/gravitational/teleport/lib/ai/model/output"
"github.com/gravitational/teleport/lib/ai/tokens"
)

Expand Down Expand Up @@ -114,14 +115,14 @@ You MUST RESPOND ONLY with a single table name. If no table can answer the quest

// GenerateQuery takes an event type, fetches its schema, and calls the LLM to
// generate SQL and answer the user query.
func (t *AuditQueryGenerationTool) GenerateQuery(ctx context.Context, eventType, input string, tc *tokens.TokenCount) (string, error) {
func (t *AuditQueryGenerationTool) GenerateQuery(ctx context.Context, eventType, input string, tc *tokens.TokenCount) (*output.StreamingMessage, error) {
eventSchema, err := eventschema.GetEventSchemaFromType(eventType)
if err != nil {
return "", trace.Wrap(err)
return nil, trace.Wrap(err)
}
tableSchema, err := eventSchema.TableSchema()
if err != nil {
return "", trace.Wrap(err)
return nil, trace.Wrap(err)
}

prompt := []openai.ChatCompletionMessage{
Expand All @@ -144,28 +145,29 @@ Today's date is DATE('%s')`, time.Now().Format("2006-01-02")),
}
promptTokens, err := tokens.NewPromptTokenCounter(prompt)
if err != nil {
return "", trace.Wrap(err)
return nil, trace.Wrap(err)
}
tc.AddPromptCounter(promptTokens)

response, err := t.LLM.CreateChatCompletion(
stream, err := t.LLM.CreateChatCompletionStream(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4,
Messages: prompt,
Temperature: 0,
Stream: true,
},
)
if err != nil {
return "", trace.Wrap(err)
return nil, trace.Wrap(err)
}

completion := response.Choices[0].Message.Content
completionTokens, err := tokens.NewSynchronousTokenCounter(completion)
deltas := output.StreamToDeltas(stream)
message, completionTokens, err := output.NewStreamingMessage(deltas, "", "")
if err != nil {
return "", trace.Wrap(err)
return nil, trace.Wrap(err)
}
tc.AddCompletionCounter(completionTokens)

return completion, nil
return message, nil
}