Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lib/assist/assist.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ const (
MessageKindAssistantPartialFinalize MessageType = "CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE"
// MessageKindSystemMessage is the type of Assist message that contains the system message.
MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM"
// MessageKindError is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like.
MessageKindError MessageType = "CHAT_MESSAGE_ERROR"
)

// Assist is the Teleport Assist client.
Expand Down Expand Up @@ -210,7 +212,7 @@ type TokensUsed struct {
// Prompt is a number of tokens used in the prompt.
Prompt int
// Completion is a number of tokens used in the completion.
Competition int
Completion int
}

// ProcessComplete processes the completion request and returns the number of tokens used.
Expand Down Expand Up @@ -386,8 +388,8 @@ func (c *Chat) ProcessComplete(ctx context.Context,
}

return &TokensUsed{
Prompt: promptTokens,
Competition: numTokens,
Prompt: promptTokens,
Completion: numTokens,
}, nil
}

Expand Down
27 changes: 26 additions & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/mod/semver"
"golang.org/x/time/rate"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this instead of oxy ratelimiter, as the latter is quite tightly coupled with individual http.Request-s.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this PR command as a code command too? The reasoning won't be easily visible to people after this PR is merged.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

"google.golang.org/protobuf/encoding/protojson"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -92,6 +93,15 @@ import (
const (
// SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages.
SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details."

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// assistantTokensPerHour defines how many assistant rate limiter tokens are replenished every hour.
assistantTokensPerHour = 140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

godoc

// assistantLimiterRate is the rate (in tokens per second)
// at which tokens for the assistant rate limiter are replenished
assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second))
assistantLimiterRate = rate.Limit(assistantTokensPerHour / time.Hour.Seconds())

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time.Hour.Seconds() sadly can not be used in a const definition

// assistantLimiterCapacity is the total capacity of the token bucket for the assistant rate limiter.
// The bucket starts full, prefilled for a week.
assistantLimiterCapacity = assistantTokensPerHour * 24 * 7
)

// healthCheckAppServerFunc defines a function used to perform a health check
Expand All @@ -111,7 +121,13 @@ type Handler struct {
clock clockwork.Clock
limiter *limiter.RateLimiter
highLimiter *limiter.RateLimiter
healthCheckAppServer healthCheckAppServerFunc
// assistantLimiter limits the amount of tokens that can be consumed
// by OpenAI API calls when using a shared key.
// golang.org/x/time/rate is used, as the oxy ratelimiter
// is quite tightly tied to individual http.Requests,
// and instead we want to consume arbitrary amounts of tokens.
assistantLimiter *rate.Limiter
healthCheckAppServer healthCheckAppServerFunc
// sshPort specifies the SSH proxy port extracted
// from configuration
sshPort string
Expand Down Expand Up @@ -301,6 +317,15 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
healthCheckAppServer: cfg.HealthCheckAppServer,
}

// Check for self-hosted vs Cloud.
// TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue to trace it? If yes, can you link it here?

if modules.GetModules().Features().Cloud {
h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity)
} else {
// Set up a limiter with "infinite limit", the "burst" parameter is ignored
h.assistantLimiter = rate.NewLimiter(rate.Inf, 0)
}

// for properly handling url-encoded parameter values.
h.UseRawPath = true

Expand Down
23 changes: 21 additions & 2 deletions lib/web/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,17 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
return trace.Wrap(err)
}

// We can not know how many tokens we will consume in advance.
// Try to consume a small amount of tokens first.
const lookaheadTokens = 100
if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) {
err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC())
if err != nil {
return trace.Wrap(err)
}
continue
Comment thread
justinas marked this conversation as resolved.
Outdated
}

//TODO(jakule): Should we sanitize the payload?
if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil {
return trace.Wrap(err)
Expand All @@ -415,14 +426,22 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
return trace.Wrap(err)
}

// Once we know how many tokens were consumed for prompt+completion,
// consume the remaining tokens from the rate limiter bucket.
extraTokens := usedTokens.Prompt + usedTokens.Completion - lookaheadTokens
if extraTokens < 0 {
extraTokens = 0
}
h.assistantLimiter.ReserveN(time.Now(), extraTokens)

