Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions lib/srv/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,11 @@ type Server struct {

proxyPort string

cache *sessionChunkCache
// cache holds sessionChunk objects for in-flight app sessions.
cache *utils.FnCache
// cacheCloseWg prevents closing the app server until all app
// sessions have been removed from the cache and closed.
cacheCloseWg sync.WaitGroup

awsHandler http.Handler
azureHandler http.Handler
Expand Down Expand Up @@ -338,18 +342,40 @@ func New(ctx context.Context, c *Config) (*Server, error) {

// Create a new session cache, this holds sessions that can be used to
// forward requests.
s.cache, err = s.newSessionChunkCache()
s.cache, err = utils.NewFnCache(utils.FnCacheConfig{
TTL: 5 * time.Minute,
Context: s.closeContext,
Clock: s.c.Clock,
CleanupInterval: time.Second,
OnExpiry: s.onSessionExpired,
})
if err != nil {
return nil, trace.Wrap(err)
}

go s.expireSessions()

// Figure out the port the proxy is running on.
s.proxyPort = s.getProxyPort()

callClose = false
return s, nil
}

func (s *Server) expireSessions() {
ticker := time.NewTicker(time.Second)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why one second? Seems pretty frequent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defer ticker.Stop()

for {
select {
case <-ticker.C:
s.cache.RemoveExpired()
case <-s.closeContext.Done():
return
}
}
}

// startApp registers the specified application.
func (s *Server) startApp(ctx context.Context, app types.Application) error {
// Start a goroutine that will be updating apps's command labels (if any)
Expand Down Expand Up @@ -620,10 +646,12 @@ func (s *Server) close(ctx context.Context) error {
errs = append(errs, err)
}

// Close the session cache and its remaining sessions. Sessions
// use server.closeContext to complete cleanup, so we must wait
// for sessions to finish closing before closing the context.
s.cache.closeAllSessions()
// Close the session cache and its remaining sessions.
s.cache.Shutdown(s.closeContext)
// Any sessions still in the cache during shutdown are closed in
// background goroutines. We must wait for sessions to finish closing
// before proceeding any further.
s.cacheCloseWg.Wait()

// Signal to any blocking go routine that it should exit.
s.closeFunc()
Expand Down Expand Up @@ -879,10 +907,18 @@ func (s *Server) serveAWSWebConsole(w http.ResponseWriter, r *http.Request, iden
func (s *Server) serveSession(w http.ResponseWriter, r *http.Request, identity *tlsca.Identity, app types.Application, opts ...sessionOpt) error {
// Fetch a cached request forwarder (or create one) that lives about 5
// minutes. Used to stream session chunks to the Audit Log.
session, err := s.getSession(r.Context(), identity, app, opts...)
ttl := min(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Minute)
session, err := utils.FnCacheGetWithTTL(r.Context(), s.cache, identity.RouteToApp.SessionID, ttl, func(ctx context.Context) (*sessionChunk, error) {
session, err := s.newSessionChunk(ctx, identity, app, opts...)
return session, trace.Wrap(err)
})
if err != nil {
return trace.Wrap(err)
}

if err := session.acquire(); err != nil {
return trace.Wrap(err)
}
defer session.release()

// Create session context.
Expand Down Expand Up @@ -995,27 +1031,6 @@ func (s *Server) authorizeContext(ctx context.Context) (*authz.Context, types.Ap
return authContext, app, nil
}

// getSession returns a request session used to proxy the request to the
// target application. Always checks if the session is valid first and if so,
// will return a cached session, otherwise will create one.
// The in-flight request count is automatically incremented on the session.
// The caller must call session.release() after finishing its use
func (s *Server) getSession(ctx context.Context, identity *tlsca.Identity, app types.Application, opts ...sessionOpt) (*sessionChunk, error) {
session, err := s.cache.get(identity.RouteToApp.SessionID)
// If a cached forwarder exists, return it right away.
if err == nil && session.acquire() == nil {
return session, nil
}

// Create a new session with a recorder and forwarder in it.
session, err = s.newSessionChunk(ctx, identity, app, opts...)
if err != nil {
return nil, trace.Wrap(err)
}

return session, nil
}

// getApp returns an application matching the public address. If multiple
// matching applications exist, the first one is returned. Random selection
// (or round robin) does not need to occur here because they will all point
Expand Down
135 changes: 19 additions & 116 deletions lib/srv/app/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,12 @@ import (

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/gravitational/ttlmap"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/recorder"
"github.com/gravitational/teleport/lib/httplib/reverseproxy"
Expand Down Expand Up @@ -99,12 +97,11 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity,
id: uuid.New().String(),
closeC: make(chan struct{}),
inflightCond: sync.NewCond(&sync.Mutex{}),
inflight: 1,
closeTimeout: sessionChunkCloseTimeout,
log: s.log,
}

sess.log.Debugf("Created app session chunk %s", sess.id)
sess.log.Debugf("Creating app session chunk %s", sess.id)

// Create a session tracker so that other services, such as the
// session upload completer, can track the session chunk's lifetime.
Expand Down Expand Up @@ -136,17 +133,12 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity,
}
}

// Put the session chunk in the cache so that upcoming requests can use it for
// 5 minutes or the time until the certificate expires, whichever comes first.
ttl := min(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Minute)
if err = s.cache.set(identity.RouteToApp.SessionID, sess, ttl); err != nil {
return nil, trace.Wrap(err)
}

// only emit a session chunk if we didnt get an error making the new session chunk
// only emit a session chunk if we didn't get an error making the new session chunk
if err := sess.audit.OnSessionChunk(ctx, s.c.HostID, sess.id, identity, app); err != nil {
return nil, trace.Wrap(err)
}

sess.log.Debugf("Created app session chunk %s", sess.id)
return sess, nil
}

Expand Down Expand Up @@ -282,10 +274,22 @@ func (s *sessionChunk) close(ctx context.Context) error {
return trace.Wrap(s.streamCloser.Close(ctx))
}

func (s *Server) closeSession(sess *sessionChunk) {
if err := sess.close(s.closeContext); err != nil {
s.log.WithError(err).Debugf("Error closing session %v", sess.id)
func (s *Server) onSessionExpired(ctx context.Context, key, expired any) {
sess, ok := expired.(*sessionChunk)
if !ok {
return
}

// Closing the session stream writer may trigger a flush operation which could
// be time-consuming. Launch in another goroutine to prevent interfering with
// cache operations.
s.cacheCloseWg.Add(1)
go func() {
defer s.cacheCloseWg.Done()
if err := sess.close(ctx); err != nil {
s.log.WithError(err).Debugf("Error closing session %v", sess.id)
}
}()
}

// newSessionRecorder creates a session stream that will be used to record
Expand Down Expand Up @@ -355,104 +359,3 @@ func (s *Server) createTracker(sess *sessionChunk, identity *tlsca.Identity, app

return nil
}

// sessionChunkCache holds a cache of session chunks.
type sessionChunkCache struct {
srv *Server

mu sync.Mutex
cache *ttlmap.TTLMap
}

// newSessionChunkCache creates a new session chunk cache.
func (s *Server) newSessionChunkCache() (*sessionChunkCache, error) {
sessionCache := &sessionChunkCache{srv: s}

// Cache of session chunks. Set an expire function that can be used
// to close and upload the stream of events to the Audit Log.
var err error
sessionCache.cache, err = ttlmap.New(defaults.ClientCacheSize, ttlmap.CallOnExpire(sessionCache.expire), ttlmap.Clock(s.c.Clock))
if err != nil {
return nil, trace.Wrap(err)
}

go sessionCache.expireSessions()

return sessionCache, nil
}

// get will fetch the session chunk from the cache.
func (s *sessionChunkCache) get(key string) (*sessionChunk, error) {
s.mu.Lock()
defer s.mu.Unlock()

if f, ok := s.cache.Get(key); ok {
if fwd, fok := f.(*sessionChunk); fok {
return fwd, nil
}
return nil, trace.BadParameter("invalid type stored in cache: %T", f)
}
return nil, trace.NotFound("session not found")
}

// set will add the session chunk to the cache.
func (s *sessionChunkCache) set(sessionID string, sess *sessionChunk, ttl time.Duration) error {
s.mu.Lock()
defer s.mu.Unlock()

if err := s.cache.Set(sessionID, sess, ttl); err != nil {
return trace.Wrap(err)
}
return nil
}

// expire will close the stream writer.
func (s *sessionChunkCache) expire(key string, el interface{}) {
// Closing the session stream writer may trigger a flush operation which could
// be time-consuming. Launch in another goroutine since this occurs under a
// lock and expire can get called during a "get" operation on the ttlmap.
go s.closeSession(el)
s.srv.log.Debugf("Closing expired stream %v.", key)
}

func (s *sessionChunkCache) closeSession(el interface{}) {
switch sess := el.(type) {
case *sessionChunk:
s.srv.closeSession(sess)
default:
s.srv.log.Debugf("Invalid type stored in cache: %T.", el)
}
}

// expireSessions ticks every second trying to close expired sessions.
func (s *sessionChunkCache) expireSessions() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()

for {
select {
case <-ticker.C:
s.expiredSessions()
case <-s.srv.closeContext.Done():
return
}
}
}

// expiredSession tries to expire sessions in the cache.
func (s *sessionChunkCache) expiredSessions() {
s.mu.Lock()
defer s.mu.Unlock()

s.cache.RemoveExpired(10)
}

// closeAllSessions will remove and close all sessions in the cache.
func (s *sessionChunkCache) closeAllSessions() {
s.mu.Lock()
defer s.mu.Unlock()

for _, session, ok := s.cache.Pop(); ok; _, session, ok = s.cache.Pop() {
s.closeSession(session)
}
}
Loading