Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent forwarding for multi-session connections #3613

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions integration/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1265,11 +1265,11 @@ func (s *discardServer) Stop() {
s.sshServer.Close()
}

func (s *discardServer) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, newChannel ssh.NewChannel) {
func (s *discardServer) HandleNewChan(ccx *sshutils.ConnectionContext, newChannel ssh.NewChannel) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ctx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately assumed the convention of referring to sshutils.ConnectionContext as ccx to handle the fact that it is frequently in-scope at the same time as srv.ServerContext, which is already conventionally called ctx.

channel, reqs, err := newChannel.Accept()
if err != nil {
sconn.Close()
conn.Close()
ccx.ServerConn.Close()
ccx.NetConn.Close()
return
}

Expand Down
4 changes: 2 additions & 2 deletions lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *MuxSuite) TestMultiplexing(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
nch.Reject(ssh.Prohibited, "nothing to see here")
})
Expand Down Expand Up @@ -381,7 +381,7 @@ func (s *MuxSuite) TestDisableTLS(c *check.C) {
defer backend1.Close()

called := false
sshHandler := sshutils.NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
sshHandler := sshutils.NewChanHandlerFunc(func(_ *sshutils.ConnectionContext, nch ssh.NewChannel) {
called = true
nch.Reject(ssh.Prohibited, "nothing to see here")
})
Expand Down
5 changes: 3 additions & 2 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,12 @@ func (s *server) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}

