Skip to content

assist: Refactor token counting#29224

Merged
hugoShaka merged 2 commits intomasterfrom
hugo/fix-e1805-assist-token-count
Jul 21, 2023
Merged

assist: Refactor token counting#29224
hugoShaka merged 2 commits intomasterfrom
hugo/fix-e1805-assist-token-count

Conversation

@hugoShaka
Copy link
Copy Markdown
Contributor

@hugoShaka hugoShaka commented Jul 17, 2023

Fixes https://github.com/gravitational/teleport.e/issues/1805

This PR refactors token counting by decorrelating token count and message responses. With the actor model, tokens can be used in multiple ways (picking tools, invoking them, ...), which don't necessarily end up in a final action (sometimes we return a nextStep instead). Streaming responses were another challenge: the agent returned without the completion being over (it returned a routine streaming the deltas sent by the model).

This PR introduces a TokenCounter interface that abstracts synchronous and asynchronous token counting. All token-consuming operations must return a TokenCounter. TokensCounters are stored in the agent state and returned once the agent exists. Finally, the token counters are evaluated asynchronously to give the streaming completion requests enough time to finish.

@hugoShaka hugoShaka force-pushed the hugo/fix-e1805-assist-token-count branch from dffc0ea to c636d4a Compare July 18, 2023 14:09
@hugoShaka hugoShaka marked this pull request as ready for review July 18, 2023 16:46
Comment thread lib/web/assistant.go Outdated
Comment thread lib/ai/model/tokencount.go Outdated

// NewPromptTokenCounter takes a list of openai.ChatCompletionMessage and
// computes how many tokens are used by sending those messages to the model.
func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*PromptTokenCounter, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this function return the count only? The PromptTokenCounter sounds a bit redundant.

Copy link
Copy Markdown
Contributor Author

@hugoShaka hugoShaka Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the two static counter types, got rid of the struct, and answered part of the comment here: #29224 (comment)

Comment thread lib/ai/model/agent.go Outdated
Comment thread lib/ai/model/tokencount.go Outdated

// NewSynchronousTokenCounter takes the completion request output and
// computes how many tokens were used by the model to generate this result.
func NewSynchronousTokenCounter(completion string) (*SynchronousTokenCounter, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The same as above. Do we need to return the token counter struct where we could only return the count?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of abstracting the different type of token counting behind an interface was to hide whether token counting was immediate or asynchronous. From the caller pov, everything is a promise of a token count, even if we already computed it.

For prompt tokens we could return the count directly as they are always synchronously countable. For completion tokens that's not the case.

To simplify the prompt and synchronous counters I removed the struct and made them an integer type alias.

Comment thread lib/ai/model/tokencount.go Outdated
Comment thread lib/ai/model/agent.go Outdated
Comment thread lib/ai/model/agent.go Outdated
go func() {
defer close(parts)
defer func() {
errCount := streamingTokenCounter.Finish()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I fully understand this pattern. You create the counter, you add all tokens and then you call TokenCount() in one thread and Finish() in the other? Why TokenCount() cannot just count all tokens and return the value?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revamped the logic to remove the wait logic and the Finished() function. I thought we were streaming completely asynchronously to the front end, but we are waiting for the stream to end in assist.go. Any TokenCount() invocation after this line will return the correct count.

Copy link
Copy Markdown
Contributor

@justinas justinas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides Jakub's notes, generally looks good.

Comment thread lib/ai/model/agent.go
// parseJSONFromModel parses a JSON object from the model output and attempts to sanitize contaminant text
// to avoid triggering self-correction due to some natural language being bundled with the JSON.
// The output type is generic, and thus the structure of the expected JSON varies depending on T.
func parseJSONFromModel[T any](text string) (T, *invalidOutputError) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? I think the best practice is to always return error and let the caller cast if necessary.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this some time ago, this is intentional. Maybe that's best practice but if we have one error type and it's called from 1-2 places privately, I really think casting to error is detrimental to working effectively with the code since we just throw away information and forces ourselves to write boilerplate for handling the other case and remember to update without compiler errors when we have new special types etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the function signature because this messes with the error handling later

image

Without the signature change, it considers that err != nil even if the error is nil with the type *model.invalidOutputError.

Copy link
Copy Markdown
Contributor Author

@hugoShaka hugoShaka Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a playground example demonstrating what happens without this change: https://go.dev/play/p/lD_J3gOIccf

By implicitly setting the err type to error a few lines before, I changed the behaviour of the error check and either had to change the function signature, or store the returned error in a new variable whose type is not error. I preferred the first solution as this is a footgun and it took me 20 minutes to understand why this was happening.

Comment thread lib/ai/model/tokencount.go Outdated
Comment thread lib/ai/model/tokencount.go Outdated

// NewPromptTokenCounter takes a list of openai.ChatCompletionMessage and
// computes how many tokens are used by sending those messages to the model.
func NewPromptTokenCounter(prompt []openai.ChatCompletionMessage) (*PromptTokenCounter, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to split into two classes here and add a relatively large amount of boilerplate compared to the business logic here? is it sufficient to maintain one token counter class that keeps track of both? It's just a lot of newlines, godocs and various misc methods that don't really do anything.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged prompt and synchronous token counters in the same type and eliminated the struct.

Copy link
Copy Markdown
Contributor

@jakule jakule left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, Thanks for all the fixes

Comment thread lib/ai/model/tokencount.go Outdated
Comment thread lib/ai/model/tokencount.go Outdated
Comment thread lib/ai/model/tokencount.go Outdated
Comment thread lib/ai/model/tokencount.go Outdated
Comment thread lib/web/assistant.go Outdated
With the actor model, tokens can be used in multiple ways (picking
tools, invoking them, ...), which don't necessarily end up in a final
action (sometimes we return a nextStep instead). Streaming responses
were another challenge: the agent returned without the completion being
over (it returned a routine streaming the deltas sent by the model).

This PR introduces a TokenCounter interface that abstracts synchronous
and asynchronous token counting. All token-consuming operations must
return a TokenCounter. TokensCounters are stored in the agent state and
returned once the agent exists. Finally, the token counters are
evaluated asynchronously to give the streaming completion requests
enough time to finish.
@hugoShaka hugoShaka force-pushed the hugo/fix-e1805-assist-token-count branch from b283dc4 to d18563a Compare July 20, 2023 18:28
@hugoShaka hugoShaka enabled auto-merge July 20, 2023 18:28
@hugoShaka hugoShaka added this pull request to the merge queue Jul 21, 2023
Merged via the queue into master with commit 2b15263 Jul 21, 2023
@hugoShaka hugoShaka deleted the hugo/fix-e1805-assist-token-count branch July 21, 2023 22:02
@public-teleport-github-review-bot
Copy link
Copy Markdown

@hugoShaka See the table below for backport results.

Branch Result
branch/v13 Create PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants