-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add rate limiting to Assist #26011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rate limiting to Assist #26011
Changes from all commits
385793f
bcd3b38
41b1c4c
54a4ebf
b598ee0
c52f349
984c1b8
beb658c
9b5f09e
5eb9fb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -51,6 +51,7 @@ import ( | |||||
| "golang.org/x/crypto/ssh" | ||||||
| "golang.org/x/exp/slices" | ||||||
| "golang.org/x/mod/semver" | ||||||
| "golang.org/x/time/rate" | ||||||
| "google.golang.org/protobuf/encoding/protojson" | ||||||
|
|
||||||
| "github.com/gravitational/teleport" | ||||||
|
|
@@ -92,6 +93,15 @@ import ( | |||||
| const ( | ||||||
| // SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages. | ||||||
| SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details." | ||||||
|
|
||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Latest discussion about the rate limit: https://gravitational.slack.com/archives/C0509EYASCW/p1683738890958669 |
||||||
| // assistantTokensPerHour defines how many assistant rate limiter tokens are replenished every hour. | ||||||
| assistantTokensPerHour = 140 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. godoc |
||||||
| // assistantLimiterRate is the rate (in tokens per second) | ||||||
| // at which tokens for the assistant rate limiter are replenished | ||||||
| assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second)) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| // assistantLimiterCapacity is the total capacity of the token bucket for the assistant rate limiter. | ||||||
| // The bucket starts full, prefilled for a week. | ||||||
| assistantLimiterCapacity = assistantTokensPerHour * 24 * 7 | ||||||
| ) | ||||||
|
|
||||||
| // healthCheckAppServerFunc defines a function used to perform a health check | ||||||
|
|
@@ -111,7 +121,13 @@ type Handler struct { | |||||
| clock clockwork.Clock | ||||||
| limiter *limiter.RateLimiter | ||||||
| highLimiter *limiter.RateLimiter | ||||||
| healthCheckAppServer healthCheckAppServerFunc | ||||||
| // assistantLimiter limits the amount of tokens that can be consumed | ||||||
| // by OpenAI API calls when using a shared key. | ||||||
| // golang.org/x/time/rate is used, as the oxy ratelimiter | ||||||
| // is quite tightly tied to individual http.Requests, | ||||||
| // and instead we want to consume arbitrary amounts of tokens. | ||||||
| assistantLimiter *rate.Limiter | ||||||
| healthCheckAppServer healthCheckAppServerFunc | ||||||
| // sshPort specifies the SSH proxy port extracted | ||||||
| // from configuration | ||||||
| sshPort string | ||||||
|
|
@@ -301,6 +317,15 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { | |||||
| healthCheckAppServer: cfg.HealthCheckAppServer, | ||||||
| } | ||||||
|
|
||||||
| // Check for self-hosted vs Cloud. | ||||||
| // TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have an issue to trace it? If yes, can you link it here? |
||||||
| if modules.GetModules().Features().Cloud { | ||||||
| h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity) | ||||||
| } else { | ||||||
| // Set up a limiter with "infinite limit", the "burst" parameter is ignored | ||||||
| h.assistantLimiter = rate.NewLimiter(rate.Inf, 0) | ||||||
| } | ||||||
|
|
||||||
| // for properly handling url-encoded parameter values. | ||||||
| h.UseRawPath = true | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -405,6 +405,17 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, | |
| return trace.Wrap(err) | ||
| } | ||
|
|
||
| // We can not know how many tokens we will consume in advance. | ||
| // Try to consume a small amount of tokens first. | ||
| const lookaheadTokens = 100 | ||
| if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) { | ||
| err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) | ||
| if err != nil { | ||
| return trace.Wrap(err) | ||
| } | ||
| continue | ||
|
justinas marked this conversation as resolved.
Outdated
|
||
| } | ||
|
|
||
| //TODO(jakule): Should we sanitize the payload? | ||
| if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil { | ||
| return trace.Wrap(err) | ||
|
|
@@ -415,14 +426,22 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, | |
| return trace.Wrap(err) | ||
| } | ||
|
|
||
| // Once we know how many tokens were consumed for prompt+completion, | ||
| // consume the remaining tokens from the rate limiter bucket. | ||
| extraTokens := usedTokens.Prompt + usedTokens.Completion - lookaheadTokens | ||
| if extraTokens < 0 { | ||
| extraTokens = 0 | ||
| } | ||
| h.assistantLimiter.ReserveN(time.Now(), extraTokens) | ||
|
|
||
| usageEventReq := &proto.SubmitUsageEventRequest{ | ||
| Event: &usageeventsv1.UsageEventOneOf{ | ||
| Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{ | ||
| AssistCompletion: &usageeventsv1.AssistCompletionEvent{ | ||
| ConversationId: conversationID, | ||
| TotalTokens: int64(usedTokens.Prompt + usedTokens.Competition), | ||
| TotalTokens: int64(usedTokens.Prompt + usedTokens.Completion), | ||
| PromptTokens: int64(usedTokens.Prompt), | ||
| CompletionTokens: int64(usedTokens.Competition), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lol, sorry for that 😅 |
||
| CompletionTokens: int64(usedTokens.Completion), | ||
| }, | ||
| }, | ||
| }, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -152,7 +152,10 @@ async function convertServerMessage( | |
| message: ServerMessage, | ||
| clusterId: string | ||
| ): Promise<MessagesAction> { | ||
| if (message.type === 'CHAT_MESSAGE_ASSISTANT') { | ||
| if ( | ||
| message.type === 'CHAT_MESSAGE_ASSISTANT' || | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ryanclark Can you take a look at it? |
||
| message.type === 'CHAT_MESSAGE_ERROR' | ||
| ) { | ||
| const newMessage: Message = { | ||
| author: Author.Teleport, | ||
| timestamp: message.created_time, | ||
|
|
@@ -263,6 +266,8 @@ async function convertServerMessage( | |
|
|
||
| return (messages: Message[]) => messages.push(newMessage); | ||
| } | ||
|
|
||
| throw new Error('unrecognized message type'); | ||
| } | ||
|
|
||
| function findIntersection<T>(elems: T[][]): T[] { | ||
|
|
@@ -364,9 +369,12 @@ export function MessagesContextProvider( | |
| if (lastMessage !== null) { | ||
| const value = JSON.parse(lastMessage.data) as ServerMessage; | ||
|
|
||
| // When a streaming message ends, or a non-streaming message arrives | ||
| if ( | ||
| value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' || | ||
| value.type === 'COMMAND' | ||
| value.type === 'COMMAND' || | ||
| value.type === 'CHAT_MESSAGE_ASSISTANT' || | ||
| value.type === 'CHAT_MESSAGE_ERROR' | ||
| ) { | ||
| setResponding(false); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using this instead of oxy ratelimiter, as the latter is quite tightly coupled with individual
http.Request-s.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this PR command as a code command too? The reasoning won't be easily visible to people after this PR is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.