func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
func (s *server) HandleNewChan(ccx *sshutils.ConnectionContext, nch ssh.NewChannel) {
// Apply read/write timeouts to the server connection.
conn = utils.ObeyIdleTimeout(conn,
conn := utils.ObeyIdleTimeout(ccx.NetConn,
s.offlineThreshold,
"reverse tunnel server")
sconn := ccx.ServerConn

channelType := nch.ChannelType()
switch channelType {
Expand Down
65 changes: 30 additions & 35 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ type ServerContext struct {

sync.RWMutex

Parent *sshutils.ConnectionContext

// env is a list of environment variables passed to the session.
env map[string]string

Expand All @@ -186,12 +188,6 @@ type ServerContext struct {
// term holds PTY if it was requested by the session.
term Terminal

// agent is a client to remote SSH agent.
agent agent.Agent

// agentCh is SSH channel using SSH agent protocol.
agentChannel ssh.Channel

// session holds the active session (if there's an active one).
session *session

Expand Down Expand Up @@ -291,7 +287,7 @@ type ServerContext struct {

// NewServerContext creates a new *ServerContext which is used to pass and
// manage resources.
func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext IdentityContext) (*ServerContext, error) {
func NewServerContext(ccx *sshutils.ConnectionContext, srv Server, identityContext IdentityContext) (*ServerContext, error) {
clusterConfig, err := srv.GetAccessPoint().GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -300,13 +296,15 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
cancelContext, cancel := context.WithCancel(context.TODO())

ctx := &ServerContext{
Parent: ccx,
id: int(atomic.AddInt32(&ctxID, int32(1))),
env: make(map[string]string),
srv: srv,
Conn: conn,
Connection: ccx.NetConn,
Conn: ccx.ServerConn,
ExecResultCh: make(chan ExecResult, 10),
SubsystemResultCh: make(chan SubsystemResult, 10),
ClusterName: conn.Permissions.Extensions[utils.CertTeleportClusterName],
ClusterName: ccx.ServerConn.Permissions.Extensions[utils.CertTeleportClusterName],
ClusterConfig: clusterConfig,
Identity: identityContext,
clientIdleTimeout: identityContext.RoleSet.AdjustClientIdleTimeout(clusterConfig.GetClientIdleTimeout()),
Expand All @@ -320,8 +318,8 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
}

fields := log.Fields{
"local": conn.LocalAddr(),
"remote": conn.RemoteAddr(),
"local": ctx.Conn.LocalAddr(),
"remote": ctx.Conn.RemoteAddr(),
"login": ctx.Identity.Login,
"teleportUser": ctx.Identity.TeleportUser,
"id": ctx.id,
Expand All @@ -343,7 +341,7 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
ClientIdleTimeout: ctx.clientIdleTimeout,
Clock: ctx.srv.GetClock(),
Tracker: ctx,
Conn: conn,
Conn: ctx.Conn,
Context: cancelContext,
TeleportUser: ctx.Identity.TeleportUser,
Login: ctx.Identity.Login,
Expand Down Expand Up @@ -374,6 +372,9 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity
ctx.AddCloser(ctx.contr)
ctx.AddCloser(ctx.contw)

// gather environment variables from parent.
ctx.ImportParentEnv()

return ctx, nil
}

Expand Down Expand Up @@ -447,30 +448,22 @@ func (c *ServerContext) AddCloser(closer io.Closer) {
c.closers = append(c.closers, closer)
}

// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent.
// GetAgent returns a agent.Agent which represents the capabilities of an SSH agent,
// or nil if no agent is available in this context.
func (c *ServerContext) GetAgent() agent.Agent {
c.RLock()
defer c.RUnlock()
return c.agent
if c.Parent == nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should GetAgent and GetAgentChannel return an error? Because it could potentially segfault here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that agent might be nil seems to be well handled, so I've opted to update the method docs to indicate that it might be nil rather than changing the method to return an error.

return nil
}
Comment on lines +454 to +456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't c.Parent always be set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed nil checks for methods which modify state. All the getters on the zero value of srv.ConnectionContext return zero values themselves instead of panicking (a fact that some tests rely on), so I kept the nil checks on the getters in order to preserve this behavior.

return c.Parent.GetAgent()
}

// GetAgentChannel returns the channel over which communication with the agent occurs.
// GetAgentChannel returns the channel over which communication with the agent occurs,
// or nil if no agent is available in this context.
func (c *ServerContext) GetAgentChannel() ssh.Channel {
c.RLock()
defer c.RUnlock()
return c.agentChannel
}

// SetAgent sets the agent and channel over which communication with the agent occurs.
func (c *ServerContext) SetAgent(a agent.Agent, channel ssh.Channel) {
c.Lock()
defer c.Unlock()
if c.agentChannel != nil {
c.Infof("closing previous agent channel")
c.agentChannel.Close()
if c.Parent == nil {
return nil
}
c.agentChannel = channel
c.agent = a
return c.Parent.GetAgentChannel()
}

// GetTerm returns a Terminal.
Expand Down Expand Up @@ -500,6 +493,12 @@ func (c *ServerContext) GetEnv(key string) (string, bool) {
return val, ok
}

// ImportParentEnv is used to re-synchronize env vars after
// parent context has been updated.
func (c *ServerContext) ImportParentEnv() {
c.Parent.ExportEnv(c.env)
}

// takeClosers returns all resources that should be closed and sets the properties to null
// we do this to avoid calling Close() under lock to avoid potential deadlocks
func (c *ServerContext) takeClosers() []io.Closer {
Expand All @@ -512,10 +511,6 @@ func (c *ServerContext) takeClosers() []io.Closer {
closers = append(closers, c.term)
c.term = nil
}
if c.agentChannel != nil {
closers = append(closers, c.agentChannel)
c.agentChannel = nil
}
closers = append(closers, c.closers...)
c.closers = nil
return closers
Expand Down
11 changes: 9 additions & 2 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type Server struct {
// forwarding, subsystems.
remoteClient *ssh.Client

// connectionContext is used to construct ServerContext instances
// and supports registration of connection-scoped resource closers.
connectionContext *sshutils.ConnectionContext
fspmarshall marked this conversation as resolved.
Show resolved Hide resolved

// identityContext holds identity information about the user that has
// authenticated on sconn (like system login, Teleport username, roles).
identityContext srv.IdentityContext
Expand Down Expand Up @@ -435,6 +439,8 @@ func (s *Server) Serve() {
}
s.sconn = sconn

s.connectionContext = sshutils.NewConnectionContext(s.serverConn, s.sconn)

// Take connection and extract identity information for the user from it.
s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn)
if err != nil {
Expand Down Expand Up @@ -488,6 +494,7 @@ func (s *Server) Close() error {
s.serverConn,
s.targetConn,
s.remoteClient,
s.connectionContext,
}

var errs []error
Expand Down Expand Up @@ -646,7 +653,7 @@ func (s *Server) handleChannel(nch ssh.NewChannel) {
func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
// forwarding is complete.
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down Expand Up @@ -713,7 +720,7 @@ func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) {
// There is no need for the forwarding server to initiate disconnects,
// based on teleport business logic, because this logic is already
// done on the server's terminating side.
ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext)
ctx, err := srv.NewServerContext(s.connectionContext, s, s.identityContext)
if err != nil {
ctx.Errorf("Unable to create connection context: %v.", err)
ch.Stderr().Write([]byte("Unable to create connection context."))
Expand Down
Loading