diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index 4359c6bdd88f8..b7a157a43afd2 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -214,7 +214,11 @@ type Server struct { proxyPort string - cache *sessionChunkCache + // cache holds sessionChunk objects for in-flight app sessions. + cache *utils.FnCache + // cacheCloseWg prevents closing the app server until all app + // sessions have been removed from the cache and closed. + cacheCloseWg sync.WaitGroup awsHandler http.Handler azureHandler http.Handler @@ -336,11 +340,19 @@ func New(ctx context.Context, c *Config) (*Server, error) { // Create a new session cache, this holds sessions that can be used to // forward requests. - s.cache, err = s.newSessionChunkCache() + s.cache, err = utils.NewFnCache(utils.FnCacheConfig{ + TTL: 5 * time.Minute, + Context: s.closeContext, + Clock: s.c.Clock, + CleanupInterval: time.Second, + OnExpiry: s.onSessionExpired, + }) if err != nil { return nil, trace.Wrap(err) } + go s.expireSessions() + // Figure out the port the proxy is running on. s.proxyPort = s.getProxyPort() @@ -348,6 +360,20 @@ func New(ctx context.Context, c *Config) (*Server, error) { return s, nil } +func (s *Server) expireSessions() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + s.cache.RemoveExpired() + case <-s.closeContext.Done(): + return + } + } +} + // startApp registers the specified application. func (s *Server) startApp(ctx context.Context, app types.Application) error { // Start a goroutine that will be updating apps's command labels (if any) @@ -618,10 +644,12 @@ func (s *Server) close(ctx context.Context) error { errs = append(errs, err) } - // Close the session cache and its remaining sessions. Sessions - // use server.closeContext to complete cleanup, so we must wait - // for sessions to finish closing before closing the context. - s.cache.closeAllSessions() + // Close the session cache and its remaining sessions. + s.cache.Shutdown(s.closeContext) + // Any sessions still in the cache during shutdown are closed in + // background goroutines. We must wait for sessions to finish closing + // before proceeding any further. + s.cacheCloseWg.Wait() // Signal to any blocking go routine that it should exit. s.closeFunc() @@ -868,10 +896,22 @@ func (s *Server) serveAWSWebConsole(w http.ResponseWriter, r *http.Request, iden func (s *Server) serveSession(w http.ResponseWriter, r *http.Request, identity *tlsca.Identity, app types.Application, opts ...sessionOpt) error { // Fetch a cached request forwarder (or create one) that lives about 5 // minutes. Used to stream session chunks to the Audit Log. - session, err := s.getSession(r.Context(), identity, app, opts...) + expiry := identity.Expires.Sub(s.c.Clock.Now()) + ttl := 5 * time.Minute + if expiry < ttl { + ttl = expiry + } + session, err := utils.FnCacheGetWithTTL(r.Context(), s.cache, identity.RouteToApp.SessionID, ttl, func(ctx context.Context) (*sessionChunk, error) { + session, err := s.newSessionChunk(ctx, identity, app, opts...) + return session, trace.Wrap(err) + }) if err != nil { return trace.Wrap(err) } + + if err := session.acquire(); err != nil { + return trace.Wrap(err) + } defer session.release() // Create session context. @@ -981,27 +1021,6 @@ func (s *Server) authorizeContext(ctx context.Context) (*authz.Context, types.Ap return authContext, app, nil } -// getSession returns a request session used to proxy the request to the -// target application. Always checks if the session is valid first and if so, -// will return a cached session, otherwise will create one. -// The in-flight request count is automatically incremented on the session. -// The caller must call session.release() after finishing its use -func (s *Server) getSession(ctx context.Context, identity *tlsca.Identity, app types.Application, opts ...sessionOpt) (*sessionChunk, error) { - session, err := s.cache.get(identity.RouteToApp.SessionID) - // If a cached forwarder exists, return it right away. - if err == nil && session.acquire() == nil { - return session, nil - } - - // Create a new session with a recorder and forwarder in it. - session, err = s.newSessionChunk(ctx, identity, app, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - - return session, nil -} - // getApp returns an application matching the public address. If multiple // matching applications exist, the first one is returned. Random selection // (or round robin) does not need to occur here because they will all point diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index 9ddf393fe90b4..263d8cb71e138 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -27,14 +27,12 @@ import ( "github.com/google/uuid" "github.com/gravitational/oxy/forward" "github.com/gravitational/trace" - "github.com/gravitational/ttlmap" "github.com/sirupsen/logrus" "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/wrappers" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/filesessions" "github.com/gravitational/teleport/lib/services" @@ -99,12 +97,11 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity, id: uuid.New().String(), closeC: make(chan struct{}), inflightCond: sync.NewCond(&sync.Mutex{}), - inflight: 1, closeTimeout: sessionChunkCloseTimeout, log: s.log, } - sess.log.Debugf("Created app session chunk %s", sess.id) + sess.log.Debugf("Creating app session chunk %s", sess.id) // Create a session tracker so that other services, such as the // session upload completer, can track the session chunk's lifetime. @@ -133,17 +130,12 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity, } } - // Put the session chunk in the cache so that upcoming requests can use it for - // 5 minutes or the time until the certificate expires, whichever comes first. - ttl := utils.MinTTL(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Minute) - if err = s.cache.set(identity.RouteToApp.SessionID, sess, ttl); err != nil { - return nil, trace.Wrap(err) - } - - // only emit a session chunk if we didnt get an error making the new session chunk + // only emit a session chunk if we didn't get an error making the new session chunk if err := sess.audit.OnSessionChunk(ctx, s.c.HostID, sess.id, identity, app); err != nil { return nil, trace.Wrap(err) } + + sess.log.Debugf("Created app session chunk %s", sess.id) return sess, nil } @@ -281,10 +273,22 @@ func (s *sessionChunk) close(ctx context.Context) error { return trace.Wrap(s.streamCloser.Close(ctx)) } -func (s *Server) closeSession(sess *sessionChunk) { - if err := sess.close(s.closeContext); err != nil { - s.log.WithError(err).Debugf("Error closing session %v", sess.id) +func (s *Server) onSessionExpired(ctx context.Context, key, expired any) { + sess, ok := expired.(*sessionChunk) + if !ok { + return } + + // Closing the session stream writer may trigger a flush operation which could + // be time-consuming. Launch in another goroutine to prevent interfering with + // cache operations. + s.cacheCloseWg.Add(1) + go func() { + defer s.cacheCloseWg.Done() + if err := sess.close(ctx); err != nil { + s.log.WithError(err).Debugf("Error closing session %v", sess.id) + } + }() } // newStreamWriter creates a session stream that will be used to record @@ -381,104 +385,3 @@ func (s *Server) createTracker(sess *sessionChunk, identity *tlsca.Identity, app return nil } - -// sessionChunkCache holds a cache of session chunks. -type sessionChunkCache struct { - srv *Server - - mu sync.Mutex - cache *ttlmap.TTLMap -} - -// newSessionChunkCache creates a new session chunk cache. -func (s *Server) newSessionChunkCache() (*sessionChunkCache, error) { - sessionCache := &sessionChunkCache{srv: s} - - // Cache of session chunks. Set an expire function that can be used - // to close and upload the stream of events to the Audit Log. - var err error - sessionCache.cache, err = ttlmap.New(defaults.ClientCacheSize, ttlmap.CallOnExpire(sessionCache.expire), ttlmap.Clock(s.c.Clock)) - if err != nil { - return nil, trace.Wrap(err) - } - - go sessionCache.expireSessions() - - return sessionCache, nil -} - -// get will fetch the session chunk from the cache. -func (s *sessionChunkCache) get(key string) (*sessionChunk, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if f, ok := s.cache.Get(key); ok { - if fwd, fok := f.(*sessionChunk); fok { - return fwd, nil - } - return nil, trace.BadParameter("invalid type stored in cache: %T", f) - } - return nil, trace.NotFound("session not found") -} - -// set will add the session chunk to the cache. -func (s *sessionChunkCache) set(sessionID string, sess *sessionChunk, ttl time.Duration) error { - s.mu.Lock() - defer s.mu.Unlock() - - if err := s.cache.Set(sessionID, sess, ttl); err != nil { - return trace.Wrap(err) - } - return nil -} - -// expire will close the stream writer. -func (s *sessionChunkCache) expire(key string, el interface{}) { - // Closing the session stream writer may trigger a flush operation which could - // be time-consuming. Launch in another goroutine since this occurs under a - // lock and expire can get called during a "get" operation on the ttlmap. - go s.closeSession(el) - s.srv.log.Debugf("Closing expired stream %v.", key) -} - -func (s *sessionChunkCache) closeSession(el interface{}) { - switch sess := el.(type) { - case *sessionChunk: - s.srv.closeSession(sess) - default: - s.srv.log.Debugf("Invalid type stored in cache: %T.", el) - } -} - -// expireSessions ticks every second trying to close expired sessions. -func (s *sessionChunkCache) expireSessions() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - s.expiredSessions() - case <-s.srv.closeContext.Done(): - return - } - } -} - -// expiredSession tries to expire sessions in the cache. -func (s *sessionChunkCache) expiredSessions() { - s.mu.Lock() - defer s.mu.Unlock() - - s.cache.RemoveExpired(10) -} - -// closeAllSessions will remove and close all sessions in the cache. -func (s *sessionChunkCache) closeAllSessions() { - s.mu.Lock() - defer s.mu.Unlock() - - for _, session, ok := s.cache.Pop(); ok; _, session, ok = s.cache.Pop() { - s.closeSession(session) - } -}