Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ func TestNewSession(t *testing.T) {
ses, err = newSession(ctx, nc, nil, env, nil, nil, nil, true)
require.NoError(t, err)
require.NotNil(t, ses)
// the session ID must be unset from tne environ map, if we are not joining a session:
require.Empty(t, ses.id)
}

// TestProxyConnection verifies that client or server-side disconnect
Expand Down
7 changes: 1 addition & 6 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ const (
)

type NodeSession struct {
// id is the Teleport session ID
id session.ID

// env is the environment variables that need to be created
// on the server for this session
env map[string]string
Expand Down Expand Up @@ -141,8 +138,6 @@ func newSession(ctx context.Context,
return nil, trace.Wrap(err)
}

ns.id = session.ID(sessionID)

if ns.terminal.IsAttached() {
err = ns.terminal.Resize(int16(terminalSize.Width), int16(terminalSize.Height))
if err != nil {
Expand All @@ -151,7 +146,7 @@ func newSession(ctx context.Context,

}

ns.env[sshutils.SessionEnvVar] = string(ns.id)
ns.env[sshutils.SessionEnvVar] = sessionID
}

// Close the Terminal when finished.
Expand Down
79 changes: 40 additions & 39 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ import (
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/x11"
Expand Down Expand Up @@ -587,51 +586,48 @@ func (c *ServerContext) ID() int {
return c.id
}

// SessionID returns the ID of the session in the context.
func (c *ServerContext) SessionID() rsession.ID {
return c.ConnectionContext.GetSessionID()
}

// GetServer returns the underlying server which this context was created in.
func (c *ServerContext) GetServer() Server {
return c.srv
}
// GetJoinParams gets join params if they are set.
//
// These params (env vars) are set synchronously between the "session" channel request
// and the "shell" / "exec" channel request. Therefore, these params are only guaranteed
// to be accurately set during and after the "shell" / "exec" channel request.
//
// TODO(Joerger): Rather than relying on the out-of-band env var params, we should
// provide session params upfront as extra data in the session channel request.
func (c *ServerContext) GetJoinParams() (string, types.SessionParticipantMode) {
c.mu.RLock()
defer c.mu.RUnlock()

// CreateOrJoinSession will look in the SessionRegistry for the session ID. If
// no session is found, a new one is created. If one is found, it is returned.
func (c *ServerContext) CreateOrJoinSession(ctx context.Context, reg *SessionRegistry) error {
c.mu.Lock()
defer c.mu.Unlock()
// As SSH conversation progresses, at some point a session will be created and
// its ID will be added to the environment
ssid, found := c.getEnvLocked(sshutils.SessionEnvVar)
sid, found := c.getEnvLocked(sshutils.SessionEnvVar)
if !found {
return nil
return "", ""
}

// make sure whatever session is requested is a valid session
id, err := rsession.ParseID(ssid)
if err != nil {
return trace.BadParameter("invalid session ID %s", ssid)
mode := types.SessionPeerMode // default
if modeString, found := c.getEnvLocked(teleport.EnvSSHJoinMode); found {
mode = types.SessionParticipantMode(modeString)
}

// update ctx with the session if it exists
if sess, found := reg.findSession(*id); found {
c.session = sess
c.ConnectionContext.SetSessionID(*id)
c.Logger.DebugContext(ctx, "Joining active SSH session", "session_id", c.session.id)
} else {
// TODO(capnspacehook): DELETE IN 19.0.0 - by then all supported
// clients should only set TELEPORT_SESSION when they want to
// join a session. Always return an error instead of using a
// new ID.
//
// The session ID the client sent was not found, ignore it; the
// connection's ID will be used as the session ID later.
c.Logger.DebugContext(ctx, "Sent session ID not found, using connection ID")
return sid, mode
}

// SessionID returns the ID of the session in the context.
//
// This value is not set until during and after the "shell" / "exec" channel request.
func (c *ServerContext) SessionID() string {
c.mu.RLock()
defer c.mu.RUnlock()

if c.session != nil {
return string(c.session.id)
}

return nil
return ""
}

// GetServer returns the underlying server which this context was created in.
func (c *ServerContext) GetServer() Server {
return c.srv
}

// TrackActivity keeps track of all activity on ssh.Channel. The caller should
Expand Down Expand Up @@ -716,6 +712,11 @@ func (c *ServerContext) setSession(ctx context.Context, sess *session, ch ssh.Ch
}

// getSession returns the context's session
//
// The associated session is not set in the server context until a
// shell / exec channel has been initiated for the session, so out-of-band
// session requests that can occur before these channel requests should
// consider fallback mechanisms.
func (c *ServerContext) getSession() *session {
c.mu.RLock()
defer c.mu.RUnlock()
Expand Down Expand Up @@ -1257,7 +1258,7 @@ func (c *ServerContext) GetExecRequest() (Exec, error) {

func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
return apievents.SessionMetadata{
SessionID: string(c.SessionID()),
SessionID: c.SessionID(),
WithMFA: c.Identity.UnmappedIdentity.MFAVerified,
PrivateKeyPolicy: string(c.Identity.UnmappedIdentity.PrivateKeyPolicy),
}
Expand Down
92 changes: 0 additions & 92 deletions lib/srv/ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@
package srv

import (
"context"
"testing"

"github.com/gogo/protobuf/proto"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"

Expand All @@ -33,9 +31,7 @@ import (
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
)

func TestCheckSFTPAllowed(t *testing.T) {
Expand Down Expand Up @@ -285,91 +281,3 @@ func TestSSHAccessLockTargets(t *testing.T) {
}
})
}

func TestCreateOrJoinSession(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

srv := newMockServer(t)
registry, err := NewSessionRegistry(SessionRegistryConfig{
clock: srv.clock,
Srv: srv,
SessionTrackerService: srv.auth,
})
require.NoError(t, err)

runningSessionID := rsession.NewID()
sess, _, err := newSession(ctx, runningSessionID, registry, newTestServerContext(t, srv, nil, &decisionpb.SSHAccessPermit{}), newMockSSHChannel(), sessionTypeInteractive)
require.NoError(t, err)

t.Cleanup(sess.Stop)

registry.sessions[runningSessionID] = sess

tests := []struct {
name string
sessionID string
expectedErr bool
wantSameSessionID bool
}{
{
name: "no session ID",
},
{
name: "new session ID",
sessionID: string(rsession.NewID()),
wantSameSessionID: false,
},
{
name: "existing session ID",
sessionID: runningSessionID.String(),
wantSameSessionID: true,
},
{
name: "existing session ID in Windows format",
sessionID: "{" + runningSessionID.String() + "}",
wantSameSessionID: true,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(ctx)
defer cancel()

parsedSessionID := new(rsession.ID)
var err error
if tt.sessionID != "" {
parsedSessionID, err = rsession.ParseID(tt.sessionID)
require.NoError(t, err)
}

scx := newTestServerContext(t, srv, nil, nil)
if tt.sessionID != "" {
scx.SetEnv(sshutils.SessionEnvVar, tt.sessionID)
}

err = scx.CreateOrJoinSession(ctx, registry)
if tt.expectedErr {
require.True(t, trace.IsNotFound(err))
} else {
require.NoError(t, err)
}

sessID := scx.GetSessionID()
require.False(t, sessID.IsZero())
if tt.wantSameSessionID {
require.Equal(t, parsedSessionID.String(), sessID.String())
require.Equal(t, *parsedSessionID, scx.GetSessionID())
} else {
require.NotEqual(t, parsedSessionID.String(), sessID.String())
require.NotEqual(t, *parsedSessionID, scx.GetSessionID())
}
})
}
}
2 changes: 1 addition & 1 deletion lib/srv/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestEmitExecAuditEvent(t *testing.T) {
require.Equal(t, "abc", execEvent.ForwardedBy)
require.Equal(t, expectedHostname, execEvent.ServerHostname)
require.Equal(t, "testNamespace", execEvent.ServerNamespace)
require.NotEqual(t, "xxx", execEvent.SessionID)
require.Equal(t, "xxx", execEvent.SessionID)
require.Equal(t, "10.0.0.5:4817", execEvent.RemoteAddr)
require.Equal(t, "127.0.0.1:3022", execEvent.LocalAddr)
require.NotEmpty(t, events.EventID)
Expand Down
14 changes: 0 additions & 14 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1187,20 +1187,6 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) {
defer s.logger.DebugContext(ctx, "Closing session request", "target_addr", s.sconn.RemoteAddr(), "session_id", scx.ID())

for {
// Update the context with the session ID.
err := scx.CreateOrJoinSession(ctx, s.sessionRegistry)
if err != nil {
s.logger.ErrorContext(ctx, "unable create or join session", "error", err)

// Write the error to channel and close it.
s.stderrWrite(ctx, ch, fmt.Sprintf("unable to update context: %v", err))
_, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: teleport.RemoteCommandFailure}))
if err != nil {
s.logger.ErrorContext(ctx, "Failed to send exit status", "error", err)
}
return
}

select {
case result := <-scx.SubsystemResultCh:
// Subsystem has finished executing, close the channel and session.
Expand Down
10 changes: 10 additions & 0 deletions lib/srv/git/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
Expand Down Expand Up @@ -451,6 +452,7 @@ type sessionContext struct {
channel ssh.Channel
remoteSession *tracessh.Session
waitExec chan error
sessionID rsession.ID
}

func newSessionContext(serverCtx *srv.ServerContext, ch ssh.Channel, remoteSession *tracessh.Session) *sessionContext {
Expand All @@ -459,9 +461,17 @@ func newSessionContext(serverCtx *srv.ServerContext, ch ssh.Channel, remoteSessi
channel: ch,
remoteSession: remoteSession,
waitExec: make(chan error, 1),
sessionID: rsession.NewID(),
}
}

func (c *sessionContext) GetSessionMetadata() apievents.SessionMetadata {
// Overwrite with our own session ID.
metadata := c.ServerContext.GetSessionMetadata()
metadata.SessionID = c.sessionID.String()
return metadata
}

// dispatch executes an incoming request. If successful, it returns the ok value
// for the reply. Otherwise, it returns the error it encountered.
func (s *ForwardServer) dispatch(ctx context.Context, sctx *sessionContext, req *ssh.Request) (bool, error) {
Expand Down
1 change: 1 addition & 0 deletions lib/srv/git/forward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func TestForwardServer(t *testing.T) {
require.True(t, ok)
assert.Equal(t, libevents.GitCommandEvent, gitEvent.Metadata.Type)
assert.Equal(t, libevents.GitCommandCode, gitEvent.Metadata.Code)
assert.NotEmpty(t, gitEvent.SessionID)
assert.Equal(t, "alice", gitEvent.User)
assert.Equal(t, "0", gitEvent.CommandMetadata.ExitCode)
assert.Equal(t, "git-upload-pack", gitEvent.Service)
Expand Down
15 changes: 0 additions & 15 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1616,21 +1616,6 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec
})

for {
// update scx with the session ID:
if !s.proxyMode {
err := scx.CreateOrJoinSession(ctx, s.reg)
if err != nil {
scx.Logger.ErrorContext(ctx, "Unable to update context", "error", err)

// write the error to channel and close it
s.writeStderr(ctx, trackingChan, fmt.Sprintf("unable to update context: %v", err))
_, err := trackingChan.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: teleport.RemoteCommandFailure}))
if err != nil {
scx.Logger.ErrorContext(ctx, "Failed to send exit status", "error", err)
}
return
}
}
select {
case creq := <-scx.SubsystemResultCh:
// this means that subsystem has finished executing and
Expand Down
Loading
Loading