diff --git a/lib/ai/testutils/http.go b/lib/ai/testutils/http.go new file mode 100644 index 0000000000000..98cf6a8ff21b2 --- /dev/null +++ b/lib/ai/testutils/http.go @@ -0,0 +1,129 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "encoding/json" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +// GetTestHandlerFn returns a handler function that can be used to OpenAI API used by +// the chat API. It takes a list of responses that will be returned in order. +func GetTestHandlerFn(t *testing.T, responses []string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || !(r.URL.Path == "/chat/completions") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return + } + + switch r.Header.Get("Accept") { + case "application/json; charset=utf-8", "application/json": + responses = messageResponse(w, r, t, responses) + case "text/event-stream": + responses = streamResponse(w, t, responses) + default: + http.Error(w, "Unexpected request", http.StatusBadRequest) + } + } +} + +func streamResponse(w http.ResponseWriter, t *testing.T, responses []string) []string { + w.Header().Set("Content-Type", "text/event-stream") + + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return responses + } + + resp := &openai.ChatCompletionStreamResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "completion", + Created: time.Now().Unix(), + Model: openai.GPT4, + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: responses[0], + Role: openai.ChatMessageRoleAssistant, + }, + FinishReason: "", + }, + }, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") + + _, err = w.Write([]byte("data: ")) + assert.NoError(t, err, "Write error") + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + _, err = w.Write([]byte("\n\nevent: done\ndata: [DONE]\n\n")) + assert.NoError(t, err, "Write error") + + return responses[1:] +} + +func messageResponse(w http.ResponseWriter, r *http.Request, t *testing.T, responses []string) []string { + w.Header().Set("Content-Type", "application/json") + + req := &openai.ChatCompletionRequest{} + err := json.NewDecoder(r.Body).Decode(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + + // Use assert as require doesn't work when called from a goroutine + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return responses + } + + dataBytes := responses[0] + + resp := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + Model: req.Model, + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: dataBytes, + Name: "", + }, + }, + }, + Usage: openai.Usage{}, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") + + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + + return responses[1:] +} diff --git a/lib/auth/assist/assistv1/service.go b/lib/auth/assist/assistv1/service.go index dfc38666eaa0c..29a71d815875b 100644 --- a/lib/auth/assist/assistv1/service.go +++ b/lib/auth/assist/assistv1/service.go @@ -22,23 +22,31 @@ import ( "context" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/emptypb" + "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/services" ) // ServiceConfig holds configuration options for // the assist gRPC service. type ServiceConfig struct { - Backend services.Assistant + Backend services.Assistant + Authorizer authz.Authorizer + Logger *logrus.Entry } // Service implements the teleport.assist.v1.AssistService RPC service. type Service struct { assist.UnimplementedAssistServiceServer - backend services.Assistant + backend services.Assistant + authorizer authz.Authorizer + log *logrus.Entry } // NewService returns a new assist gRPC service. @@ -46,22 +54,46 @@ func NewService(cfg *ServiceConfig) (*Service, error) { switch { case cfg.Backend == nil: return nil, trace.BadParameter("backend is required") + case cfg.Authorizer == nil: + return nil, trace.BadParameter("authorizer is required") + case cfg.Logger == nil: + cfg.Logger = logrus.WithField(trace.Component, "assist.service") } return &Service{ - backend: cfg.Backend, + backend: cfg.Backend, + authorizer: cfg.Authorizer, + log: cfg.Logger, }, nil } // CreateAssistantConversation creates a new conversation entry in the backend. func (a *Service) CreateAssistantConversation(ctx context.Context, req *assist.CreateAssistantConversationRequest) (*assist.CreateAssistantConversationResponse, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbCreate) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to create conversation for user %q", authCtx.User.GetName(), req.Username) + } + resp, err := a.backend.CreateAssistantConversation(ctx, req) return resp, trace.Wrap(err) } // UpdateAssistantConversationInfo updates the conversation info for a conversation. -func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, request *assist.UpdateAssistantConversationInfoRequest) (*emptypb.Empty, error) { - err := a.backend.UpdateAssistantConversationInfo(ctx, request) +func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, req *assist.UpdateAssistantConversationInfoRequest) (*emptypb.Empty, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbUpdate) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to update conversation for user %q", authCtx.User.GetName(), req.Username) + } + + err = a.backend.UpdateAssistantConversationInfo(ctx, req) if err != nil { return &emptypb.Empty{}, trace.Wrap(err) } @@ -71,22 +103,79 @@ func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, request * // GetAssistantConversations returns all conversations started by a user. func (a *Service) GetAssistantConversations(ctx context.Context, req *assist.GetAssistantConversationsRequest) (*assist.GetAssistantConversationsResponse, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbList) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to list conversations for user %q", authCtx.User.GetName(), req.GetUsername()) + } + resp, err := a.backend.GetAssistantConversations(ctx, req) return resp, trace.Wrap(err) } // GetAssistantMessages returns all messages with given conversation ID. func (a *Service) GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbRead) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to get messages for user %q", authCtx.User.GetName(), req.GetUsername()) + } + resp, err := a.backend.GetAssistantMessages(ctx, req) return resp, trace.Wrap(err) } // CreateAssistantMessage adds the message to the backend. func (a *Service) CreateAssistantMessage(ctx context.Context, req *assist.CreateAssistantMessageRequest) (*emptypb.Empty, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbCreate) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to create message for user %q", authCtx.User.GetName(), req.GetUsername()) + } + return &emptypb.Empty{}, trace.Wrap(a.backend.CreateAssistantMessage(ctx, req)) } // IsAssistEnabled returns true if the assist is enabled or not on the auth level. -func (a *Service) IsAssistEnabled(ctx context.Context, req *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) { +func (a *Service) IsAssistEnabled(ctx context.Context, _ *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) { + authCtx, err := a.authorizer.Authorize(ctx) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + // Check if this endpoint is called by a user or Proxy. + if authz.IsLocalUser(*authCtx) { + checkErr := authCtx.Checker.CheckAccessToRule( + &services.Context{User: authCtx.User}, + defaults.Namespace, types.KindAssistant, types.VerbRead, + false, /* silent */ + ) + if checkErr != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + } else { + // This endpoint is called from Proxy to check if the assist is enabled. + // Proxy credentials are used instead of the user credentials. + requestedByProxy := authz.HasBuiltinRole(*authCtx, string(types.RoleProxy)) + if !requestedByProxy { + return nil, trace.AccessDenied("only proxy is allowed to call IsAssistEnabled endpoint") + } + } + + // Check if assist can use the backend. return a.backend.IsAssistEnabled(ctx) } + +// userHasAccess returns true if the user should have access to the resource. +func userHasAccess(authCtx *authz.Context, req interface{ GetUsername() string }) bool { + return !authz.IsCurrentUser(*authCtx, req.GetUsername()) && !authz.HasBuiltinRole(*authCtx, string(types.RoleAdmin)) +} diff --git a/lib/auth/assist/assistv1/service_test.go b/lib/auth/assist/assistv1/service_test.go new file mode 100644 index 0000000000000..4857dfac142b8 --- /dev/null +++ b/lib/auth/assist/assistv1/service_test.go @@ -0,0 +1,301 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package assistv1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/tlsca" +) + +const ( + defaultUser = "test-user" + noAccessUser = "user-no-access" +) + +func TestService_CreateAssistantConversation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + username string + req *assistpb.CreateAssistantConversationRequest + wantErr assert.ErrorAssertionFunc + assertResponse func(t *testing.T, resp *assistpb.CreateAssistantConversationResponse) + }{ + { + name: "success", + username: defaultUser, + req: &assistpb.CreateAssistantConversationRequest{ + Username: defaultUser, + CreatedTime: timestamppb.Now(), + }, + wantErr: assert.NoError, + assertResponse: func(t *testing.T, resp *assistpb.CreateAssistantConversationResponse) { + require.NotEmpty(t, resp.GetId()) + }, + }, + { + name: "access denies - RBAC", + username: noAccessUser, + req: &assistpb.CreateAssistantConversationRequest{ + Username: noAccessUser, + CreatedTime: timestamppb.Now(), + }, + wantErr: assert.Error, + }, + { + name: "access denied - different user", + username: defaultUser, + req: &assistpb.CreateAssistantConversationRequest{ + Username: noAccessUser, + CreatedTime: timestamppb.Now(), + }, + wantErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxs, svc := initSvc(t) + + got, err := svc.CreateAssistantConversation(ctxs[tt.username], tt.req) + tt.wantErr(t, err) + + if tt.assertResponse != nil { + tt.assertResponse(t, got) + } + }) + } +} + +func TestService_GetAssistantConversations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + username string + req *assistpb.GetAssistantConversationsRequest + wantErr assert.ErrorAssertionFunc + assertResponse func(t *testing.T, resp *assistpb.CreateAssistantConversationResponse) + }{ + { + name: "success", + username: defaultUser, + req: &assistpb.GetAssistantConversationsRequest{ + Username: defaultUser, + }, + wantErr: assert.NoError, + }, + { + name: "access denies - RBAC", + username: noAccessUser, + req: &assistpb.GetAssistantConversationsRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + { + name: "access denied - different user", + username: defaultUser, + req: &assistpb.GetAssistantConversationsRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxs, svc := initSvc(t) + + _, err := svc.GetAssistantConversations(ctxs[tt.username], tt.req) + tt.wantErr(t, err) + }) + } +} + +func TestService_InsertAssistantMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + username string + req *assistpb.CreateAssistantMessageRequest + wantErr assert.ErrorAssertionFunc + }{ + { + name: "success", + username: defaultUser, + req: &assistpb.CreateAssistantMessageRequest{ + Username: defaultUser, + Message: &assistpb.AssistantMessage{ + Type: "CHAT_MESSAGE_ASSISTANT", + CreatedTime: timestamppb.Now(), + Payload: "Blah", + }, + }, + wantErr: assert.NoError, + }, + { + name: "access denies - RBAC", + username: noAccessUser, + req: &assistpb.CreateAssistantMessageRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + { + name: "access denied - different user", + username: defaultUser, + req: &assistpb.CreateAssistantMessageRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxs, svc := initSvc(t) + + // Create a conversation that we can remove, so we don't hit "conversation doesn't exist" error + convMsg, err := svc.backend.CreateAssistantConversation(ctxs[tt.username], &assistpb.CreateAssistantConversationRequest{ + Username: tt.username, + CreatedTime: timestamppb.Now(), + }) + require.NoError(t, err) + + conversationID := convMsg.GetId() + + tt.req.ConversationId = conversationID + + _, err = svc.CreateAssistantMessage(ctxs[tt.username], tt.req) + tt.wantErr(t, err) + }) + } +} + +func initSvc(t *testing.T) (map[string]context.Context, *Service) { + ctx := context.Background() + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + + clusterConfigSvc, err := local.NewClusterConfigurationService(backend) + require.NoError(t, err) + trustSvc := local.NewCAService(backend) + roleSvc := local.NewAccessService(backend) + userSvc := local.NewIdentityService(backend) + + require.NoError(t, clusterConfigSvc.SetAuthPreference(ctx, types.DefaultAuthPreference())) + require.NoError(t, clusterConfigSvc.SetClusterAuditConfig(ctx, types.DefaultClusterAuditConfig())) + require.NoError(t, clusterConfigSvc.SetClusterNetworkingConfig(ctx, types.DefaultClusterNetworkingConfig())) + require.NoError(t, clusterConfigSvc.SetSessionRecordingConfig(ctx, types.DefaultSessionRecordingConfig())) + + accessPoint := struct { + services.ClusterConfiguration + services.Trust + services.RoleGetter + services.UserGetter + }{ + ClusterConfiguration: clusterConfigSvc, + Trust: trustSvc, + RoleGetter: roleSvc, + UserGetter: userSvc, + } + + accessService := local.NewAccessService(backend) + eventService := local.NewEventsService(backend) + lockWatcher, err := services.NewLockWatcher(ctx, services.LockWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Client: eventService, + Component: "test", + }, + LockGetter: accessService, + }) + require.NoError(t, err) + + authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ + ClusterName: "test-cluster", + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + }) + require.NoError(t, err) + + roles := map[string]types.Role{} + + role, err := types.NewRole("allow-rules", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + { + Resources: []string{types.KindAssistant}, + Verbs: []string{types.VerbList, types.VerbRead, types.VerbUpdate, types.VerbCreate, types.VerbDelete}, + }, + }, + }, + }) + require.NoError(t, err) + + roles[defaultUser] = role + + roleNoAccess, err := types.NewRole("no-rules", types.RoleSpecV6{ + Allow: types.RoleConditions{}, + }) + require.NoError(t, err) + roles["user-no-access"] = roleNoAccess + + ctxs := make(map[string]context.Context, len(roles)) + for username, role := range roles { + err = roleSvc.CreateRole(ctx, role) + require.NoError(t, err) + + user, err := types.NewUser(username) + user.AddRole(role.GetName()) + require.NoError(t, err) + + err = userSvc.CreateUser(user) + require.NoError(t, err) + + ctx = authz.ContextWithUser(ctx, authz.LocalUser{ + Username: user.GetName(), + Identity: tlsca.Identity{ + Username: user.GetName(), + Groups: []string{role.GetName()}, + }, + }) + ctxs[user.GetName()] = ctx + } + + svc, err := NewService(&ServiceConfig{ + Backend: local.NewAssistService(backend), + Authorizer: authorizer, + }) + require.NoError(t, err) + + return ctxs, svc +} diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 6f8e805043fb7..fa83d048139bb 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -212,7 +212,7 @@ func (a *ServerWithRoles) serverAction() error { // whether any of the given roles match the role set. func (a *ServerWithRoles) hasBuiltinRole(roles ...types.SystemRole) bool { for _, role := range roles { - if HasBuiltinRole(a.context, string(role)) { + if authz.HasBuiltinRole(a.context, string(role)) { return true } } @@ -221,15 +221,11 @@ func (a *ServerWithRoles) hasBuiltinRole(roles ...types.SystemRole) bool { // HasBuiltinRole checks if the identity is a builtin role with the matching // name. +// Deprecated: use authz.HasBuiltinRole instead. func HasBuiltinRole(authContext authz.Context, name string) bool { - if _, ok := authContext.Identity.(authz.BuiltinRole); !ok { - return false - } - if !authContext.Checker.HasRole(name) { - return false - } - - return true + // TODO(jakule): This function can be removed once teleport.e is updated + // to use authz.HasBuiltinRole. + return authz.HasBuiltinRole(authContext, name) } // HasRemoteBuiltinRole checks if the identity is a remote builtin role with the @@ -5255,8 +5251,8 @@ func (a *ServerWithRoles) checkAccessToNode(node types.Server) error { // In addition, allow proxy (and remote proxy) to access all nodes for its // smart resolution address resolution. Once the smart resolution logic is // moved to the auth server, this logic can be removed. - builtinRole := HasBuiltinRole(a.context, string(types.RoleAdmin)) || - HasBuiltinRole(a.context, string(types.RoleProxy)) || + builtinRole := authz.HasBuiltinRole(a.context, string(types.RoleAdmin)) || + authz.HasBuiltinRole(a.context, string(types.RoleProxy)) || HasRemoteBuiltinRole(a.context, string(types.RoleRemoteProxy)) if builtinRole { @@ -6114,56 +6110,32 @@ func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, // CreateAssistantConversation creates a new conversation entry in the backend. func (a *ServerWithRoles) CreateAssistantConversation(ctx context.Context, req *assist.CreateAssistantConversationRequest) (*assist.CreateAssistantConversationResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbCreate); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.CreateAssistantConversation(ctx, req) + return nil, trace.NotImplemented("CreateAssistantConversation must not be called on auth.ServerWithRoles") } // GetAssistantConversations returns all conversations started by a user. func (a *ServerWithRoles) GetAssistantConversations(ctx context.Context, request *assist.GetAssistantConversationsRequest) (*assist.GetAssistantConversationsResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbList); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.GetAssistantConversations(ctx, request) + return nil, trace.NotImplemented("GetAssistantConversations must not be called on auth.ServerWithRoles") } // GetAssistantMessages returns all messages with given conversation ID. func (a *ServerWithRoles) GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbRead); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.GetAssistantMessages(ctx, req) + return nil, trace.NotImplemented("GetAssistantMessages must not be called on auth.ServerWithRoles") } // IsAssistEnabled returns true if the assist is enabled or not on the auth level. func (a *ServerWithRoles) IsAssistEnabled(ctx context.Context) (*assist.IsAssistEnabledResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbRead); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.IsAssistEnabled(ctx) + return nil, trace.NotImplemented("IsAssistEnabled must not be called on auth.ServerWithRoles") } // CreateAssistantMessage adds the message to the backend. func (a *ServerWithRoles) CreateAssistantMessage(ctx context.Context, msg *assist.CreateAssistantMessageRequest) error { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbUpdate); err != nil { - return trace.Wrap(err) - } - - return a.authServer.CreateAssistantMessage(ctx, msg) + return trace.NotImplemented("CreateAssistantMessage must not be called on auth.ServerWithRoles") } // UpdateAssistantConversationInfo updates the conversation info. func (a *ServerWithRoles) UpdateAssistantConversationInfo(ctx context.Context, msg *assist.UpdateAssistantConversationInfoRequest) error { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbUpdate); err != nil { - return trace.Wrap(err) - } - - return a.authServer.UpdateAssistantConversationInfo(ctx, msg) + return trace.NotImplemented("UpdateAssistantConversationInfo must not be called on auth.ServerWithRoles") } // CloneHTTPClient creates a new HTTP client with the same configuration. diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 9973b39db57fd..7e41b8560839b 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5144,7 +5144,8 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { // Initialize and register the assist service. assistSrv, err := assistv1.NewService(&assistv1.ServiceConfig{ - Backend: cfg.AuthServer.Services, + Backend: cfg.AuthServer.Services, + Authorizer: cfg.Authorizer, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index ec7d43b487d92..e3f784ce2984b 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -1294,3 +1294,27 @@ func UserFromContext(ctx context.Context) (IdentityGetter, error) { } return user, nil } + +// HasBuiltinRole checks if the identity is a builtin role with the matching +// name. +func HasBuiltinRole(authContext Context, name string) bool { + if _, ok := authContext.Identity.(BuiltinRole); !ok { + return false + } + if !authContext.Checker.HasRole(name) { + return false + } + + return true +} + +// IsLocalUser checks if the identity is a local user. +func IsLocalUser(authContext Context) bool { + _, ok := authContext.Identity.(LocalUser) + return ok +} + +// IsCurrentUser checks if the identity is a local user matching the given username +func IsCurrentUser(authContext Context, username string) bool { + return IsLocalUser(authContext) && authContext.User.GetName() == username +} diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5af2877d01148..3a07d1ce13fa3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -630,7 +630,7 @@ type authPack struct { // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. -func (s *WebSuite) authPack(t *testing.T, user string) *authPack { +func (s *WebSuite) authPack(t *testing.T, user string, roles ...string) *authPack { login := s.user pass := "abc123" rawSecret := "def456" @@ -644,7 +644,7 @@ func (s *WebSuite) authPack(t *testing.T, user string) *authPack { err = s.server.Auth().SetAuthPreference(s.ctx, ap) require.NoError(t, err) - s.createUser(t, user, login, pass, otpSecret) + s.createUser(t, user, login, pass, otpSecret, roles...) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) @@ -683,7 +683,7 @@ func (s *WebSuite) authPack(t *testing.T, user string) *authPack { } } -func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { +func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string, roles ...string) { teleUser, err := types.NewUser(user) require.NoError(t, err) role := services.RoleForUser(teleUser) @@ -695,6 +695,10 @@ func (s *WebSuite) createUser(t *testing.T, user string, login string, pass stri require.NoError(t, err) teleUser.AddRole(role.GetName()) + for _, r := range roles { + teleUser.AddRole(r) + } + teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 8f15b6a82bf51..bb836a182f5cc 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -35,8 +35,10 @@ import ( "golang.org/x/time/rate" authproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/services" ) func Test_runAssistant(t *testing.T) { @@ -168,7 +170,17 @@ func Test_runAssistant(t *testing.T) { tc.setup(t, s) } - ws, err := s.makeAssistant(t, s.authPack(t, "foo")) + assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindAssistant, services.RW()), + }, + }, + }) + require.NoError(t, err) + require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) + + ws, err := s.makeAssistant(t, s.authPack(t, "foo", assistRole.GetName())) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -186,7 +198,6 @@ func Test_runAssistant(t *testing.T) { tc.act(t, ws) }) } - } func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) { diff --git a/lib/web/command.go b/lib/web/command.go index aba2f3729da2a..f378c4fb27699 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -443,7 +443,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl return nil, trace.NotImplemented("MFA is not supported for command execution") } - //TODO(jakule): Implement MFA support + // TODO(jakule): Implement MFA support nc, err := t.connectToHost(ctx, t.ws, tc, mfaAuth) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failure connecting to host") diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 4cb3d7b0a904d..c7d5a40f39fa1 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -35,15 +35,32 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" ) +const ( + testCommand = "echo txlxport | sed 's/x/e/g'" + testUser = "foo" +) + func TestExecuteCommand(t *testing.T) { t.Parallel() s := newWebSuite(t) - ws, _, err := s.makeCommand(t, s.authPack(t, "foo")) + assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindAssistant, services.RW()), + }, + }, + }) + require.NoError(t, err) + require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) + + ws, _, err := s.makeCommand(t, s.authPack(t, testUser, assistRole.GetName()), uuid.New()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -52,13 +69,13 @@ func TestExecuteCommand(t *testing.T) { require.NoError(t, waitForCommandOutput(stream, "teleport")) } -func (s *WebSuite) makeCommand(t *testing.T, pack *authPack) (*websocket.Conn, *session.Session, error) { +func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid.UUID) (*websocket.Conn, *session.Session, error) { req := CommandRequest{ Query: fmt.Sprintf("name == \"%s\"", s.srvID), Login: pack.login, - ConversationID: uuid.New().String(), + ConversationID: conversationID.String(), ExecutionID: uuid.New().String(), - Command: "echo txlxport | sed 's/x/e/g'", + Command: testCommand, } u := url.URL{