Skip to content

Commit

Permalink
Make agent channel setup lazy.
Browse files Browse the repository at this point in the history
Changes agent channel setup behavior to be consistent
openssh by having servers lazily request agent channels
when they are needed, rather than immediately starting a
single connection-wide channel as soon as forwarding is
requested.  Fixes an issue introduced in #3613 which
caused openssh clients to hang on exit due to persistent
agent channel.
  • Loading branch information
fspmarshall committed Jun 10, 2020
1 parent 6471a0f commit acde213
Show file tree
Hide file tree
Showing 20 changed files with 553 additions and 477 deletions.
9 changes: 6 additions & 3 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ func (s *discardServer) Stop() {
s.sshServer.Close()
}

func (s *discardServer) HandleNewChan(ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) {
func (s *discardServer) HandleNewChan(_ context.Context, ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) {
channel, reqs, err := newChannel.Accept()
if err != nil {
ccx.ServerConn.Close()
Expand Down Expand Up @@ -1400,10 +1400,13 @@ func createAgent(me *user.User, privateKeyByte []byte, certificateBytes []byte)
}

// create a (unstarted) agent and add the key to it
teleAgent := teleagent.NewServer()
if err := teleAgent.Add(agentKey); err != nil {
keyring := agent.NewKeyring()
if err := keyring.Add(agentKey); err != nil {
return nil, "", "", trace.Wrap(err)
}
teleAgent := teleagent.NewServer(func() (teleagent.Agent, error) {
return teleagent.NopCloser(keyring), nil
})

// start the SSH agent
err = teleAgent.ListenUnixSocket(sockPath, uid, gid, 0600)
Expand Down
5 changes: 3 additions & 2 deletions lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package multiplexer

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -80,7 +81,7 @@ func (s *MuxSuite) TestMultiplexing(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
err := nch.Reject(ssh.Prohibited, "nothing to see here")
c.Assert(err, check.IsNil)
Expand Down Expand Up @@ -380,7 +381,7 @@ func (s *MuxSuite) TestDisableTLS(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
err := nch.Reject(ssh.Prohibited, "nothing to see here")
c.Assert(err, check.IsNil)
Expand Down
7 changes: 3 additions & 4 deletions lib/reversetunnel/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ import (
"net"
"time"

"golang.org/x/crypto/ssh/agent"

"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/teleagent"
)

// DialParams is a list of parameters used to Dial to a node within a cluster.
Expand All @@ -35,9 +34,9 @@ type DialParams struct {
// To is the destination address.
To net.Addr

// UserAgent is SSH agent used to connect to the remote host. Used by the
// GetUserAgent gets an SSH agent for use in connecting to the remote host. Used by the
// forwarding proxy.
UserAgent agent.Agent
GetUserAgent teleagent.Getter

// Address is used by the forwarding proxy to generate a host certificate for
// the target node. This is needed because while dialing occurs via IP
Expand Down
16 changes: 12 additions & 4 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ func (s *localSite) Dial(params DialParams) (net.Conn, error) {
return nil, trace.Wrap(err)
}
if clusterConfig.GetSessionRecording() == services.RecordAtProxy {
if params.UserAgent == nil {
return nil, trace.BadParameter("user agent missing")
}
return s.dialWithAgent(params)
}

Expand All @@ -195,18 +192,29 @@ func (s *localSite) DialTCP(params DialParams) (net.Conn, error) {
func (s *localSite) IsClosed() bool { return false }

func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
if params.GetUserAgent == nil {
return nil, trace.BadParameter("user agent getter missing")
}
s.log.Debugf("Dialing with an agent from %v to %v.", params.From, params.To)

// request user agent connection
userAgent, err := params.GetUserAgent()
if err != nil {
return nil, trace.Wrap(err)
}

// If server ID matches a node that has self registered itself over the tunnel,
// return a connection to that node. Otherwise net.Dial to the target host.
targetConn, useTunnel, err := s.getConn(params)
if err != nil {
userAgent.Close()
return nil, trace.Wrap(err)
}

// Get a host certificate for the forwarding node from the cache.
hostCertificate, err := s.certificateCache.GetHostCertificate(params.Address, params.Principals)
if err != nil {
userAgent.Close()
return nil, trace.Wrap(err)
}

Expand All @@ -215,7 +223,7 @@ func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
// once conn is closed.
serverConfig := forward.ServerConfig{
AuthClient: s.client,
UserAgent: params.UserAgent,
UserAgent: userAgent,
TargetConn: targetConn,
SrcAddr: params.From,
DstAddr: params.To,
Expand Down
16 changes: 12 additions & 4 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,6 @@ func (s *remoteSite) Dial(params DialParams) (net.Conn, error) {
// If the proxy is in recording mode use the agent to dial and build a
// in-memory forwarding server.
if clusterConfig.GetSessionRecording() == services.RecordAtProxy {
if params.UserAgent == nil {
return nil, trace.BadParameter("user agent missing")
}
return s.dialWithAgent(params)
}
return s.DialTCP(params)
Expand All @@ -532,11 +529,21 @@ func (s *remoteSite) DialTCP(params DialParams) (net.Conn, error) {
}

func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) {
if params.GetUserAgent == nil {
return nil, trace.BadParameter("user agent getter missing")
}
s.Debugf("Dialing with an agent from %v to %v.", params.From, params.To)

// request user agent connection
userAgent, err := params.GetUserAgent()
if err != nil {
return nil, trace.Wrap(err)
}

// Get a host certificate for the forwarding node from the cache.
hostCertificate, err := s.certificateCache.GetHostCertificate(params.Address, params.Principals)
if err != nil {
userAgent.Close()
return nil, trace.Wrap(err)
}

Expand All @@ -545,6 +552,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) {
ServerID: params.ServerID,
})
if err != nil {
userAgent.Close()
return nil, trace.Wrap(err)
}

Expand All @@ -556,7 +564,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) {
// session gets recorded in the local cluster instead of the remote cluster.
serverConfig := forward.ServerConfig{
AuthClient: s.localClient,
UserAgent: params.UserAgent,
UserAgent: userAgent,
TargetConn: targetConn,
SrcAddr: params.From,
DstAddr: params.To,
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func (s *server) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}

func (s *server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) {
func (s *server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) {
// Apply read/write timeouts to the server connection.
conn := utils.ObeyIdleTimeout(ccx.NetConn,
s.offlineThreshold,
Expand Down
23 changes: 17 additions & 6 deletions lib/reversetunnel/track/tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ func (s *simpleTestProxies) GetRandProxy() (p testProxy, ok bool) {
}

func (s *simpleTestProxies) Discover(tracker *Tracker, lease Lease) (ok bool) {
defer lease.Release()
proxy, ok := s.GetRandProxy()
if !ok {
panic("discovery called with no available proxies")
}
return s.ProxyLoop(tracker, lease, proxy)
}

func (s *simpleTestProxies) ProxyLoop(tracker *Tracker, lease Lease, proxy testProxy) (ok bool) {
defer lease.Release()
timeout := time.After(proxy.life)
ok = tracker.WithProxy(func() {
ticker := time.NewTicker(jitter(time.Millisecond * 100))
Expand Down Expand Up @@ -165,7 +169,7 @@ Discover:
break Discover
}
case <-timeoutC:
panic("timeout")
c.Fatal("timeout")
}
}
}
Expand Down Expand Up @@ -193,15 +197,22 @@ Loop0:
select {
case lease := <-tracker.Acquire():
c.Assert(lease.Key().(Key), check.DeepEquals, key)
go proxies.Discover(tracker, lease)
// get our "discovered" proxy in the foreground
// to prevent race with the call to RemoveRandProxies
// that comes after this loop.
proxy, ok := proxies.GetRandProxy()
if !ok {
c.Fatal("failed to get test proxy")
}
go proxies.ProxyLoop(tracker, lease, proxy)
case <-ticker.C:
counts := tracker.wp.Get(key)
c.Logf("Counts0: %+v", counts)
if counts.Active == proxyCount {
break Loop0
}
case <-timeoutC:
panic("timeout")
c.Fatal("timeout")
}
}
proxies.RemoveRandProxies(proxyCount)
Expand All @@ -215,7 +226,7 @@ Loop1:
break Loop1
}
case <-timeoutC:
panic("timeout")
c.Fatal("timeout")
}
}
proxies.AddRandProxies(proxyCount, minConnB, maxConnB)
Expand All @@ -231,7 +242,7 @@ Loop2:
break Loop2
}
case <-timeoutC:
panic("timeout")
c.Fatal("timeout")
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/authhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ func (h *AuthHandlers) CheckPortForward(addr string, ctx *ServerContext) error {
events.PortForwardErr: systemErrorMessage,
events.EventLogin: ctx.Identity.Login,
events.EventUser: ctx.Identity.TeleportUser,
events.LocalAddr: ctx.Conn.LocalAddr().String(),
events.RemoteAddr: ctx.Conn.RemoteAddr().String(),
events.LocalAddr: ctx.ServerConn.LocalAddr().String(),
events.RemoteAddr: ctx.ServerConn.RemoteAddr().String(),
}); err != nil {
h.Warnf("Failed to emit port forward deny audit event: %v", err)
}
Expand Down
Loading

0 comments on commit acde213

Please sign in to comment.