diff --git a/e b/e index 071f45f6b51ea..d14cd6a16c0b4 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit 071f45f6b51eae9503ea7ccaf72652e2b5dcdddf +Subproject commit d14cd6a16c0b46bd7c2a3cdc83f5c4aeb49ce6c2 diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 7350ef6f23b5b..433b555c93e67 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -45,6 +45,7 @@ import ( type APIConfig struct { PluginRegistry plugin.Registry AuthServer *Server + SessionService session.Service AuditLog events.IAuditLog Authorizer Authorizer Emitter apievents.Emitter @@ -149,6 +150,12 @@ func NewAPIServer(config *APIConfig) (http.Handler, error) { // Active sessions srv.GET("/:version/namespaces/:namespace/sessions/:id/stream", srv.withAuth(srv.getSessionChunk)) srv.GET("/:version/namespaces/:namespace/sessions/:id/events", srv.withAuth(srv.getSessionEvents)) + // DELETE IN 12.0.0 + srv.POST("/:version/namespaces/:namespace/sessions", srv.withAuth(srv.createSession)) + srv.PUT("/:version/namespaces/:namespace/sessions/:id", srv.withAuth(srv.updateSession)) + srv.DELETE("/:version/namespaces/:namespace/sessions/:id", srv.withAuth(srv.deleteSession)) + srv.GET("/:version/namespaces/:namespace/sessions", srv.withAuth(srv.getSessions)) + srv.GET("/:version/namespaces/:namespace/sessions/:id", srv.withAuth(srv.getSession)) // Namespaces srv.POST("/:version/namespaces", srv.withAuth(srv.upsertNamespace)) @@ -214,6 +221,7 @@ func (s *APIServer) withAuth(handler HandlerWithAuthFunc) httprouter.Handle { auth := &ServerWithRoles{ authServer: s.AuthServer, context: *authContext, + sessions: s.SessionService, alog: s.AuthServer, } version := p.ByName("version") @@ -784,6 +792,82 @@ func (s *APIServer) deleteCertAuthority(auth ClientI, w http.ResponseWriter, r * return message(fmt.Sprintf("cert '%v' deleted", id)), nil } +type createSessionReq struct { + Session session.Session `json:"session"` +} + +func (s *APIServer) createSession(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req *createSessionReq + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + namespace := p.ByName("namespace") + if !types.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + req.Session.Namespace = namespace + if err := auth.CreateSession(r.Context(), req.Session); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + +type updateSessionReq struct { + Update session.UpdateRequest `json:"update"` +} + +func (s *APIServer) updateSession(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req *updateSessionReq + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + namespace := p.ByName("namespace") + if !types.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + req.Update.Namespace = namespace + if err := auth.UpdateSession(r.Context(), req.Update); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + +func (s *APIServer) deleteSession(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + err := auth.DeleteSession(r.Context(), p.ByName("namespace"), session.ID(p.ByName("id"))) + if err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + +func (s *APIServer) getSessions(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + namespace := p.ByName("namespace") + if !types.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + sessions, err := auth.GetSessions(r.Context(), namespace) + if err != nil { + return nil, trace.Wrap(err) + } + return sessions, nil +} + +func (s *APIServer) getSession(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + sid, err := session.ParseID(p.ByName("id")) + if err != nil { + return nil, trace.Wrap(err) + } + namespace := p.ByName("namespace") + if !types.IsValidNamespace(namespace) { + return nil, trace.BadParameter("invalid namespace %q", namespace) + } + se, err := auth.GetSession(r.Context(), namespace, *sid) + if err != nil { + return nil, trace.Wrap(err) + } + return se, nil +} + type validateOIDCAuthCallbackReq struct { Query url.Values `json:"query"` } diff --git a/lib/auth/apiserver_active_sessions_test.go b/lib/auth/apiserver_active_sessions_test.go new file mode 100644 index 0000000000000..72a63f910df38 --- /dev/null +++ b/lib/auth/apiserver_active_sessions_test.go @@ -0,0 +1,244 @@ +// Copyright 2021 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "sort" + "testing" + "time" + + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" + + "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestAPIServer_activeSessions_whereConditions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tlsServer := newTestTLSServer(t) + authServer := tlsServer.Auth() + + // - "admin" has permissions to access all active sessions + // - "alpaca" has permissions to access only their own active sessions + // Each user is assigned its corresponding role, plus whatever extra + // permissions are needed to run the scenario. + const admin = "admin" + const alpaca = "alpaca" + alpacaRole := services.RoleForUser(&types.UserV2{Metadata: types.Metadata{Name: alpaca}}) + alpacaRole.SetLogins(types.Allow, []string{alpaca}) + alpacaRole.SetRules(types.Allow, append(alpacaRole.GetRules(types.Allow), types.Rule{ + Resources: []string{"ssh_session"}, + // Allow all ssh_session verbs, deny rule below takes precedence. + Verbs: []string{"*"}, + })) + alpacaRole.SetRules(types.Deny, append(alpacaRole.GetRules(types.Deny), types.Rule{ + Resources: []string{"ssh_session"}, + Verbs: []string{"list", "read", "update", "delete"}, + Where: "!contains(ssh_session.participants, user.metadata.name)", + })) + _, err := CreateUser(authServer, alpaca, alpacaRole) + require.NoError(t, err) + + // Prepare clients. + adminClient, err := tlsServer.NewClient(TestAdmin()) + require.NoError(t, err) + alpacaClient, err := tlsServer.NewClient(TestUser(alpaca)) + require.NoError(t, err) + + // Prepare one session per user. + createSession := func(clt ClientI, user string) session.ID { + id := session.NewID() + now := time.Now() + + // Create initial session. + require.NoError(t, clt.CreateSession(ctx, session.Session{ + ID: id, + Namespace: apidefaults.Namespace, + TerminalParams: session.TerminalParams{ + W: 100, + H: 100, + }, + Login: user, + Created: now, + LastActive: now, + })) + + // Add parties, must be done via update. + // Usually the Node does this, in the test we are taking a shortcut and + // using admin due to its powerful permissions. + require.NoError(t, adminClient.UpdateSession(ctx, session.UpdateRequest{ + ID: id, + Namespace: apidefaults.Namespace, + Parties: &[]session.Party{ + {ID: session.NewID(), User: user}, + }, + })) + return id + } + adminSessionID := createSession(adminClient, admin) + alpacaSessionID := createSession(alpacaClient, alpaca) + + t.Run("GetSessions respects role conditions", func(t *testing.T) { + tests := []struct { + name string + clt ClientI + wantIDs []session.ID + }{ + { + name: admin, + clt: adminClient, + wantIDs: []session.ID{adminSessionID, alpacaSessionID}, + }, + { + name: alpaca, + clt: alpacaClient, + wantIDs: []session.ID{alpacaSessionID}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sessions, err := test.clt.GetSessions(ctx, apidefaults.Namespace) + require.NoError(t, err) + + got := make([]session.ID, len(sessions)) + for i, s := range sessions { + got[i] = s.ID + } + want := test.wantIDs + sort.Slice(got, func(i, j int) bool { return got[i] < got[j] }) + sort.Slice(want, func(i, j int) bool { return want[i] < want[j] }) + if diff := cmp.Diff(test.wantIDs, got); diff != "" { + t.Errorf("GetSessions() mismatch (-want +got):\n%s", diff) + } + }) + } + }) + + // Helper functions used by test cases below. + getSession := func(clt ClientI) func(id session.ID) error { + return func(id session.ID) error { + _, err := clt.GetSession(ctx, apidefaults.Namespace, id) + return err + } + } + updateSession := func(clt ClientI) func(id session.ID) error { + return func(id session.ID) error { + return clt.UpdateSession(ctx, session.UpdateRequest{ + ID: id, + Namespace: apidefaults.Namespace, + TerminalParams: &session.TerminalParams{W: 150, H: 150}, + }) + } + } + deleteSession := func(clt ClientI) func(id session.ID) error { + return func(id session.ID) error { + return clt.UpdateSession(ctx, session.UpdateRequest{ + ID: id, + Namespace: apidefaults.Namespace, + TerminalParams: &session.TerminalParams{W: 150, H: 150}, + }) + } + } + + t.Run("users can't interact with denied sessions", func(t *testing.T) { + clt := alpacaClient + sessionID := adminSessionID + tests := []struct { + name string + fn func(id session.ID) error + }{ + { + name: "GetSession", + fn: getSession(clt), + }, + { + name: "UpdateSession", + fn: updateSession(clt), + }, + { + name: "DeleteSession", + fn: deleteSession(clt), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := test.fn(sessionID) + require.True(t, trace.IsAccessDenied(err), "unexpected err: %v (want access denied)", err) + }) + } + }) + + t.Run("users can interact with allowed sessions", func(t *testing.T) { + tests := []struct { + name string + fn func(session.ID) error + sessionID session.ID + }{ + { + name: "admin reads own session", + fn: getSession(adminClient), + sessionID: adminSessionID, + }, + { + name: "admin updates own session", + fn: updateSession(adminClient), + sessionID: adminSessionID, + }, + { + name: "admin deletes own session", + fn: deleteSession(adminClient), + sessionID: adminSessionID, + }, + { + name: "admin reads alpaca session", + fn: getSession(adminClient), + sessionID: alpacaSessionID, + }, + { + name: "admin updates alpaca session", + fn: updateSession(adminClient), + sessionID: alpacaSessionID, + }, + + { + name: "alpaca reads own session", + fn: getSession(alpacaClient), + sessionID: alpacaSessionID, + }, + { + name: "alpaca updates own session", + fn: updateSession(alpacaClient), + sessionID: alpacaSessionID, + }, + { + name: "alpaca deletes own session", + fn: deleteSession(alpacaClient), + sessionID: alpacaSessionID, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require.NoError(t, test.fn(test.sessionID)) + }) + } + }) +} diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 5afb750d95b10..e8298a49442b1 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -52,6 +52,7 @@ import ( // methods that focuses on authorizing every request type ServerWithRoles struct { authServer *Server + sessions session.Service alog events.IAuditLog // context holds authorization context context Context @@ -185,6 +186,19 @@ func (a *ServerWithRoles) actionForKindSession(namespace, verb string, sid sessi return trace.Wrap(a.actionWithExtendedContext(namespace, types.KindSession, verb, extendContext)) } +// actionForKindSSHSession is a special checker that grants access to active SSH +// sessions. It can allow access to a specific session based on the `where` +// section of the user's access rule for kind `ssh_session`. +// DELETE IN 12.0.0 +func (a *ServerWithRoles) actionForKindSSHSession(ctx context.Context, namespace, verb string, sid session.ID) error { + extendContext := func(serviceContext *services.Context) error { + session, err := a.sessions.GetSession(ctx, namespace, sid) + serviceContext.SSHSession = session + return trace.Wrap(err) + } + return trace.Wrap(a.actionWithExtendedContext(namespace, types.KindSSHSession, verb, extendContext)) +} + // serverAction returns an access denied error if the role is not one of the builtin server roles. func (a *ServerWithRoles) serverAction() error { role, ok := a.context.Identity.(BuiltinRole) @@ -492,6 +506,67 @@ func (a *ServerWithRoles) AuthenticateSSHUser(ctx context.Context, req Authentic return a.authServer.AuthenticateSSHUser(ctx, req) } +// DELETE IN 12.0.0 +func (a *ServerWithRoles) GetSessions(ctx context.Context, namespace string) ([]session.Session, error) { + cond, err := a.actionForListWithCondition(namespace, types.KindSSHSession, services.SSHSessionIdentifier) + if err != nil { + return nil, trace.Wrap(err) + } + + sessions, err := a.sessions.GetSessions(ctx, namespace) + if err != nil { + return nil, trace.Wrap(err) + } + if cond == nil { + return sessions, nil + } + + // Filter sessions according to cond. + filteredSessions := make([]session.Session, 0, len(sessions)) + ruleCtx := &services.Context{User: a.context.User} + for _, s := range sessions { + ruleCtx.SSHSession = &s + if err := a.context.Checker.CheckAccessToRule(ruleCtx, namespace, types.KindSSHSession, types.VerbList, true /* silent */); err != nil { + continue + } + filteredSessions = append(filteredSessions, s) + } + return filteredSessions, nil +} + +// DELETE IN 12.0.0 +func (a *ServerWithRoles) GetSession(ctx context.Context, namespace string, id session.ID) (*session.Session, error) { + if err := a.actionForKindSSHSession(ctx, namespace, types.VerbRead, id); err != nil { + return nil, trace.Wrap(err) + } + return a.sessions.GetSession(ctx, namespace, id) +} + +// DELETE IN 12.0.0 +func (a *ServerWithRoles) CreateSession(ctx context.Context, s session.Session) error { + if err := a.action(s.Namespace, types.KindSSHSession, types.VerbCreate); err != nil { + return trace.Wrap(err) + } + return a.sessions.CreateSession(ctx, s) +} + +// DELETE IN 12.0.0 +func (a *ServerWithRoles) UpdateSession(ctx context.Context, req session.UpdateRequest) error { + if err := a.actionForKindSSHSession(ctx, req.Namespace, types.VerbUpdate, req.ID); err != nil { + return trace.Wrap(err) + } + return a.sessions.UpdateSession(ctx, req) +} + +// DeleteSession removes an active session from the backend. +// DELETE IN 12.0.0 +func (a *ServerWithRoles) DeleteSession(ctx context.Context, namespace string, id session.ID) error { + if err := a.actionForKindSSHSession(ctx, namespace, types.VerbDelete, id); err != nil { + return trace.Wrap(err) + } + return a.sessions.DeleteSession(ctx, namespace, id) +} + // CreateCertAuthority not implemented: can only be called locally. func (a *ServerWithRoles) CreateCertAuthority(ca types.CertAuthority) error { return trace.NotImplemented(notImplementedMessage) @@ -5150,7 +5225,7 @@ func (a *ServerWithRoles) MaintainSessionPresence(ctx context.Context) (proto.Au // NewAdminAuthServer returns auth server authorized as admin, // used for auth server cached access -func NewAdminAuthServer(authServer *Server, alog events.IAuditLog) (ClientI, error) { +func NewAdminAuthServer(authServer *Server, sessions session.Service, alog events.IAuditLog) (ClientI, error) { ctx, err := NewAdminContext() if err != nil { return nil, trace.Wrap(err) @@ -5159,6 +5234,7 @@ func NewAdminAuthServer(authServer *Server, alog events.IAuditLog) (ClientI, err authServer: authServer, context: *ctx, alog: alog, + sessions: sessions, }, nil } diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index d3b6a708d4a2d..0fe687cd36d3f 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -1677,6 +1677,7 @@ func serverWithAllowRules(t *testing.T, srv *TestAuthServer, allowRules []types. return &ServerWithRoles{ authServer: srv.AuthServer, + sessions: srv.SessionServer, alog: srv.AuditLog, context: *authContext, } @@ -2253,6 +2254,7 @@ func TestReplaceRemoteLocksRBAC(t *testing.T) { s := &ServerWithRoles{ authServer: srv.AuthServer, + sessions: srv.SessionServer, alog: srv.AuditLog, context: *authContext, } @@ -2446,6 +2448,7 @@ func TestKindClusterConfig(t *testing.T) { require.NoError(t, err, trace.DebugReport(err)) s := &ServerWithRoles{ authServer: srv.AuthServer, + sessions: srv.SessionServer, alog: srv.AuditLog, context: *authContext, } @@ -3047,6 +3050,7 @@ func TestListResources_KindKubernetesCluster(t *testing.T) { s := &ServerWithRoles{ authServer: srv.AuthServer, + sessions: srv.SessionServer, alog: srv.AuditLog, context: *authContext, } diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 067f5bab82949..b1571c84f55bf 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -291,6 +291,75 @@ func (c *Client) ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error) { return &re, nil } +// GetSessions returns a list of active sessions in the cluster as reported by +// the auth server. +// DELETE IN 12.0.0 +func (c *Client) GetSessions(ctx context.Context, namespace string) ([]session.Session, error) { + if namespace == "" { + return nil, trace.BadParameter(MissingNamespaceError) + } + out, err := c.Get(ctx, c.Endpoint("namespaces", namespace, "sessions"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var sessions []session.Session + if err := json.Unmarshal(out.Bytes(), &sessions); err != nil { + return nil, err + } + return sessions, nil +} + +// GetSession returns a session by ID +// DELETE IN 12.0.0 +func (c *Client) GetSession(ctx context.Context, namespace string, id session.ID) (*session.Session, error) { + if namespace == "" { + return nil, trace.BadParameter(MissingNamespaceError) + } + // saving extra round-trip + if err := id.Check(); err != nil { + return nil, trace.Wrap(err) + } + out, err := c.Get(ctx, c.Endpoint("namespaces", namespace, "sessions", string(id)), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var sess *session.Session + if err := json.Unmarshal(out.Bytes(), &sess); err != nil { + return nil, trace.Wrap(err) + } + return sess, nil +} + +// DeleteSession removes an active session from the backend. +// DELETE IN 12.0.0 +func (c *Client) DeleteSession(ctx context.Context, namespace string, id session.ID) error { + if namespace == "" { + return trace.BadParameter(MissingNamespaceError) + } + _, err := c.Delete(ctx, c.Endpoint("namespaces", namespace, "sessions", string(id))) + return trace.Wrap(err) +} + +// CreateSession creates new session +// DELETE IN 12.0.0 +func (c *Client) CreateSession(ctx context.Context, sess session.Session) error { + if sess.Namespace == "" { + return trace.BadParameter(MissingNamespaceError) + } + _, err := c.PostJSON(ctx, c.Endpoint("namespaces", sess.Namespace, "sessions"), createSessionReq{Session: sess}) + return trace.Wrap(err) +} + +// UpdateSession updates existing session +// DELETE IN 12.0.0 +func (c *Client) UpdateSession(ctx context.Context, req session.UpdateRequest) error { + if err := req.Check(); err != nil { + return trace.Wrap(err) + } + _, err := c.PutJSON(ctx, c.Endpoint("namespaces", req.Namespace, "sessions", string(req.ID)), updateSessionReq{Update: req}) + return trace.Wrap(err) +} + func (c *Client) Close() error { c.HTTPClient.Close() return c.APIClient.Close() @@ -1606,6 +1675,7 @@ type ClientI interface { services.WindowsDesktops WebService services.Status + session.Service services.ClusterConfiguration services.SessionTrackerService services.ConnectionsDiagnostic diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 99ae0e65feaf8..40af361b6f4ee 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -4381,6 +4381,7 @@ func serverWithNopRole(cfg GRPCServerConfig) (*ServerWithRoles, error) { return &ServerWithRoles{ authServer: cfg.AuthServer, context: *nopCtx, + sessions: cfg.SessionService, alog: cfg.AuthServer, }, nil } @@ -4420,6 +4421,7 @@ func (g *GRPCServer) authenticate(ctx context.Context) (*grpcContext, error) { ServerWithRoles: &ServerWithRoles{ authServer: g.AuthServer, context: *authContext, + sessions: g.SessionService, alog: g.AuthServer, }, }, nil diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index f7deae744b0ef..9788b1ff07bca 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -49,6 +49,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/services/suite" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -182,6 +183,8 @@ type TestAuthServer struct { AuthServer *Server // AuditLog is an event audit log AuditLog events.IAuditLog + // SessionServer is a session service + SessionServer session.Service // Backend is a backend for auth server Backend backend.Backend // Authorizer is an authorizer used in tests @@ -227,6 +230,11 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { srv.AuditLog = localLog } + srv.SessionServer, err = session.New(srv.Backend) + if err != nil { + return nil, trace.Wrap(err) + } + access := local.NewAccessService(srv.Backend) identity := local.NewIdentityService(srv.Backend) @@ -537,10 +545,11 @@ func (a *TestAuthServer) Trust(ctx context.Context, remote *TestAuthServer, role // NewTestTLSServer returns new test TLS server func (a *TestAuthServer) NewTestTLSServer() (*TestTLSServer, error) { apiConfig := &APIConfig{ - AuthServer: a.AuthServer, - Authorizer: a.Authorizer, - AuditLog: a.AuditLog, - Emitter: a.AuthServer.emitter, + AuthServer: a.AuthServer, + Authorizer: a.Authorizer, + SessionService: a.SessionServer, + AuditLog: a.AuditLog, + Emitter: a.AuthServer.emitter, } srv, err := NewTestTLSServer(TestTLSServerConfig{ APIConfig: apiConfig, @@ -653,7 +662,7 @@ func NewTestTLSServer(cfg TestTLSServerConfig) (*TestTLSServer, error) { } tlsConfig.Time = cfg.AuthServer.Clock().Now - accessPoint, err := NewAdminAuthServer(srv.AuthServer.AuthServer, srv.AuthServer.AuditLog) + accessPoint, err := NewAdminAuthServer(srv.AuthServer.AuthServer, srv.AuthServer.SessionServer, srv.AuthServer.AuditLog) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/service/service.go b/lib/service/service.go index 297f108962a78..2c6c328ee59ad 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -98,6 +98,7 @@ import ( "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/alpnproxy" alpnproxyauth "github.com/gravitational/teleport/lib/srv/alpnproxy/auth" @@ -1632,12 +1633,17 @@ func (process *TeleportProcess) initAuthService() error { // second, create the API Server: it's actually a collection of API servers, // each serving requests for a "role" which is assigned to every connected // client based on their certificate (user, server, admin, etc) + sessionService, err := session.New(b) + if err != nil { + return trace.Wrap(err) + } authorizer, err := auth.NewAuthorizer(clusterName, authServer, lockWatcher) if err != nil { return trace.Wrap(err) } apiConf := &auth.APIConfig{ AuthServer: authServer, + SessionService: sessionService, Authorizer: authorizer, AuditLog: process.auditLog, PluginRegistry: process.PluginRegistry, diff --git a/lib/session/session.go b/lib/session/session.go index f87c722345cc3..f863e8cd3f288 100644 --- a/lib/session/session.go +++ b/lib/session/session.go @@ -19,15 +19,21 @@ limitations under the License. package session import ( + "context" + "encoding/json" "fmt" + "sort" "strconv" "strings" "time" "github.com/google/uuid" + "github.com/jonboulle/clockwork" "github.com/moby/term" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/trace" ) @@ -193,10 +199,298 @@ func (p *TerminalParams) Winsize() *term.Winsize { } } +// UpdateRequest is a session update request +// DELETE IN 12.0.0 +type UpdateRequest struct { + ID ID `json:"id"` + Namespace string `json:"namespace"` + TerminalParams *TerminalParams `json:"terminal_params"` + + // Parties allows to update the list of session parties. nil means + // "do not update", empty list means "everybody is gone" + Parties *[]Party `json:"parties"` +} + +// Check returns nil if request is valid, error otherwize +func (u *UpdateRequest) Check() error { + if err := u.ID.Check(); err != nil { + return trace.Wrap(err) + } + if u.Namespace == "" { + return trace.BadParameter("missing parameter Namespace") + } + if u.TerminalParams != nil { + _, err := NewTerminalParamsFromInt(u.TerminalParams.W, u.TerminalParams.H) + if err != nil { + return trace.Wrap(err) + } + } + return nil +} + // MaxSessionSliceLength is the maximum number of sessions per time window // that the backend will return. const MaxSessionSliceLength = 1000 +// Service is a realtime SSH session service that has information about +// sessions that are in-flight in the cluster at the moment. +// DELETE IN 12.0.0 +type Service interface { + // GetSessions returns a list of currently active sessions matching + // the given condition. + GetSessions(ctx context.Context, namespace string) ([]Session, error) + + // GetSession returns a session with its parties by ID. + GetSession(ctx context.Context, namespace string, id ID) (*Session, error) + + // CreateSession creates a new active session and it's parameters if term is + // skipped, terminal size won't be recorded. + CreateSession(ctx context.Context, sess Session) error + + // UpdateSession updates certain session parameters (last_active, terminal + // parameters) other parameters will not be updated. + UpdateSession(ctx context.Context, req UpdateRequest) error + + // DeleteSession removes an active session from the backend. + DeleteSession(ctx context.Context, namespace string, id ID) error +} + +// DELETE IN 12.0.0 +type server struct { + bk backend.Backend + activeSessionTTL time.Duration + clock clockwork.Clock +} + +// New returns new session server that uses sqlite to manage +// active sessions +func New(bk backend.Backend) (Service, error) { + s := &server{ + bk: bk, + clock: clockwork.NewRealClock(), + } + if s.activeSessionTTL == 0 { + s.activeSessionTTL = defaults.ActiveSessionTTL + } + return s, nil +} + +func activePrefix(namespace string) []byte { + return backend.Key("namespaces", namespace, "sessions", "active") +} + +func activeKey(namespace string, key string) []byte { + return backend.Key("namespaces", namespace, "sessions", "active", key) +} + +// GetSessions returns a list of active sessions. +// Returns an empty slice if no sessions are active +func (s *server) GetSessions(ctx context.Context, namespace string) ([]Session, error) { + prefix := activePrefix(namespace) + result, err := s.bk.GetRange(ctx, prefix, backend.RangeEnd(prefix), MaxSessionSliceLength) + if err != nil { + return nil, trace.Wrap(err) + } + + sessions := make(Sessions, len(result.Items)) + for i, item := range result.Items { + if err := json.Unmarshal(item.Value, &sessions[i]); err != nil { + return nil, trace.Wrap(err) + } + } + + sort.Stable(sessions) + return sessions, nil +} + +// Sessions type is created over []Session to implement sort.Interface to +// be able to sort sessions by creation time +type Sessions []Session + +// Swap is part of sort.Interface implementation for []Session +func (slice Sessions) Swap(i, j int) { + s := slice[i] + slice[i] = slice[j] + slice[j] = s +} + +// Less is part of sort.Interface implementation for []Session +func (slice Sessions) Less(i, j int) bool { + return slice[i].Created.Before(slice[j].Created) +} + +// Len is part of sort.Interface implementation for []Session +func (slice Sessions) Len() int { + return len(slice) +} + +// GetSession returns the session by its id. Returns NotFound if a session +// is not found +func (s *server) GetSession(ctx context.Context, namespace string, id ID) (*Session, error) { + item, err := s.bk.Get(ctx, activeKey(namespace, string(id))) + if err != nil { + if trace.IsNotFound(err) { + return nil, trace.NotFound("session(%v, %v) is not found", namespace, id) + } + return nil, trace.Wrap(err) + } + var sess Session + if err := json.Unmarshal(item.Value, &sess); err != nil { + return nil, trace.Wrap(err) + } + return &sess, nil +} + +// CreateSession creates a new session if it does not exist, if the session +// exists the function will return AlreadyExists error +// The session will be marked as active for TTL period of time +func (s *server) CreateSession(ctx context.Context, sess Session) error { + if err := sess.ID.Check(); err != nil { + return trace.Wrap(err) + } + if sess.Namespace == "" { + return trace.BadParameter("session namespace can not be empty") + } + if sess.Login == "" { + return trace.BadParameter("session login can not be empty") + } + if sess.Created.IsZero() { + return trace.BadParameter("created can not be empty") + } + if sess.LastActive.IsZero() { + return trace.BadParameter("last_active can not be empty") + } + _, err := NewTerminalParamsFromInt(sess.TerminalParams.W, sess.TerminalParams.H) + if err != nil { + return trace.Wrap(err) + } + sess.Parties = nil + data, err := json.Marshal(sess) + if err != nil { + return trace.Wrap(err) + } + item := backend.Item{ + Key: activeKey(sess.Namespace, string(sess.ID)), + Value: data, + Expires: s.clock.Now().UTC().Add(s.activeSessionTTL), + } + _, err = s.bk.Create(ctx, item) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +const ( + sessionUpdateAttempts = 10 + sessionUpdateRetryPeriod = 20 * time.Millisecond +) + +// UpdateSession updates session parameters - can mark it as inactive and update its terminal parameters +func (s *server) UpdateSession(ctx context.Context, req UpdateRequest) error { + if err := req.Check(); err != nil { + return trace.Wrap(err) + } + + key := activeKey(req.Namespace, string(req.ID)) + + // Try several times, then give up + for i := 0; i < sessionUpdateAttempts; i++ { + item, err := s.bk.Get(ctx, key) + if err != nil { + return trace.Wrap(err) + } + + var session Session + if err := json.Unmarshal(item.Value, &session); err != nil { + return trace.Wrap(err) + } + + if req.TerminalParams != nil { + session.TerminalParams = *req.TerminalParams + } + if req.Parties != nil { + session.Parties = *req.Parties + } + newValue, err := json.Marshal(session) + if err != nil { + return trace.Wrap(err) + } + newItem := backend.Item{ + Key: key, + Value: newValue, + Expires: s.clock.Now().UTC().Add(s.activeSessionTTL), + } + + _, err = s.bk.CompareAndSwap(ctx, *item, newItem) + if err != nil { + if trace.IsCompareFailed(err) || trace.IsConnectionProblem(err) { + s.clock.Sleep(sessionUpdateRetryPeriod) + continue + } + return trace.Wrap(err) + } + return nil + } + return trace.ConnectionProblem(nil, "failed concurrently update the session") +} + +// DeleteSession removes an active session from the backend. +func (s *server) DeleteSession(ctx context.Context, namespace string, id ID) error { + if !types.IsValidNamespace(namespace) { + return trace.BadParameter("invalid namespace %q", namespace) + } + err := id.Check() + if err != nil { + return trace.Wrap(err) + } + + err = s.bk.Delete(ctx, activeKey(namespace, string(id))) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +// discardSessionServer discards all information about sessions given to it. +// DELETE IN 12.0.0 +type discardSessionServer struct { +} + +// NewDiscardSessionServer returns a new discarding session server. It's used +// with the recording proxy so that nodes don't register active sessions to +// the backend. +// DELETE IN 12.0.0 +func NewDiscardSessionServer() Service { + return &discardSessionServer{} +} + +// GetSessions returns an empty list of sessions. +func (d *discardSessionServer) GetSessions(ctx context.Context, namespace string) ([]Session, error) { + return []Session{}, nil +} + +// GetSession always returns a zero session. +func (d *discardSessionServer) GetSession(ctx context.Context, namespace string, id ID) (*Session, error) { + return &Session{}, nil +} + +// CreateSession always returns nil, does nothing. +func (d *discardSessionServer) CreateSession(ctx context.Context, sess Session) error { + return nil +} + +// UpdateSession always returns nil, does nothing. +func (d *discardSessionServer) UpdateSession(ctx context.Context, req UpdateRequest) error { + return nil +} + +// DeleteSession removes an active session from the backend. +func (d *discardSessionServer) DeleteSession(ctx context.Context, namespace string, id ID) error { + return nil +} + // NewTerminalParamsFromUint32 returns new terminal parameters from uint32 width and height func NewTerminalParamsFromUint32(w uint32, h uint32) (*TerminalParams, error) { if w > maxSize || w < minSize { diff --git a/lib/session/session_test.go b/lib/session/session_test.go new file mode 100644 index 0000000000000..83e7b7e4fd45e --- /dev/null +++ b/lib/session/session_test.go @@ -0,0 +1,227 @@ +/* +Copyright 2015 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package session + +import ( + "context" + "os" + "testing" + "time" + + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/trace" +) + +func TestMain(m *testing.M) { + utils.InitLoggerForTests() + os.Exit(m.Run()) +} + +func TestSessions(t *testing.T) { + s := newsessionSuite(t) + t.Cleanup(func() { s.TearDown(t) }) + + t.Run("TestID", s.TestID) + t.Run("TestSessionsCRUD", s.TestSessionsCRUD) + t.Run("TestSessionsInactivity", s.TestSessionsInactivity) + t.Run("TestPartiesCRUD", s.TestPartiesCRUD) +} + +type sessionSuite struct { + dir string + srv *server + bk backend.Backend + clock clockwork.FakeClock +} + +func newsessionSuite(t *testing.T) *sessionSuite { + var err error + s := &sessionSuite{} + + s.clock = clockwork.NewFakeClockAt(time.Date(2016, 9, 8, 7, 6, 5, 0, time.UTC)) + s.dir = t.TempDir() + s.bk, err = lite.NewWithConfig(context.TODO(), + lite.Config{ + Path: s.dir, + Clock: s.clock, + }, + ) + require.NoError(t, err) + + srv, err := New(s.bk) + require.NoError(t, err) + srv.(*server).clock = s.clock + s.srv = srv.(*server) + return s +} + +func (s *sessionSuite) TearDown(t *testing.T) { + require.NoError(t, s.bk.Close()) +} + +func (s *sessionSuite) TestID(t *testing.T) { + id := NewID() + id2, err := ParseID(id.String()) + require.NoError(t, err) + require.Equal(t, id, *id2) + + for _, val := range []string{"garbage", "", " ", string(id) + "extra"} { + id := ID(val) + require.Error(t, id.Check()) + } +} + +func (s *sessionSuite) TestSessionsCRUD(t *testing.T) { + ctx := context.Background() + out, err := s.srv.GetSessions(ctx, apidefaults.Namespace) + require.NoError(t, err) + require.Empty(t, out) + + // Create session. + sess := Session{ + ID: NewID(), + Namespace: apidefaults.Namespace, + TerminalParams: TerminalParams{W: 100, H: 100}, + Login: "bob", + LastActive: s.clock.Now().UTC(), + Created: s.clock.Now().UTC(), + } + require.NoError(t, s.srv.CreateSession(ctx, sess)) + + // Make sure only one session exists. + out, err = s.srv.GetSessions(ctx, apidefaults.Namespace) + require.NoError(t, err) + require.Equal(t, out, []Session{sess}) + + // Make sure the session is the one created above. + s2, err := s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.NoError(t, err) + require.Equal(t, s2, &sess) + + // Update session terminal parameter + err = s.srv.UpdateSession(ctx, UpdateRequest{ + ID: sess.ID, + Namespace: apidefaults.Namespace, + TerminalParams: &TerminalParams{W: 101, H: 101}, + }) + require.NoError(t, err) + + // Verify update was applied. + sess.TerminalParams = TerminalParams{W: 101, H: 101} + s2, err = s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.NoError(t, err) + require.Equal(t, s2, &sess) + + // Remove the session. + err = s.srv.DeleteSession(ctx, apidefaults.Namespace, sess.ID) + require.NoError(t, err) + + // Make sure session no longer exists. + _, err = s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.Error(t, err) +} + +// TestSessionsInactivity makes sure that session will be marked +// as inactive after period of inactivity +func (s *sessionSuite) TestSessionsInactivity(t *testing.T) { + ctx := context.Background() + sess := Session{ + ID: NewID(), + Namespace: apidefaults.Namespace, + TerminalParams: TerminalParams{W: 100, H: 100}, + Login: "bob", + LastActive: s.clock.Now().UTC(), + Created: s.clock.Now().UTC(), + } + require.NoError(t, s.srv.CreateSession(ctx, sess)) + + // move forward in time: + s.clock.Advance(defaults.ActiveSessionTTL + time.Second) + + // should not be in active sessions: + s2, err := s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.IsType(t, trace.NotFound(""), err) + require.Nil(t, s2) +} + +func (s *sessionSuite) TestPartiesCRUD(t *testing.T) { + ctx := context.Background() + + // create session: + sess := Session{ + ID: NewID(), + Namespace: apidefaults.Namespace, + TerminalParams: TerminalParams{W: 100, H: 100}, + Login: "vincent", + LastActive: s.clock.Now().UTC(), + Created: s.clock.Now().UTC(), + } + err := s.srv.CreateSession(ctx, sess) + require.NoError(t, err) + // add two people: + parties := []Party{ + { + ID: NewID(), + RemoteAddr: "1_remote_addr", + User: "first", + ServerID: "luna", + LastActive: s.clock.Now().UTC(), + }, + { + ID: NewID(), + RemoteAddr: "2_remote_addr", + User: "second", + ServerID: "luna", + LastActive: s.clock.Now().UTC(), + }, + } + err = s.srv.UpdateSession(ctx, UpdateRequest{ + ID: sess.ID, + Namespace: apidefaults.Namespace, + Parties: &parties, + }) + require.NoError(t, err) + // verify they're in the session: + copy, err := s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.NoError(t, err) + require.Len(t, copy.Parties, 2) + + // empty update (list of parties must not change) + err = s.srv.UpdateSession(ctx, UpdateRequest{ID: sess.ID, Namespace: apidefaults.Namespace}) + require.NoError(t, err) + copy, _ = s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.Len(t, copy.Parties, 2) + + // remove the 2nd party: + deleted := copy.RemoveParty(parties[1].ID) + require.True(t, deleted) + err = s.srv.UpdateSession(ctx, UpdateRequest{ID: copy.ID, Parties: ©.Parties, Namespace: apidefaults.Namespace}) + require.NoError(t, err) + copy, _ = s.srv.GetSession(ctx, apidefaults.Namespace, sess.ID) + require.Len(t, copy.Parties, 1) + + // we still have the 1st party in: + require.Equal(t, parties[0].ID, copy.Parties[0].ID) +} diff --git a/lib/srv/mock.go b/lib/srv/mock.go index d8ef398361b23..0e0a5712beb6e 100644 --- a/lib/srv/mock.go +++ b/lib/srv/mock.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/lib/pam" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -182,6 +183,11 @@ func (m *mockServer) GetAccessPoint() AccessPoint { return m.auth } +// GetSessionServer returns a session server. +func (m *mockServer) GetSessionServer() rsession.Service { + return rsession.NewDiscardSessionServer() +} + // GetDataDir returns data directory of the server func (m *mockServer) GetDataDir() string { return "testDataDir" diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 0ba28ca9f4d1b..482fb95a2c191 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -1651,6 +1651,7 @@ func (s *session) trackSession(ctx context.Context, teleportUser string, policyS } else { s.tracker, err = NewSessionTracker(ctx, trackerSpec, s.registry.SessionTrackerService) } + if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/sessiontracker.go b/lib/srv/sessiontracker.go index 872b0f9ae1137..dd3f1682af46f 100644 --- a/lib/srv/sessiontracker.go +++ b/lib/srv/sessiontracker.go @@ -21,11 +21,12 @@ import ( "sync" "time" + "github.com/jonboulle/clockwork" + "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" - "github.com/jonboulle/clockwork" "github.com/gravitational/trace" ) @@ -113,10 +114,8 @@ func (s *SessionTracker) UpdateExpiration(ctx context.Context, expiry time.Time) }, }, }) - return trace.Wrap(err) } - return nil } @@ -135,7 +134,6 @@ func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participan }, }, }) - return trace.Wrap(err) } @@ -157,7 +155,6 @@ func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID st }, }, }) - return trace.Wrap(err) } @@ -179,7 +176,6 @@ func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionSta }, }, }) - return trace.Wrap(err) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 3e14d583c25d3..9f0956ea726f8 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -54,6 +54,8 @@ import ( "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" authproto "github.com/gravitational/teleport/api/client/proto" @@ -93,7 +95,6 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/ui" - "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/julienschmidt/httprouter"