Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {

// Check for self-hosted vs Cloud.
// TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud
if modules.GetModules().Features().Cloud {
if cfg.ClusterFeatures.GetCloud() {
h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity)
} else {
// Set up a limiter with "infinite limit", the "burst" parameter is ignored
Expand Down
9 changes: 8 additions & 1 deletion lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ type webSuiteConfig struct {

// OpenAIConfig is a custom OpenAI config for the test.
OpenAIConfig *openai.ClientConfig

// ClusterFeatures allows overriding default auth server features
ClusterFeatures *authproto.Features
}

func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite {
Expand Down Expand Up @@ -431,8 +434,12 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite {
fs, err := newDebugFileSystem()
require.NoError(t, err)

features := *modules.GetModules().Features().ToProto() // safe to dereference because ToProto creates a struct and return a pointer to it
if cfg.ClusterFeatures != nil {
features = *cfg.ClusterFeatures
}
handlerConfig := Config{
ClusterFeatures: *modules.GetModules().Features().ToProto(), // safe to dereference because ToProto creates a struct and return a pointer to it
ClusterFeatures: features,
Proxy: revTunServer,
AuthServers: utils.FromAddr(s.server.TLS.Addr()),
DomainName: s.server.ClusterName(),
Expand Down
20 changes: 16 additions & 4 deletions lib/web/assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"

authproto "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/lib/assist"
"github.com/gravitational/teleport/lib/client"
)
Expand Down Expand Up @@ -79,7 +80,8 @@ func Test_runAssistant(t *testing.T) {
testCases := []struct {
name string
responses [][]byte
setup func(*WebSuite)
cfg webSuiteConfig
setup func(*testing.T, *WebSuite)
act func(*testing.T, *websocket.Conn)
}{
{
Expand All @@ -105,7 +107,16 @@ func Test_runAssistant(t *testing.T) {
generateTextResponse(),
generateTextResponse(),
},
setup: func(s *WebSuite) {
cfg: webSuiteConfig{
ClusterFeatures: &authproto.Features{
Cloud: true,
},
},
setup: func(t *testing.T, s *WebSuite) {
// Assert that rate limiter is set up when Cloud feature is active,
// before replacing with a lower capacity rate-limiter for test purposes
require.Equal(t, assistantLimiterRate, s.webHandler.handler.assistantLimiter.Limit())

// 101 token capacity (lookaheadTokens+1) and a slow replenish rate
// to let the first completion request succeed, but not the second one
s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101)
Expand Down Expand Up @@ -150,10 +161,11 @@ func Test_runAssistant(t *testing.T) {

openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
tc.cfg.OpenAIConfig = &openaiCfg
s := newWebSuiteWithConfig(t, tc.cfg)

if tc.setup != nil {
tc.setup(s)
tc.setup(t, s)
}

ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
Expand Down