Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,24 @@ func init() {
metrics.RegisterPrometheusCollectors(proxiedSessions, failedConnectingToNode, connectingToNode)
}

// proxiedMetricConn wraps [net.Conn] opened by
// ProxiedMetricConn wraps [net.Conn] opened by
// the [Router] so that the proxiedSessions counter
// can be decremented when it is closed.
type proxiedMetricConn struct {
type ProxiedMetricConn struct {
// once ensures that proxiedSessions is only decremented
// a single time per [net.Conn]
once sync.Once
net.Conn
}

// newProxiedMetricConn increments proxiedSessions and creates
// a proxiedMetricConn that defers to the provided [net.Conn].
func newProxiedMetricConn(conn net.Conn) *proxiedMetricConn {
// NewProxiedMetricConn increments proxiedSessions and creates
// a ProxiedMetricConn that defers to the provided [net.Conn].
func NewProxiedMetricConn(conn net.Conn) *ProxiedMetricConn {
proxiedSessions.Inc()
return &proxiedMetricConn{Conn: conn}
return &ProxiedMetricConn{Conn: conn}
}

func (c *proxiedMetricConn) Close() error {
func (c *ProxiedMetricConn) Close() error {
c.once.Do(proxiedSessions.Dec)
return trace.Wrap(c.Conn.Close())
}
Expand Down Expand Up @@ -313,7 +313,7 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
return nil, trace.Wrap(err)
}

return newProxiedMetricConn(conn), trace.Wrap(err)
return NewProxiedMetricConn(conn), trace.Wrap(err)
}

// getRemoteCluster looks up the provided clusterName to determine if a remote site exists with
Expand Down Expand Up @@ -475,7 +475,7 @@ func (r *Router) DialSite(ctx context.Context, clusterName string, clientSrcAddr
return nil, trace.Wrap(err)
}

return newProxiedMetricConn(conn), trace.Wrap(err)
return NewProxiedMetricConn(conn), trace.Wrap(err)
}

// GetSiteClient returns an auth client for the provided cluster.
Expand Down
13 changes: 12 additions & 1 deletion lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ func (p *AgentPool) getVersion(ctx context.Context) (string, error) {

// transport creates a new transport instance.
func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request, conn sshutils.Conn) *transport {
return &transport{
t := &transport{
closeContext: ctx,
component: p.Component,
localClusterName: p.LocalCluster,
Expand All @@ -531,6 +531,17 @@ func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests
proxySigner: p.PROXYSigner,
forwardClientAddress: true,
}

// If the AgentPool is being used for Proxy to Proxy communication between two clusters, then
// we check if the reverse tunnel server is capable of tracking user connections. This allows
// the leaf proxy to track sessions that are initiated via the root cluster. Without providing
// the user tracker the leaf cluster metrics will be incorrect and graceful shutdown will not
// wait for user sessions to be terminated prior to proceeding with the shutdown operation.
if p.IsRemoteCluster && p.ReverseTunnelServer != nil {
t.trackUserConnection = p.ReverseTunnelServer.TrackUserConnection
}

return t
}

// agentPoolRuntimeConfig contains configurations dynamically set and updated
Expand Down
7 changes: 7 additions & 0 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,13 @@ func (s *server) rejectRequest(ch ssh.NewChannel, reason ssh.RejectionReason, ms
}
}

// TrackUserConnection tracks a user connection that should prevent
// the server from being terminated if active. The returned function
// should be called when the connection is terminated.
func (s *server) TrackUserConnection() (release func()) {
return s.srv.TrackUserConnection()
}

// newRemoteSite helper creates and initializes 'remoteSite' instance
func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, error) {
connInfo, err := types.NewTunnelConnection(
Expand Down
13 changes: 13 additions & 0 deletions lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/proxy"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/utils"
)
Expand Down Expand Up @@ -93,6 +94,9 @@ type transport struct {
// preventing users connecting to the proxy tunnel listener spoofing their address; but we are still able to
// correctly propagate client address in reverse tunnel agents of nodes/services.
forwardClientAddress bool

// trackUserConnection is an optional mechanism used to count active user sessions.
trackUserConnection func() (release func())
}

// start will start the transporting data over the tunnel. This function will
Expand Down Expand Up @@ -246,6 +250,10 @@ func (p *transport) start() {
// tunnel from the SSH node by dreq.ServerID. We'll need to forward
// dreq.Address as well.
directAddress = dreq.Address

if p.trackUserConnection != nil {
defer p.trackUserConnection()()
}
default:
// Not a special address; could be empty.
directAddress = dreq.Address
Expand Down Expand Up @@ -395,6 +403,11 @@ func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, e
}

p.log.Debugf("Returning connection dialed through tunnel with server ID %v.", r.ServerID)

if r.ConnType == types.NodeTunnel {
return proxy.NewProxiedMetricConn(conn), true, nil
}

return conn, true, nil
}

Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnelclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ type Server interface {
Wait(ctx context.Context)
// GetProxyPeerClient returns the proxy peer client
GetProxyPeerClient() *peer.Client
// TrackUserConnection tracks a user connection that should prevent
// the server from being terminated if active. The returned function
// should be called when the connection is terminated.
TrackUserConnection() (release func())
}

const (
Expand Down
3 changes: 2 additions & 1 deletion lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4657,12 +4657,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
// really guaranteed to be capable to serve new requests if we're
// halfway through a shutdown, and double closing a listener is fine.
listeners.Close()
rcWatcher.Close()
if payload == nil {
log.Infof("Shutting down immediately.")
if tsrv != nil {
warnOnErr(tsrv.Close(), log)
}
warnOnErr(rcWatcher.Close(), log)
if proxyServer != nil {
warnOnErr(proxyServer.Close(), log)
}
Expand Down Expand Up @@ -4709,6 +4709,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if tsrv != nil {
warnOnErr(tsrv.Shutdown(ctx), log)
}
warnOnErr(rcWatcher.Close(), log)
if proxyServer != nil {
warnOnErr(proxyServer.Shutdown(), log)
}
Expand Down
11 changes: 11 additions & 0 deletions lib/sshutils/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,17 @@ func (s *Server) trackUserConnections(delta int32) int32 {
return atomic.AddInt32(&s.userConns, delta)
}

// TrackUserConnection tracks a user connection that should prevent
// the server from being terminated if active. The returned function
// should be called when the connection is terminated.
func (s *Server) TrackUserConnection() (release func()) {
s.trackUserConnections(1)

return sync.OnceFunc(func() {
s.trackUserConnections(-1)
})
}

// ActiveConnections returns the number of connections that are
// being served.
func (s *Server) ActiveConnections() int32 {
Expand Down