From c5c165208ebd92e334ac46682432477a8a2932a0 Mon Sep 17 00:00:00 2001 From: Justinas Stankevicius Date: Tue, 16 May 2023 16:32:47 +0300 Subject: [PATCH] Fix Assist rate-limiting in Cloud When Proxy is separate from Auth, Proxy 'modules' will not contain meaningful data. Instead, one must use ClusterFeatures fetched from the Auth server --- lib/web/apiserver.go | 2 +- lib/web/apiserver_test.go | 9 ++++++++- lib/web/assistant_test.go | 20 ++++++++++++++++---- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index eafa271380aa9..d6931ebad6650 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -318,7 +318,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { // Check for self-hosted vs Cloud. // TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud - if modules.GetModules().Features().Cloud { + if cfg.ClusterFeatures.GetCloud() { h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity) } else { // Set up a limiter with "infinite limit", the "burst" parameter is ignored diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index b4aa3e9d169da..2af62d6c07533 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -181,6 +181,9 @@ type webSuiteConfig struct { // OpenAIConfig is a custom OpenAI config for the test. OpenAIConfig *openai.ClientConfig + + // ClusterFeatures allows overriding default auth server features + ClusterFeatures *authproto.Features } func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { @@ -433,8 +436,12 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { fs, err := newDebugFileSystem() require.NoError(t, err) + features := *modules.GetModules().Features().ToProto() // safe to dereference because ToProto creates a struct and return a pointer to it + if cfg.ClusterFeatures != nil { + features = *cfg.ClusterFeatures + } handlerConfig := Config{ - ClusterFeatures: *modules.GetModules().Features().ToProto(), // safe to dereference because ToProto creates a struct and return a pointer to it + ClusterFeatures: features, Proxy: revTunServer, AuthServers: utils.FromAddr(s.server.TLS.Addr()), DomainName: s.server.ClusterName(), diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 25cff80c560cd..8f15b6a82bf51 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -34,6 +34,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/time/rate" + authproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/client" ) @@ -79,7 +80,8 @@ func Test_runAssistant(t *testing.T) { testCases := []struct { name string responses [][]byte - setup func(*WebSuite) + cfg webSuiteConfig + setup func(*testing.T, *WebSuite) act func(*testing.T, *websocket.Conn) }{ { @@ -105,7 +107,16 @@ func Test_runAssistant(t *testing.T) { generateTextResponse(), generateTextResponse(), }, - setup: func(s *WebSuite) { + cfg: webSuiteConfig{ + ClusterFeatures: &authproto.Features{ + Cloud: true, + }, + }, + setup: func(t *testing.T, s *WebSuite) { + // Assert that rate limiter is set up when Cloud feature is active, + // before replacing with a lower capacity rate-limiter for test purposes + require.Equal(t, assistantLimiterRate, s.webHandler.handler.assistantLimiter.Limit()) + // 101 token capacity (lookaheadTokens+1) and a slow replenish rate // to let the first completion request succeed, but not the second one s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101) @@ -150,10 +161,11 @@ func Test_runAssistant(t *testing.T) { openaiCfg := openai.DefaultConfig("test-token") openaiCfg.BaseURL = server.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) + tc.cfg.OpenAIConfig = &openaiCfg + s := newWebSuiteWithConfig(t, tc.cfg) if tc.setup != nil { - tc.setup(s) + tc.setup(t, s) } ws, err := s.makeAssistant(t, s.authPack(t, "foo"))