Skip to content

Commit

Permalink
prototype agent-forwarding fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fspmarshall committed Apr 21, 2020
1 parent df535f5 commit d74855a
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 49 deletions.
4 changes: 2 additions & 2 deletions lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *MuxSuite) TestMultiplexing(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, _ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
called = true
nch.Reject(ssh.Prohibited, "nothing to see here")
})
Expand Down Expand Up @@ -373,7 +373,7 @@ func (s *MuxSuite) TestDisableTLS(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, _ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
called = true
nch.Reject(ssh.Prohibited, "nothing to see here")
})
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func (s *server) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}

func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
func (s *server) HandleNewChan(ccx *sshutils.ConnectionContext, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
// Apply read/write timeouts to the server connection.
conn = utils.ObeyIdleTimeout(conn,
s.offlineThreshold,
Expand Down
47 changes: 32 additions & 15 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ type ServerContext struct {

sync.RWMutex

Parent *sshutils.ConnectionContext

// env is a list of environment variables passed to the session.
env map[string]string

Expand All @@ -187,10 +189,10 @@ type ServerContext struct {
term Terminal

// agent is a client to remote SSH agent.
agent agent.Agent
//agent agent.Agent

// agentCh is SSH channel using SSH agent protocol.
agentChannel ssh.Channel
//agentChannel ssh.Channel

// session holds the active session (if there's an active one).
session *session
Expand Down Expand Up @@ -291,7 +293,7 @@ type ServerContext struct {

// NewServerContext creates a new *ServerContext which is used to pass and
// manage resources.
func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext) (*ServerContext, error) {
func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext, parent *sshutils.ConnectionContext) (*ServerContext, error) {
clusterConfig, err := srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -300,6 +302,7 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
cancelContext, cancel := context.WithCancel(context.TODO())

ctx := &ServerContext{
Parent: parent,
id: int(atomic.AddInt32(&ctxID, int32(1))),
env: make(map[string]string),
srv: srv,
Expand Down Expand Up @@ -374,6 +377,9 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
ctx.AddCloser(ctx.contr)
ctx.AddCloser(ctx.contw)

// gather environment variables from parent.
ctx.SyncParentEnv()

return ctx, nil
}

Expand Down Expand Up @@ -449,20 +455,22 @@ func (c *ServerContext) AddCloser(closer io.Closer) {

// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent.
func (c *ServerContext) GetAgent() agent.Agent {
c.RLock()
defer c.RUnlock()
return c.agent
if c.Parent == nil {
return nil
}
return c.Parent.GetAgent()
}

// GetAgentChannel returns the channel over which communication with the agent occurs.
func (c *ServerContext) GetAgentChannel() ssh.Channel {
c.RLock()
defer c.RUnlock()
return c.agentChannel
if c.Parent == nil {
return nil
}
return c.Parent.GetAgentChannel()
}

// SetAgent sets the agent and channel over which communication with the agent occurs.
func (c *ServerContext) SetAgent(a agent.Agent, channel ssh.Channel) {
/*func (c *ServerContext) SetAgent(a agent.Agent, channel ssh.Channel) {
c.Lock()
defer c.Unlock()
if c.agentChannel != nil {
Expand All @@ -471,7 +479,7 @@ func (c *ServerContext) SetAgent(a agent.Agent, channel ssh.Channel) {
}
c.agentChannel = channel
c.agent = a
}
}*/

// GetTerm returns a Terminal.
func (c *ServerContext) GetTerm() Terminal {
Expand Down Expand Up @@ -500,6 +508,15 @@ func (c *ServerContext) GetEnv(key string) (string, bool) {
return val, ok
}

// SyncParentEnv is used to re-synchronize env vars after
// parent context has been updated.
func (c *ServerContext) SyncParentEnv() {
if c.Parent == nil {
return
}
c.Parent.ApplyEnv(c.env)
}

// takeClosers returns all resources that should be closed and sets the properties to null
// we do this to avoid calling Close() under lock to avoid potential deadlocks
func (c *ServerContext) takeClosers() []io.Closer {
Expand All @@ -512,10 +529,10 @@ func (c *ServerContext) takeClosers() []io.Closer {
closers = append(closers, c.term)
c.term = nil
}
if c.agentChannel != nil {
closers = append(closers, c.agentChannel)
c.agentChannel = nil
}
//if c.agentChannel != nil {
// closers = append(closers, c.agentChannel)
// c.agentChannel = nil
//}
closers = append(closers, c.closers...)
c.closers = nil
return closers
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ func (s *Server) handleChannel(nch ssh.NewChannel) {
func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext, nil)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down Expand Up @@ -713,7 +713,7 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) {
// There is no need for the forwarding server to initiate disconnects,
// based on teleport business logic, because this logic is already
// done on the server's terminating side.
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext, nil)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down
38 changes: 20 additions & 18 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,15 +761,16 @@ func (s *Server) serveAgent(ctx *srv.ServerContext) error {
}

// start an agent on a unix socket
agentServer := &teleagent.AgentServer{Agent: ctx.GetAgent()}
agentServer := &teleagent.AgentServer{Agent: ctx.Parent.GetAgent()}
err = agentServer.ListenUnixSocket(socketPath, uid, gid, 0600)
if err != nil {
return trace.Wrap(err)
}
ctx.SetEnv(teleport.SSHAuthSock, socketPath)
ctx.SetEnv(teleport.SSHAgentPID, fmt.Sprintf("%v", pid))
ctx.AddCloser(agentServer)
ctx.AddCloser(dirCloser)
ctx.Parent.SetEnv(teleport.SSHAuthSock, socketPath)
ctx.Parent.SetEnv(teleport.SSHAgentPID, fmt.Sprintf("%v", pid))
ctx.Parent.AddCloser(agentServer)
ctx.Parent.AddCloser(dirCloser)
ctx.SyncParentEnv()
ctx.Debugf("Opened agent channel for Teleport user %v and socket %v.", ctx.Identity.TeleportUser, socketPath)
go agentServer.Serve()

Expand Down Expand Up @@ -816,7 +817,7 @@ func (s *Server) HandleRequest(r *ssh.Request) {
}

// HandleNewChan is called when new channel is opened
func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
func (s *Server) HandleNewChan(ccx *sshutils.ConnectionContext, wconn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
identityContext, err := s.authHandlers.CreateIdentityContext(sconn)
if err != nil {
nch.Reject(ssh.Prohibited, fmt.Sprintf("Unable to create identity from connection: %v", err))
Expand All @@ -841,7 +842,7 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleProxyJump(wconn, sconn, identityContext, ch, *req)
go s.handleProxyJump(ccx, wconn, sconn, identityContext, ch, *req)
return
// Channels of type "session" handle requests that are involved in running
// commands on a server. In the case of proxy mode subsystem and agent
Expand All @@ -853,7 +854,7 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests)
go s.handleSessionRequests(ccx, wconn, sconn, identityContext, ch, requests)
return
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
Expand All @@ -871,7 +872,7 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleSessionRequests(wconn, sconn, identityContext, ch, requests)
go s.handleSessionRequests(ccx, wconn, sconn, identityContext, ch, requests)
// Channels of type "direct-tcpip" handles request for port forwarding.
case teleport.ChanDirectTCPIP:
req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData())
Expand All @@ -886,17 +887,17 @@ func (s *Server) HandleNewChan(wconn net.Conn, sconn *ssh.ServerConn, nch ssh.Ne
nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err))
return
}
go s.handleDirectTCPIPRequest(wconn, sconn, identityContext, ch, req)
go s.handleDirectTCPIPRequest(ccx, wconn, sconn, identityContext, ch, req)
default:
nch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType))
}
}

// handleDirectTCPIPRequest handles port forwarding requests.
func (s *Server) handleDirectTCPIPRequest(wconn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) {
func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, wconn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
ctx, err := srv.NewServerContext(s, sconn, identityContext, ccx)
if err != nil {
log.Errorf("Unable to create connection context: %v.", err)
channel.Stderr().Write([]byte("Unable to create connection context."))
Expand Down Expand Up @@ -996,10 +997,11 @@ func (s *Server) handleDirectTCPIPRequest(wconn net.Conn, sconn *ssh.ServerConn,
// handleSessionRequests handles out of band session requests once the session
// channel has been created this function's loop handles all the "exec",
// "subsystem" and "shell" requests.
func (s *Server) handleSessionRequests(conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
func (s *Server) handleSessionRequests(ccx *sshutils.ConnectionContext, conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) {
// Create context for this channel. This context will be closed when the
// session request is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
log.Infof("Starting session request handler for %v", conn.RemoteAddr())
ctx, err := srv.NewServerContext(s, sconn, identityContext, ccx)
if err != nil {
log.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down Expand Up @@ -1171,7 +1173,7 @@ func (s *Server) handleAgentForwardNode(req *ssh.Request, ctx *srv.ServerContext
}

// save the agent in the context so it can be used later
ctx.SetAgent(agent.NewClient(authChannel), authChannel)
ctx.Parent.SetAgent(agent.NewClient(authChannel), authChannel)

// serve an agent on a unix socket on this node
err = s.serveAgent(ctx)
Expand Down Expand Up @@ -1209,7 +1211,7 @@ func (s *Server) handleAgentForwardProxy(req *ssh.Request, ctx *srv.ServerContex
// Save the agent so it can be used when making a proxy subsystem request
// later. It will also be used when building a remote connection to the
// target node.
ctx.SetAgent(agent.NewClient(authChannel), authChannel)
ctx.Parent.SetAgent(agent.NewClient(authChannel), authChannel)

return nil
}
Expand Down Expand Up @@ -1304,10 +1306,10 @@ func (s *Server) handleVersionRequest(req *ssh.Request) {
}

// handleProxyJump handles ProxyJump request that is executed via direct tcp-ip dial on the proxy
func (s *Server) handleProxyJump(conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req sshutils.DirectTCPIPReq) {
func (s *Server) handleProxyJump(ccx *sshutils.ConnectionContext, conn net.Conn, sconn *ssh.ServerConn, identityContext srv.IdentityContext, ch ssh.Channel, req sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when the
// session request is complete.
ctx, err := srv.NewServerContext(s, sconn, identityContext)
ctx, err := srv.NewServerContext(s, sconn, identityContext, ccx)
if err != nil {
log.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down
15 changes: 13 additions & 2 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,19 @@ func (s *SrvSuite) TestAgentForward(c *C) {
err = client.Close()
c.Assert(err, IsNil)

// make sure the socket is gone after we closed the session
se.Close()
// make sure the socket persists after the session is closed.
// (agents are started from specific sessions, but apply to all
// sessions on the connection).
err = se.Close()
c.Assert(err, IsNil)
time.Sleep(150 * time.Millisecond)
_, err = net.Dial("unix", socketPath)
c.Assert(err, IsNil)

// make sure the socket is gone after we closed the connection.
err = s.clt.Close()
c.Assert(err, IsNil)
s.clt = nil
for i := 0; i < 4; i++ {
_, err = net.Dial("unix", socketPath)
if err != nil {
Expand Down
Loading

0 comments on commit d74855a

Please sign in to comment.