usageEventReq := &proto.SubmitUsageEventRequest{
Event: &usageeventsv1.UsageEventOneOf{
Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{
AssistCompletion: &usageeventsv1.AssistCompletionEvent{
ConversationId: conversationID,
TotalTokens: int64(usedTokens.Prompt + usedTokens.Competition),
TotalTokens: int64(usedTokens.Prompt + usedTokens.Completion),
PromptTokens: int64(usedTokens.Prompt),
CompletionTokens: int64(usedTokens.Competition),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol, sorry for that 😅

CompletionTokens: int64(usedTokens.Completion),
},
},
},
Expand Down
160 changes: 115 additions & 45 deletions lib/web/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"

"github.com/gravitational/teleport/lib/assist"
"github.com/gravitational/teleport/lib/client"
Expand All @@ -40,71 +41,140 @@ import (
func Test_runAssistant(t *testing.T) {
t.Parallel()

responses := [][]byte{
generateTextResponse(),
readPartialMessage := func(t *testing.T, ws *websocket.Conn) string {
var msg assistantMessage
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type)
return msg.Payload
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
readStreamEnd := func(t *testing.T, ws *websocket.Conn) {
var msg assistantMessage
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type)
}

readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) {
var msg assistantMessage
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

responses = responses[1:]
}))
defer server.Close()
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
require.Equal(t, assist.MessageKindError, msg.Type)
require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.")
}

ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ws.Close()) })
testCases := []struct {
name string
responses [][]byte
setup func(*WebSuite)
act func(*testing.T, *websocket.Conn)
}{
{
name: "normal",
responses: [][]byte{
generateTextResponse(),
},
act: func(t *testing.T, ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)

require.Contains(t, readPartialMessage(t, ws), "Which")
require.Contains(t, readPartialMessage(t, ws), "node do")
require.Contains(t, readPartialMessage(t, ws), "you want")
require.Contains(t, readPartialMessage(t, ws), "use?")

readStreamEnd(t, ws)
},
},
{
name: "rate limited",
responses: [][]byte{
generateTextResponse(),
generateTextResponse(),
},
setup: func(s *WebSuite) {
// 101 token capacity (lookaheadTokens+1) and a slow replenish rate
// to let the first completion request succeed, but not the second one
s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101)

},
act: func(t *testing.T, ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)

require.Contains(t, readPartialMessage(t, ws), "Which")
require.Contains(t, readPartialMessage(t, ws), "node do")
require.Contains(t, readPartialMessage(t, ws), "you want")
require.Contains(t, readPartialMessage(t, ws), "use?")

readStreamEnd(t, ws)

err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`))
require.NoError(t, err)

readRateLimitedMessage(t, ws)
},
},
}

_, payload, err := ws.ReadMessage()
require.NoError(t, err)
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
responses := tc.responses
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)
require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]

require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")

err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)
responses = responses[1:]
}))
t.Cleanup(server.Close)

readPartialMessage := func() string {
_, payload, err = ws.ReadMessage()
require.NoError(t, err)
openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})

err = json.Unmarshal(payload, &msg)
require.NoError(t, err)
if tc.setup != nil {
tc.setup(s)
}

require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type)
return msg.Payload
}
ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ws.Close()) })

require.Contains(t, readPartialMessage(), "Which")
require.Contains(t, readPartialMessage(), "node do")
require.Contains(t, readPartialMessage(), "you want")
require.Contains(t, readPartialMessage(), "use?")
_, payload, err := ws.ReadMessage()
require.NoError(t, err)

readStraemEnd := func() {
_, payload, err = ws.ReadMessage()
require.NoError(t, err)
var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)

err = json.Unmarshal(payload, &msg)
require.NoError(t, err)
// Expect "hello" message
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")

require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type)
tc.act(t, ws)
})
}

readStraemEnd()
}

func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) {
Expand Down
12 changes: 10 additions & 2 deletions web/packages/teleport/src/Assist/contexts/messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ async function convertServerMessage(
message: ServerMessage,
clusterId: string
): Promise<MessagesAction> {
if (message.type === 'CHAT_MESSAGE_ASSISTANT') {
if (
message.type === 'CHAT_MESSAGE_ASSISTANT' ||
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ryanclark Can you take a look at it?

message.type === 'CHAT_MESSAGE_ERROR'
) {
const newMessage: Message = {
author: Author.Teleport,
timestamp: message.created_time,
Expand Down Expand Up @@ -263,6 +266,8 @@ async function convertServerMessage(

return (messages: Message[]) => messages.push(newMessage);
}

throw new Error('unrecognized message type');
}

function findIntersection<T>(elems: T[][]): T[] {
Expand Down Expand Up @@ -364,9 +369,12 @@ export function MessagesContextProvider(
if (lastMessage !== null) {
const value = JSON.parse(lastMessage.data) as ServerMessage;

// When a streaming message ends, or a non-streaming message arrives
if (
value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' ||
value.type === 'COMMAND'
value.type === 'COMMAND' ||
value.type === 'CHAT_MESSAGE_ASSISTANT' ||
value.type === 'CHAT_MESSAGE_ERROR'
) {
setResponding(false);
}
Expand Down