diff --git a/lib/ai/chat.go b/lib/ai/chat.go index 33381637116ff..e4698054e8f2d 100644 --- a/lib/ai/chat.go +++ b/lib/ai/chat.go @@ -36,6 +36,8 @@ type Message struct { Role string `json:"role"` Content string `json:"content"` Idx int `json:"idx"` + // NumTokens is the number of completion tokens for the (non-streaming) message + NumTokens int `json:"-"` } // Chat represents a conversation between a user and an assistant with context memory. @@ -95,6 +97,8 @@ type CompletionCommand struct { Command string `json:"command,omitempty"` Nodes []string `json:"nodes,omitempty"` Labels []Label `json:"labels,omitempty"` + // NumTokens is the number of completion tokens for the (non-streaming) message + NumTokens int `json:"-"` } // Summary create a short summary for the given input. @@ -135,7 +139,7 @@ type StreamingMessage struct { // Complete completes the conversation with a message from the assistant based on the current context. // On success, it returns the message and the number of tokens used for the completion. -func (chat *Chat) Complete(ctx context.Context) (any, int, error) { +func (chat *Chat) Complete(ctx context.Context) (any, error) { var numTokens int // if the chat is empty, return the initial response we predefine instead of querying GPT-4 @@ -144,7 +148,7 @@ func (chat *Chat) Complete(ctx context.Context) (any, int, error) { Role: openai.ChatMessageRoleAssistant, Content: initialAIResponse, Idx: len(chat.messages) - 1, - }, numTokens, nil + }, nil } // if not, copy the current chat log to a new slice and append the suffix instruction @@ -167,7 +171,7 @@ func (chat *Chat) Complete(ctx context.Context) (any, int, error) { }, ) if err != nil { - return nil, numTokens, trace.Wrap(err) + return nil, trace.Wrap(err) } // fetch the first delta to check for a possible JSON payload @@ -175,7 +179,7 @@ func (chat *Chat) Complete(ctx context.Context) (any, int, error) { top: response, err = stream.Recv() if err != nil { - return nil, numTokens, trace.Wrap(err) + return nil, trace.Wrap(err) } numTokens++ @@ -194,7 +198,7 @@ top: case errors.Is(err, io.EOF): break outer case err != nil: - return nil, numTokens, trace.Wrap(err) + return nil, trace.Wrap(err) } numTokens++ @@ -206,13 +210,15 @@ top: err = json.Unmarshal([]byte(payload), &c) switch err { case nil: - return &c, numTokens, nil + c.NumTokens = numTokens + return &c, nil default: return &Message{ - Role: openai.ChatMessageRoleAssistant, - Content: payload, - Idx: len(chat.messages) - 1, - }, numTokens, nil + Role: openai.ChatMessageRoleAssistant, + Content: payload, + Idx: len(chat.messages) - 1, + NumTokens: numTokens, + }, nil } } @@ -246,5 +252,5 @@ top: Idx: len(chat.messages) - 1, Chunks: chunks, Error: errCh, - }, numTokens, nil + }, nil } diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 7019bdbf63e24..f0017df340d20 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -463,8 +463,10 @@ func tryFindEmbeddedCommand(message string) *ai.CompletionCommand { func processComplete(ctx context.Context, h *Handler, chat *ai.Chat, conversationID string, ws *websocket.Conn, authClient auth.ClientI, ) (int, error) { + var numTokens int + // query the assistant and fetch an answer - message, numTokens, err := chat.Complete(ctx) + message, err := chat.Complete(ctx) if err != nil { return numTokens, trace.Wrap(err) } @@ -577,6 +579,7 @@ func processComplete(ctx context.Context, h *Handler, chat *ai.Chat, conversatio } } case *ai.Message: + numTokens = message.NumTokens // write assistant message to both in-memory chain and persistent storage chat.Insert(message.Role, message.Content) protoMsg := &proto.AssistantMessage{ @@ -593,6 +596,7 @@ func processComplete(ctx context.Context, h *Handler, chat *ai.Chat, conversatio return numTokens, trace.Wrap(err) } case *ai.CompletionCommand: + numTokens = message.NumTokens payload := commandPayload{ Command: message.Command, Nodes: message.Nodes,