From fa9280ebd8644ff489fb6ce892cffaab16cc1aa3 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Wed, 2 Aug 2023 16:55:01 -0400 Subject: [PATCH 1/3] Fix authorization rules to the Assistant and UserPreferences service (#29481) * "Add authorization rules to the Assistant and UserPreferences service" This commit introduces authorization rules into the Assistant service to restrict operations based on the authenticated user's role permissions. Now each method in the Assistant service checks if the authenticated user has necessary permissions to perform the requested operation. The permissions are checked via defined RBAC rules. A user requires specific permissions to perform various operations such as creating a conversation, updating a conversation, fetching a user's conversations, deleting a conversation, and adding a message to a conversation. Also, even if a user has necessary permissions, they cannot perform operations for a different user. Each user can only access their own data. * Add missing logger * "Refactor user preferences request handling" This commit refactors how GetUserPreferences and UpsertUserPreferences handle requests. The `username` field is removed from request parameters. Instead of having the client send the user's username in a request, the server now automatically uses the username of the authenticated user making the request. This change improves the security by preventing a user from attempting to fetch or manipulate another user's preferences. Removed tests were specifically testing the old, insecure behavior. * Refactor to use authz.HasBuiltinRole Refactored code in the 'auth_with_roles.go' file to use 'authz.HasBuiltinRole' instead of 'HasBuiltinRole'. This change is in line with recommended practices for deprecation and makes the code more standard and easier to manage. The original 'HasBuiltinRole' function is marked as deprecated and will be removed in future once 'teleport.e' is updated to use 'authz.HasBuiltinRole'. * Reserve removed username again? * Fix UT * Add local user permissions checks in authz This commit introduces two new methods in permissions.go to check if a user is a local user, and if a given action is performed by a local user. These permission checks are then used to replace existing checks in service.go, when performing actions like creating conversation, updating, listing, etc. This simplifies checks and provides a more consolidated and unified method for verifying user actions. * Fix tests * Tweak RBAC * Address review comments * Separate client and server interfaces for user preference services. * Apply core review suggestions * Apply suggestions from code review Co-authored-by: Brian Joerger --------- Co-authored-by: joerger --- lib/ai/testutils/http.go | 129 ++++++++++ lib/auth/assist/assistv1/service.go | 101 +++++++- lib/auth/assist/assistv1/service_test.go | 308 +++++++++++++++++++++++ lib/auth/auth_with_roles.go | 54 +--- lib/auth/grpcserver.go | 3 +- lib/authz/permissions.go | 24 ++ lib/web/command.go | 20 +- lib/web/command_test.go | 25 +- 8 files changed, 611 insertions(+), 53 deletions(-) create mode 100644 lib/ai/testutils/http.go create mode 100644 lib/auth/assist/assistv1/service_test.go diff --git a/lib/ai/testutils/http.go b/lib/ai/testutils/http.go new file mode 100644 index 0000000000000..98cf6a8ff21b2 --- /dev/null +++ b/lib/ai/testutils/http.go @@ -0,0 +1,129 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "encoding/json" + "net/http" + "strconv" + "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +// GetTestHandlerFn returns a handler function that can be used to OpenAI API used by +// the chat API. It takes a list of responses that will be returned in order. +func GetTestHandlerFn(t *testing.T, responses []string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || !(r.URL.Path == "/chat/completions") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return + } + + switch r.Header.Get("Accept") { + case "application/json; charset=utf-8", "application/json": + responses = messageResponse(w, r, t, responses) + case "text/event-stream": + responses = streamResponse(w, t, responses) + default: + http.Error(w, "Unexpected request", http.StatusBadRequest) + } + } +} + +func streamResponse(w http.ResponseWriter, t *testing.T, responses []string) []string { + w.Header().Set("Content-Type", "text/event-stream") + + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return responses + } + + resp := &openai.ChatCompletionStreamResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "completion", + Created: time.Now().Unix(), + Model: openai.GPT4, + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: responses[0], + Role: openai.ChatMessageRoleAssistant, + }, + FinishReason: "", + }, + }, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") + + _, err = w.Write([]byte("data: ")) + assert.NoError(t, err, "Write error") + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + _, err = w.Write([]byte("\n\nevent: done\ndata: [DONE]\n\n")) + assert.NoError(t, err, "Write error") + + return responses[1:] +} + +func messageResponse(w http.ResponseWriter, r *http.Request, t *testing.T, responses []string) []string { + w.Header().Set("Content-Type", "application/json") + + req := &openai.ChatCompletionRequest{} + err := json.NewDecoder(r.Body).Decode(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + + // Use assert as require doesn't work when called from a goroutine + if !assert.GreaterOrEqual(t, len(responses), 1, "Unexpected request") { + http.Error(w, "Unexpected request", http.StatusBadRequest) + return responses + } + + dataBytes := responses[0] + + resp := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + Model: req.Model, + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: dataBytes, + Name: "", + }, + }, + }, + Usage: openai.Usage{}, + } + + respBytes, err := json.Marshal(resp) + assert.NoError(t, err, "Marshal error") + + _, err = w.Write(respBytes) + assert.NoError(t, err, "Write error") + + return responses[1:] +} diff --git a/lib/auth/assist/assistv1/service.go b/lib/auth/assist/assistv1/service.go index dfc38666eaa0c..29a71d815875b 100644 --- a/lib/auth/assist/assistv1/service.go +++ b/lib/auth/assist/assistv1/service.go @@ -22,23 +22,31 @@ import ( "context" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/emptypb" + "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/services" ) // ServiceConfig holds configuration options for // the assist gRPC service. type ServiceConfig struct { - Backend services.Assistant + Backend services.Assistant + Authorizer authz.Authorizer + Logger *logrus.Entry } // Service implements the teleport.assist.v1.AssistService RPC service. type Service struct { assist.UnimplementedAssistServiceServer - backend services.Assistant + backend services.Assistant + authorizer authz.Authorizer + log *logrus.Entry } // NewService returns a new assist gRPC service. @@ -46,22 +54,46 @@ func NewService(cfg *ServiceConfig) (*Service, error) { switch { case cfg.Backend == nil: return nil, trace.BadParameter("backend is required") + case cfg.Authorizer == nil: + return nil, trace.BadParameter("authorizer is required") + case cfg.Logger == nil: + cfg.Logger = logrus.WithField(trace.Component, "assist.service") } return &Service{ - backend: cfg.Backend, + backend: cfg.Backend, + authorizer: cfg.Authorizer, + log: cfg.Logger, }, nil } // CreateAssistantConversation creates a new conversation entry in the backend. func (a *Service) CreateAssistantConversation(ctx context.Context, req *assist.CreateAssistantConversationRequest) (*assist.CreateAssistantConversationResponse, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbCreate) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to create conversation for user %q", authCtx.User.GetName(), req.Username) + } + resp, err := a.backend.CreateAssistantConversation(ctx, req) return resp, trace.Wrap(err) } // UpdateAssistantConversationInfo updates the conversation info for a conversation. -func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, request *assist.UpdateAssistantConversationInfoRequest) (*emptypb.Empty, error) { - err := a.backend.UpdateAssistantConversationInfo(ctx, request) +func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, req *assist.UpdateAssistantConversationInfoRequest) (*emptypb.Empty, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbUpdate) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to update conversation for user %q", authCtx.User.GetName(), req.Username) + } + + err = a.backend.UpdateAssistantConversationInfo(ctx, req) if err != nil { return &emptypb.Empty{}, trace.Wrap(err) } @@ -71,22 +103,79 @@ func (a *Service) UpdateAssistantConversationInfo(ctx context.Context, request * // GetAssistantConversations returns all conversations started by a user. func (a *Service) GetAssistantConversations(ctx context.Context, req *assist.GetAssistantConversationsRequest) (*assist.GetAssistantConversationsResponse, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbList) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to list conversations for user %q", authCtx.User.GetName(), req.GetUsername()) + } + resp, err := a.backend.GetAssistantConversations(ctx, req) return resp, trace.Wrap(err) } // GetAssistantMessages returns all messages with given conversation ID. func (a *Service) GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbRead) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to get messages for user %q", authCtx.User.GetName(), req.GetUsername()) + } + resp, err := a.backend.GetAssistantMessages(ctx, req) return resp, trace.Wrap(err) } // CreateAssistantMessage adds the message to the backend. func (a *Service) CreateAssistantMessage(ctx context.Context, req *assist.CreateAssistantMessageRequest) (*emptypb.Empty, error) { + authCtx, err := authz.AuthorizeWithVerbs(ctx, a.log, a.authorizer, true, types.KindAssistant, types.VerbCreate) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + if userHasAccess(authCtx, req) { + return nil, trace.AccessDenied("user %q is not allowed to create message for user %q", authCtx.User.GetName(), req.GetUsername()) + } + return &emptypb.Empty{}, trace.Wrap(a.backend.CreateAssistantMessage(ctx, req)) } // IsAssistEnabled returns true if the assist is enabled or not on the auth level. -func (a *Service) IsAssistEnabled(ctx context.Context, req *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) { +func (a *Service) IsAssistEnabled(ctx context.Context, _ *assist.IsAssistEnabledRequest) (*assist.IsAssistEnabledResponse, error) { + authCtx, err := a.authorizer.Authorize(ctx) + if err != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + + // Check if this endpoint is called by a user or Proxy. + if authz.IsLocalUser(*authCtx) { + checkErr := authCtx.Checker.CheckAccessToRule( + &services.Context{User: authCtx.User}, + defaults.Namespace, types.KindAssistant, types.VerbRead, + false, /* silent */ + ) + if checkErr != nil { + return nil, authz.ConvertAuthorizerError(ctx, a.log, err) + } + } else { + // This endpoint is called from Proxy to check if the assist is enabled. + // Proxy credentials are used instead of the user credentials. + requestedByProxy := authz.HasBuiltinRole(*authCtx, string(types.RoleProxy)) + if !requestedByProxy { + return nil, trace.AccessDenied("only proxy is allowed to call IsAssistEnabled endpoint") + } + } + + // Check if assist can use the backend. return a.backend.IsAssistEnabled(ctx) } + +// userHasAccess returns true if the user should have access to the resource. +func userHasAccess(authCtx *authz.Context, req interface{ GetUsername() string }) bool { + return !authz.IsCurrentUser(*authCtx, req.GetUsername()) && !authz.HasBuiltinRole(*authCtx, string(types.RoleAdmin)) +} diff --git a/lib/auth/assist/assistv1/service_test.go b/lib/auth/assist/assistv1/service_test.go new file mode 100644 index 0000000000000..5a837243fb02a --- /dev/null +++ b/lib/auth/assist/assistv1/service_test.go @@ -0,0 +1,308 @@ +/* + * Copyright 2023 Gravitational, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package assistv1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/tlsca" +) + +const ( + defaultUser = "test-user" + noAccessUser = "user-no-access" +) + +func TestService_CreateAssistantConversation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + username string + req *assistpb.CreateAssistantConversationRequest + wantErr assert.ErrorAssertionFunc + assertResponse func(t *testing.T, resp *assistpb.CreateAssistantConversationResponse) + }{ + { + name: "success", + username: defaultUser, + req: &assistpb.CreateAssistantConversationRequest{ + Username: defaultUser, + CreatedTime: timestamppb.Now(), + }, + wantErr: assert.NoError, + assertResponse: func(t *testing.T, resp *assistpb.CreateAssistantConversationResponse) { + require.NotEmpty(t, resp.GetId()) + }, + }, + { + name: "access denies - RBAC", + username: noAccessUser, + req: &assistpb.CreateAssistantConversationRequest{ + Username: noAccessUser, + CreatedTime: timestamppb.Now(), + }, + wantErr: assert.Error, + }, + { + name: "access denied - different user", + username: defaultUser, + req: &assistpb.CreateAssistantConversationRequest{ + Username: noAccessUser, + CreatedTime: timestamppb.Now(), + }, + wantErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxs, svc := initSvc(t) + + got, err := svc.CreateAssistantConversation(ctxs[tt.username], tt.req) + tt.wantErr(t, err) + + if tt.assertResponse != nil { + tt.assertResponse(t, got) + } + }) + } +} + +func TestService_GetAssistantConversations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + username string + req *assistpb.GetAssistantConversationsRequest + wantErr assert.ErrorAssertionFunc + assertResponse func(t *testing.T, resp *assistpb.CreateAssistantConversationResponse) + }{ + { + name: "success", + username: defaultUser, + req: &assistpb.GetAssistantConversationsRequest{ + Username: defaultUser, + }, + wantErr: assert.NoError, + }, + { + name: "access denies - RBAC", + username: noAccessUser, + req: &assistpb.GetAssistantConversationsRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + { + name: "access denied - different user", + username: defaultUser, + req: &assistpb.GetAssistantConversationsRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxs, svc := initSvc(t) + + _, err := svc.GetAssistantConversations(ctxs[tt.username], tt.req) + tt.wantErr(t, err) + }) + } +} + +func TestService_InsertAssistantMessage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + username string + req *assistpb.CreateAssistantMessageRequest + wantErr assert.ErrorAssertionFunc + }{ + { + name: "success", + username: defaultUser, + req: &assistpb.CreateAssistantMessageRequest{ + Username: defaultUser, + Message: &assistpb.AssistantMessage{ + Type: "CHAT_MESSAGE_ASSISTANT", + CreatedTime: timestamppb.Now(), + Payload: "Blah", + }, + }, + wantErr: assert.NoError, + }, + { + name: "access denies - RBAC", + username: noAccessUser, + req: &assistpb.CreateAssistantMessageRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + { + name: "access denied - different user", + username: defaultUser, + req: &assistpb.CreateAssistantMessageRequest{ + Username: noAccessUser, + }, + wantErr: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctxs, svc := initSvc(t) + + // Create a conversation that we can remove, so we don't hit "conversation doesn't exist" error + convMsg, err := svc.backend.CreateAssistantConversation(ctxs[tt.username], &assistpb.CreateAssistantConversationRequest{ + Username: tt.username, + CreatedTime: timestamppb.Now(), + }) + require.NoError(t, err) + + conversationID := convMsg.GetId() + + tt.req.ConversationId = conversationID + + _, err = svc.CreateAssistantMessage(ctxs[tt.username], tt.req) + tt.wantErr(t, err) + }) + } +} + +func initSvc(t *testing.T) (map[string]context.Context, *Service) { + ctx := context.Background() + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + + clusterConfigSvc, err := local.NewClusterConfigurationService(backend) + require.NoError(t, err) + trustSvc := local.NewCAService(backend) + roleSvc := local.NewAccessService(backend) + userSvc := local.NewIdentityService(backend) + + require.NoError(t, clusterConfigSvc.SetAuthPreference(ctx, types.DefaultAuthPreference())) + require.NoError(t, clusterConfigSvc.SetClusterAuditConfig(ctx, types.DefaultClusterAuditConfig())) + require.NoError(t, clusterConfigSvc.SetClusterNetworkingConfig(ctx, types.DefaultClusterNetworkingConfig())) + require.NoError(t, clusterConfigSvc.SetSessionRecordingConfig(ctx, types.DefaultSessionRecordingConfig())) + + accessPoint := struct { + services.ClusterConfiguration + services.Trust + services.RoleGetter + services.UserGetter + }{ + ClusterConfiguration: clusterConfigSvc, + Trust: trustSvc, + RoleGetter: roleSvc, + UserGetter: userSvc, + } + + accessService := local.NewAccessService(backend) + eventService := local.NewEventsService(backend) + lockWatcher, err := services.NewLockWatcher(ctx, services.LockWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Client: eventService, + Component: "test", + }, + LockGetter: accessService, + }) + require.NoError(t, err) + + authorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ + ClusterName: "test-cluster", + AccessPoint: accessPoint, + LockWatcher: lockWatcher, + }) + require.NoError(t, err) + + roles := map[string]types.Role{} + + role, err := types.NewRole("allow-rules", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + { + Resources: []string{types.KindAssistant}, + Verbs: []string{types.VerbList, types.VerbRead, types.VerbUpdate, types.VerbCreate, types.VerbDelete}, + }, + }, + }, + }) + require.NoError(t, err) + + roles[defaultUser] = role + + roleNoAccess, err := types.NewRole("no-rules", types.RoleSpecV6{ + Allow: types.RoleConditions{}, + }) + require.NoError(t, err) + roles["user-no-access"] = roleNoAccess + + ctxs := make(map[string]context.Context, len(roles)) + for username, role := range roles { + err = roleSvc.CreateRole(ctx, role) + require.NoError(t, err) + + user, err := types.NewUser(username) + user.AddRole(role.GetName()) + require.NoError(t, err) + + err = userSvc.CreateUser(user) + require.NoError(t, err) + + ctx = authz.ContextWithUser(ctx, authz.LocalUser{ + Username: user.GetName(), + Identity: tlsca.Identity{ + Username: user.GetName(), + Groups: []string{role.GetName()}, + }, + }) + ctxs[user.GetName()] = ctx + } + + svc, err := NewService(&ServiceConfig{ + Backend: local.NewAssistService(backend), + Authorizer: authorizer, + }) + require.NoError(t, err) + + return ctxs, svc +} + +type nodeGetterFake struct { +} + +func (g *nodeGetterFake) GetNode(ctx context.Context, namespace, name string) (types.Server, error) { + return nil, nil +} diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 6f8e805043fb7..fa83d048139bb 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -212,7 +212,7 @@ func (a *ServerWithRoles) serverAction() error { // whether any of the given roles match the role set. func (a *ServerWithRoles) hasBuiltinRole(roles ...types.SystemRole) bool { for _, role := range roles { - if HasBuiltinRole(a.context, string(role)) { + if authz.HasBuiltinRole(a.context, string(role)) { return true } } @@ -221,15 +221,11 @@ func (a *ServerWithRoles) hasBuiltinRole(roles ...types.SystemRole) bool { // HasBuiltinRole checks if the identity is a builtin role with the matching // name. +// Deprecated: use authz.HasBuiltinRole instead. func HasBuiltinRole(authContext authz.Context, name string) bool { - if _, ok := authContext.Identity.(authz.BuiltinRole); !ok { - return false - } - if !authContext.Checker.HasRole(name) { - return false - } - - return true + // TODO(jakule): This function can be removed once teleport.e is updated + // to use authz.HasBuiltinRole. + return authz.HasBuiltinRole(authContext, name) } // HasRemoteBuiltinRole checks if the identity is a remote builtin role with the @@ -5255,8 +5251,8 @@ func (a *ServerWithRoles) checkAccessToNode(node types.Server) error { // In addition, allow proxy (and remote proxy) to access all nodes for its // smart resolution address resolution. Once the smart resolution logic is // moved to the auth server, this logic can be removed. - builtinRole := HasBuiltinRole(a.context, string(types.RoleAdmin)) || - HasBuiltinRole(a.context, string(types.RoleProxy)) || + builtinRole := authz.HasBuiltinRole(a.context, string(types.RoleAdmin)) || + authz.HasBuiltinRole(a.context, string(types.RoleProxy)) || HasRemoteBuiltinRole(a.context, string(types.RoleRemoteProxy)) if builtinRole { @@ -6114,56 +6110,32 @@ func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, // CreateAssistantConversation creates a new conversation entry in the backend. func (a *ServerWithRoles) CreateAssistantConversation(ctx context.Context, req *assist.CreateAssistantConversationRequest) (*assist.CreateAssistantConversationResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbCreate); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.CreateAssistantConversation(ctx, req) + return nil, trace.NotImplemented("CreateAssistantConversation must not be called on auth.ServerWithRoles") } // GetAssistantConversations returns all conversations started by a user. func (a *ServerWithRoles) GetAssistantConversations(ctx context.Context, request *assist.GetAssistantConversationsRequest) (*assist.GetAssistantConversationsResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbList); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.GetAssistantConversations(ctx, request) + return nil, trace.NotImplemented("GetAssistantConversations must not be called on auth.ServerWithRoles") } // GetAssistantMessages returns all messages with given conversation ID. func (a *ServerWithRoles) GetAssistantMessages(ctx context.Context, req *assist.GetAssistantMessagesRequest) (*assist.GetAssistantMessagesResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbRead); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.GetAssistantMessages(ctx, req) + return nil, trace.NotImplemented("GetAssistantMessages must not be called on auth.ServerWithRoles") } // IsAssistEnabled returns true if the assist is enabled or not on the auth level. func (a *ServerWithRoles) IsAssistEnabled(ctx context.Context) (*assist.IsAssistEnabledResponse, error) { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbRead); err != nil { - return nil, trace.Wrap(err) - } - - return a.authServer.IsAssistEnabled(ctx) + return nil, trace.NotImplemented("IsAssistEnabled must not be called on auth.ServerWithRoles") } // CreateAssistantMessage adds the message to the backend. func (a *ServerWithRoles) CreateAssistantMessage(ctx context.Context, msg *assist.CreateAssistantMessageRequest) error { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbUpdate); err != nil { - return trace.Wrap(err) - } - - return a.authServer.CreateAssistantMessage(ctx, msg) + return trace.NotImplemented("CreateAssistantMessage must not be called on auth.ServerWithRoles") } // UpdateAssistantConversationInfo updates the conversation info. func (a *ServerWithRoles) UpdateAssistantConversationInfo(ctx context.Context, msg *assist.UpdateAssistantConversationInfoRequest) error { - if err := a.action(apidefaults.Namespace, types.KindAssistant, types.VerbUpdate); err != nil { - return trace.Wrap(err) - } - - return a.authServer.UpdateAssistantConversationInfo(ctx, msg) + return trace.NotImplemented("UpdateAssistantConversationInfo must not be called on auth.ServerWithRoles") } // CloneHTTPClient creates a new HTTP client with the same configuration. diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 9973b39db57fd..7e41b8560839b 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5144,7 +5144,8 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { // Initialize and register the assist service. assistSrv, err := assistv1.NewService(&assistv1.ServiceConfig{ - Backend: cfg.AuthServer.Services, + Backend: cfg.AuthServer.Services, + Authorizer: cfg.Authorizer, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index ec7d43b487d92..e3f784ce2984b 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -1294,3 +1294,27 @@ func UserFromContext(ctx context.Context) (IdentityGetter, error) { } return user, nil } + +// HasBuiltinRole checks if the identity is a builtin role with the matching +// name. +func HasBuiltinRole(authContext Context, name string) bool { + if _, ok := authContext.Identity.(BuiltinRole); !ok { + return false + } + if !authContext.Checker.HasRole(name) { + return false + } + + return true +} + +// IsLocalUser checks if the identity is a local user. +func IsLocalUser(authContext Context) bool { + _, ok := authContext.Identity.(LocalUser) + return ok +} + +// IsCurrentUser checks if the identity is a local user matching the given username +func IsCurrentUser(authContext Context, username string) bool { + return IsLocalUser(authContext) && authContext.User.GetName() == username +} diff --git a/lib/web/command.go b/lib/web/command.go index aba2f3729da2a..78209ea3675a4 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -64,6 +64,24 @@ type CommandRequest struct { ExecutionID string `json:"execution_id"` } +// commandExecResult is a result of a command execution. +type commandExecResult struct { + // NodeID is the ID of the node where the command was executed. + NodeID string `json:"node_id"` + // NodeName is the name of the node where the command was executed. + NodeName string `json:"node_name"` + // ExecutionID is a unique ID used to identify the command execution. + ExecutionID string `json:"execution_id"` + // SessionID is the ID of the session where the command was executed. + SessionID string `json:"session_id"` +} + +// sessionEndEvent is an event that is sent when a session ends. +type sessionEndEvent struct { + // NodeID is the ID of the server where the session was created. + NodeID string `json:"node_id"` +} + // Check checks if the request is valid. func (c *CommandRequest) Check() error { if c.Command == "" { @@ -443,7 +461,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl return nil, trace.NotImplemented("MFA is not supported for command execution") } - //TODO(jakule): Implement MFA support + // TODO(jakule): Implement MFA support nc, err := t.connectToHost(ctx, t.ws, tc, mfaAuth) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failure connecting to host") diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 4cb3d7b0a904d..46c9a9cb6c28f 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -35,15 +35,32 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" ) +const ( + testCommand = "echo txlxport | sed 's/x/e/g'" + testUser = "foo" +) + func TestExecuteCommand(t *testing.T) { t.Parallel() s := newWebSuite(t) - ws, _, err := s.makeCommand(t, s.authPack(t, "foo")) + assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindAssistant, services.RW()), + }, + }, + }) + require.NoError(t, err) + require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) + + ws, _, err := s.makeCommand(t, s.authPack(t, testUser), uuid.New()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -52,13 +69,13 @@ func TestExecuteCommand(t *testing.T) { require.NoError(t, waitForCommandOutput(stream, "teleport")) } -func (s *WebSuite) makeCommand(t *testing.T, pack *authPack) (*websocket.Conn, *session.Session, error) { +func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid.UUID) (*websocket.Conn, *session.Session, error) { req := CommandRequest{ Query: fmt.Sprintf("name == \"%s\"", s.srvID), Login: pack.login, - ConversationID: uuid.New().String(), + ConversationID: conversationID.String(), ExecutionID: uuid.New().String(), - Command: "echo txlxport | sed 's/x/e/g'", + Command: testCommand, } u := url.URL{ From 10d17e5dad92ee59b065d820d6807ff828d9d21d Mon Sep 17 00:00:00 2001 From: Mike Jensen Date: Tue, 31 Oct 2023 12:23:43 -0600 Subject: [PATCH 2/3] Fix errors after cherry-pick --- lib/auth/assist/assistv1/service_test.go | 7 ------- lib/web/assistant_test.go | 17 +++++++++++++++++ lib/web/command.go | 18 ------------------ 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/lib/auth/assist/assistv1/service_test.go b/lib/auth/assist/assistv1/service_test.go index 5a837243fb02a..4857dfac142b8 100644 --- a/lib/auth/assist/assistv1/service_test.go +++ b/lib/auth/assist/assistv1/service_test.go @@ -299,10 +299,3 @@ func initSvc(t *testing.T) (map[string]context.Context, *Service) { return ctxs, svc } - -type nodeGetterFake struct { -} - -func (g *nodeGetterFake) GetNode(ctx context.Context, namespace, name string) (types.Server, error) { - return nil, nil -} diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 8f15b6a82bf51..3d220fd313c69 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -35,8 +35,10 @@ import ( "golang.org/x/time/rate" authproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/services" ) func Test_runAssistant(t *testing.T) { @@ -168,6 +170,8 @@ func Test_runAssistant(t *testing.T) { tc.setup(t, s) } + allowAssistAccess(t, s) + ws, err := s.makeAssistant(t, s.authPack(t, "foo")) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -186,7 +190,20 @@ func Test_runAssistant(t *testing.T) { tc.act(t, ws) }) } +} + +func allowAssistAccess(t *testing.T, s *WebSuite) types.Role { + assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindAssistant, services.RW()), + }, + }, + }) + require.NoError(t, err) + require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) + return assistRole } func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) { diff --git a/lib/web/command.go b/lib/web/command.go index 78209ea3675a4..f378c4fb27699 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -64,24 +64,6 @@ type CommandRequest struct { ExecutionID string `json:"execution_id"` } -// commandExecResult is a result of a command execution. -type commandExecResult struct { - // NodeID is the ID of the node where the command was executed. - NodeID string `json:"node_id"` - // NodeName is the name of the node where the command was executed. - NodeName string `json:"node_name"` - // ExecutionID is a unique ID used to identify the command execution. - ExecutionID string `json:"execution_id"` - // SessionID is the ID of the session where the command was executed. - SessionID string `json:"session_id"` -} - -// sessionEndEvent is an event that is sent when a session ends. -type sessionEndEvent struct { - // NodeID is the ID of the server where the session was created. - NodeID string `json:"node_id"` -} - // Check checks if the request is valid. func (c *CommandRequest) Check() error { if c.Command == "" { From 12e27b2a279aa40ace5f4fefa460bc8044e179d5 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Fri, 3 Nov 2023 16:41:49 -0400 Subject: [PATCH 3/3] Fix UT --- lib/web/apiserver_test.go | 10 +++++++--- lib/web/assistant_test.go | 26 ++++++++++---------------- lib/web/command_test.go | 2 +- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5af2877d01148..3a07d1ce13fa3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -630,7 +630,7 @@ type authPack struct { // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. -func (s *WebSuite) authPack(t *testing.T, user string) *authPack { +func (s *WebSuite) authPack(t *testing.T, user string, roles ...string) *authPack { login := s.user pass := "abc123" rawSecret := "def456" @@ -644,7 +644,7 @@ func (s *WebSuite) authPack(t *testing.T, user string) *authPack { err = s.server.Auth().SetAuthPreference(s.ctx, ap) require.NoError(t, err) - s.createUser(t, user, login, pass, otpSecret) + s.createUser(t, user, login, pass, otpSecret, roles...) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) @@ -683,7 +683,7 @@ func (s *WebSuite) authPack(t *testing.T, user string) *authPack { } } -func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { +func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string, roles ...string) { teleUser, err := types.NewUser(user) require.NoError(t, err) role := services.RoleForUser(teleUser) @@ -695,6 +695,10 @@ func (s *WebSuite) createUser(t *testing.T, user string, login string, pass stri require.NoError(t, err) teleUser.AddRole(role.GetName()) + for _, r := range roles { + teleUser.AddRole(r) + } + teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 3d220fd313c69..bb836a182f5cc 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -170,9 +170,17 @@ func Test_runAssistant(t *testing.T) { tc.setup(t, s) } - allowAssistAccess(t, s) + assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Rules: []types.Rule{ + types.NewRule(types.KindAssistant, services.RW()), + }, + }, + }) + require.NoError(t, err) + require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) - ws, err := s.makeAssistant(t, s.authPack(t, "foo")) + ws, err := s.makeAssistant(t, s.authPack(t, "foo", assistRole.GetName())) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -192,20 +200,6 @@ func Test_runAssistant(t *testing.T) { } } -func allowAssistAccess(t *testing.T, s *WebSuite) types.Role { - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) - require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) - - return assistRole -} - func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) { u := url.URL{ Host: s.url().Host, diff --git a/lib/web/command_test.go b/lib/web/command_test.go index 46c9a9cb6c28f..c7d5a40f39fa1 100644 --- a/lib/web/command_test.go +++ b/lib/web/command_test.go @@ -60,7 +60,7 @@ func TestExecuteCommand(t *testing.T) { require.NoError(t, err) require.NoError(t, s.server.Auth().UpsertRole(s.ctx, assistRole)) - ws, _, err := s.makeCommand(t, s.authPack(t, testUser), uuid.New()) + ws, _, err := s.makeCommand(t, s.authPack(t, testUser, assistRole.GetName()), uuid.New()) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) })