Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/daviddengcn/go-colortext v1.0.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.9.0 // indirect
github.com/docker/distribution v2.8.1+incompatible // indirect
github.com/dvsekhvalnov/jose2go v1.5.0 // indirect
github.com/elastic/elastic-transport-go/v8 v8.1.0 // indirect
Expand Down Expand Up @@ -332,12 +333,14 @@ require (
github.com/rs/zerolog v1.28.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 // indirect
github.com/sashabaranov/go-openai v1.9.3
github.com/shabbyrobe/gocovmerge v0.0.0-20190829150210-3e036491d500 // indirect
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 // indirect
github.com/siddontang/go-log v0.0.0-20180807004314-8d05993dda07 // indirect
github.com/spf13/cobra v1.6.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/thales-e-security/pool v0.0.2 // indirect
github.com/tiktoken-go/tokenizer v0.1.0
github.com/x448/float16 v0.8.4 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.1 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
github.com/dlclark/regexp2 v1.9.0 h1:pTK/l/3qYIKaRXuHnEnIf7Y5NxfRPfpb7dis6/gdlVI=
github.com/dlclark/regexp2 v1.9.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
Expand Down Expand Up @@ -1166,6 +1168,8 @@ github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 h1:GHRpF1pTW19a8tTFrMLUcfWwyC0pnifVo2ClaLq+hP8=
github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8=
github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM=
github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/schollz/progressbar/v3 v3.13.0 h1:9TeeWRcjW2qd05I8Kf9knPkW4vLM/hYoa6z9ABvxje8=
github.com/schollz/progressbar/v3 v3.13.0/go.mod h1:ZBYnSuLAX2LU8P8UiKN/KgF2DY58AJC8yfVYLPC8Ly4=
Expand Down Expand Up @@ -1235,6 +1239,8 @@ github.com/thales-e-security/pool v0.0.2 h1:RAPs4q2EbWsTit6tpzuvTFlgFRJ3S8Evf5gt
github.com/thales-e-security/pool v0.0.2/go.mod h1:qtpMm2+thHtqhLzTwgDBj/OuNnMpupY8mv0Phz0gjhU=
github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tiktoken-go/tokenizer v0.1.0 h1:c1fXriHSR/NmhMDTwUDLGiNhHwTV+ElABGvqhCWLRvY=
github.com/tiktoken-go/tokenizer v0.1.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20200427203606-3cfed13b9966/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
Expand Down
229 changes: 229 additions & 0 deletions lib/ai/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai

import (
"context"
"encoding/json"
"errors"
"io"
"strings"

"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/tiktoken-go/tokenizer"
)

const maxResponseTokens = 2000

// Chat represents a conversation between a user and an assistant with context memory.
type Chat struct {
client *Client
messages []openai.ChatCompletionMessage
tokenizer tokenizer.Codec
}

// Insert inserts a message into the conversation. This is commonly in the
// form of a user's input but may also take the form of a system messages used for instructions.
func (chat *Chat) Insert(role string, content string) Message {
chat.messages = append(chat.messages, openai.ChatCompletionMessage{
Role: role,
Content: content,
})

return Message{
Role: role,
Content: content,
Idx: len(chat.messages) - 1,
}
}

// PromptTokens uses the chat's tokenizer to calculate
// the total number of tokens in the prompt
//
// Ref: https://github.com/openai/openai-cookbook/blob/594fc6c952425810e9ea5bd1a275c8ca5f32e8f9/examples/How_to_count_tokens_with_tiktoken.ipynb
func (chat *Chat) PromptTokens() (int, error) {
// perRequest is the number of tokens used up for each completion request
const perRequest = 3
// perRole is the number of tokens used to encode a message's role
const perRole = 1
// perMessage is the token "overhead" for each message
const perMessage = 3

sum := perRequest
for _, m := range chat.messages {
tokens, _, err := chat.tokenizer.Encode(m.Content)
if err != nil {
return 0, trace.Wrap(err)
}
sum += len(tokens)
sum += perRole
sum += perMessage
}

return sum, nil
}

// Summary creates a short summary for the given input.
func (chat *Chat) Summary(ctx context.Context, message string) (string, error) {
resp, err := chat.client.svc.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4,
Messages: []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: promptSummarizeTitle},
{Role: openai.ChatMessageRoleUser, Content: message},
},
},
)

if err != nil {
return "", trace.Wrap(err)
}

return resp.Choices[0].Message.Content, nil
}

// Complete completes the conversation with a message from the assistant based on the current context.
// On success, it returns the message and the number of tokens used for the completion.
// Returned types:
// - Message: the message from the assistant
// - int: the number of tokens used for the completion
// - error: an error if one occurred
// Message types:
// - CompletionCommand: a command from the assistant
// - StreamingMessage: a message that is streamed from the assistant
func (chat *Chat) Complete(ctx context.Context) (any, error) {
var numTokens int

// if the chat is empty, return the initial response we predefine instead of querying GPT-4
if len(chat.messages) == 1 {
return &Message{
Role: openai.ChatMessageRoleAssistant,
Content: initialAIResponse,
Idx: len(chat.messages) - 1,
}, nil
}

// if not, copy the current chat log to a new slice and append the suffix instruction
messages := make([]openai.ChatCompletionMessage, len(chat.messages)+1)
copy(messages, chat.messages)
messages[len(messages)-1] = openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: promptExtractInstruction,
}

// create a streaming completion request, we do this to optimistically stream the response when
// we don't believe it's a payload
stream, err := chat.client.svc.CreateChatCompletionStream(
ctx,
openai.ChatCompletionRequest{
Model: openai.GPT4,
Messages: messages,
MaxTokens: maxResponseTokens,
Stream: true,
},
)
if err != nil {
return nil, trace.Wrap(err)
}

var (
response openai.ChatCompletionStreamResponse
trimmed string
)
for trimmed == "" {
// fetch the first delta to check for a possible JSON payload
response, err = stream.Recv()
if err != nil {
return nil, trace.Wrap(err)
}
numTokens++

trimmed = strings.TrimSpace(response.Choices[0].Delta.Content)
}

// if it looks like a JSON payload, let's wait for the entire response and try to parse it
if strings.HasPrefix(trimmed, "{") {
payload := strings.Builder{}
payload.WriteString(response.Choices[0].Delta.Content)

for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, trace.Wrap(err)
}
numTokens++

payload.WriteString(response.Choices[0].Delta.Content)
}

// if we can parse it, return the parsed payload, otherwise return a non-streaming message
var c CompletionCommand
err = json.Unmarshal([]byte(payload.String()), &c)
switch err {
case nil:
c.NumTokens = numTokens
return &c, nil
default:
return &Message{
Role: openai.ChatMessageRoleAssistant,
Content: payload.String(),
Idx: len(chat.messages) - 1,
NumTokens: numTokens,
}, nil
}
}

// if it doesn't look like a JSON payload, return a streaming message to the caller
chunks := make(chan string, 1)
errCh := make(chan error)
chunks <- response.Choices[0].Delta.Content
go func() {
defer close(chunks)

for {
response, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return
case err != nil:
select {
case <-ctx.Done():
case errCh <- trace.Wrap(err):
}
return
}

select {
case chunks <- response.Choices[0].Delta.Content:
case <-ctx.Done():
return
}
}
}()

return &StreamingMessage{
Role: openai.ChatMessageRoleAssistant,
Idx: len(chat.messages) - 1,
Chunks: chunks,
Error: errCh,
}, nil
}
Loading