Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions integration/utmp_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/gravitational/teleport/lib/pam"
restricted "github.com/gravitational/teleport/lib/restrictedsession"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv"
"github.com/gravitational/teleport/lib/srv/regular"
"github.com/gravitational/teleport/lib/srv/uacc"
"github.com/gravitational/teleport/lib/sshutils"
Expand Down Expand Up @@ -257,8 +258,19 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx {
require.NoError(t, err)
t.Cleanup(lockWatcher.Close)

nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{
Semaphores: s.nodeClient,
AccessPoint: s.nodeClient,
LockEnforcer: lockWatcher,
Emitter: s.nodeClient,
Component: teleport.ComponentNode,
ServerID: s.nodeID,
})
require.NoError(t, err)

nodeDir := t.TempDir()
srv, err := regular.New(
ctx,
utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"},
s.server.ClusterName(),
[]ssh.Signer{s.signer},
Expand Down Expand Up @@ -286,6 +298,7 @@ func newSrvCtx(ctx context.Context, t *testing.T) *SrvCtx {
regular.SetClock(s.clock),
regular.SetUtmpPath(utmpPath, utmpPath),
regular.SetLockWatcher(lockWatcher),
regular.SetSessionController(nodeSessionController),
)
require.NoError(t, err)
s.srv = srv
Expand Down
52 changes: 50 additions & 2 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2288,7 +2288,29 @@ func (process *TeleportProcess) initSSH() error {

storagePresence := local.NewPresenceService(process.storage.BackendStorage)

s, err := regular.New(cfg.SSH.Addr,
// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}

sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{
Semaphores: authClient,
AccessPoint: authClient,
LockEnforcer: lockWatcher,
Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer},
Component: teleport.ComponentNode,
Logger: process.log.WithField(trace.Component, "sessionctrl"),
TracerProvider: process.TracingProvider,
ServerID: serverID,
})
if err != nil {
return trace.Wrap(err)
}

s, err := regular.New(
process.ExitContext(),
cfg.SSH.Addr,
cfg.Hostname,
[]ssh.Signer{conn.ServerIdentity.KeySigner},
authClient,
Expand Down Expand Up @@ -2320,6 +2342,8 @@ func (process *TeleportProcess) initSSH() error {
regular.SetCreateHostUser(!cfg.SSH.DisableCreateHostUser),
regular.SetStoragePresenceService(storagePresence),
regular.SetInventoryControlHandle(process.inventoryHandle),
regular.SetTracerProvider(process.TracingProvider),
regular.SetSessionController(sessionController),
)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3627,7 +3651,29 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
proxyRouter = router
}

sshProxy, err := regular.New(cfg.Proxy.SSHAddr,
// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}

sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{
Semaphores: accessPoint,
AccessPoint: accessPoint,
LockEnforcer: lockWatcher,
Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer},
Component: teleport.ComponentProxy,
Logger: process.log.WithField(trace.Component, "sessionctrl"),
TracerProvider: process.TracingProvider,
ServerID: serverID,
})
if err != nil {
return trace.Wrap(err)
}

sshProxy, err := regular.New(
process.ExitContext(),
cfg.SSH.Addr,
cfg.Hostname,
[]ssh.Signer{conn.ServerIdentity.KeySigner},
accessPoint,
Expand All @@ -3651,6 +3697,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
// accurately checked later when an SCP/SFTP request hits the
// destination Node.
regular.SetAllowFileCopying(true),
regular.SetTracerProvider(process.TracingProvider),
regular.SetSessionController(sessionController),
)
if err != nil {
return trace.Wrap(err)
Expand Down
20 changes: 14 additions & 6 deletions lib/services/semaphore.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
Expand All @@ -39,10 +40,16 @@ type SemaphoreLockConfig struct {
TickRate time.Duration
// Params holds the semaphore lease acquisition parameters.
Params types.AcquireSemaphoreRequest
// Clock used to alter time in tests
Clock clockwork.Clock
}

// CheckAndSetDefaults checks and sets default parameters
func (l *SemaphoreLockConfig) CheckAndSetDefaults() error {
if l.Clock == nil {
l.Clock = clockwork.NewRealClock()
}

if l.Service == nil {
return trace.BadParameter("missing semaphore service")
}
Expand All @@ -59,7 +66,7 @@ func (l *SemaphoreLockConfig) CheckAndSetDefaults() error {
return trace.BadParameter("tick-rate must be less than expiry")
}
if l.Params.Expires.IsZero() {
l.Params.Expires = time.Now().UTC().Add(l.Expiry)
l.Params.Expires = l.Clock.Now().UTC().Add(l.Expiry)
}
if err := l.Params.Check(); err != nil {
return trace.Wrap(err)
Expand All @@ -73,7 +80,7 @@ type SemaphoreLock struct {
cfg SemaphoreLockConfig
lease0 types.SemaphoreLease
retry retryutils.Retry
ticker *time.Ticker
ticker clockwork.Ticker
doneC chan struct{}
closeOnce sync.Once
renewalC chan struct{}
Expand Down Expand Up @@ -140,7 +147,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) {
// cancellation/expiry.
return
}
if lease.Expires.After(time.Now().UTC()) {
if lease.Expires.After(l.cfg.Clock.Now().UTC()) {
// parent context is closed. create orphan context with generous
// timeout for lease cancellation scope. this will not block any
// caller that is not explicitly waiting on the final error value.
Expand All @@ -157,7 +164,7 @@ func (l *SemaphoreLock) keepAlive(ctx context.Context) {
Outer:
for {
select {
case tick := <-l.ticker.C:
case tick := <-l.ticker.Chan():
leaseContext, leaseCancel := context.WithDeadline(ctx, lease.Expires)
nextLease := lease
nextLease.Expires = tick.Add(l.cfg.Expiry)
Expand Down Expand Up @@ -185,7 +192,7 @@ Outer:
l.retry.Inc()
select {
case <-l.retry.After():
case tick = <-l.ticker.C:
case tick = <-l.ticker.Chan():
// check to make sure that we still have some time on the lease. the default tick rate would have
// us waking _as_ the lease expires here, but if we're working with a higher tick rate, its worth
// retrying again.
Expand Down Expand Up @@ -247,6 +254,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph
Max: cfg.Expiry / 4,
Step: cfg.Expiry / 16,
Jitter: retryutils.NewJitter(),
Clock: cfg.Clock,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -259,7 +267,7 @@ func AcquireSemaphoreLock(ctx context.Context, cfg SemaphoreLockConfig) (*Semaph
cfg: cfg,
lease0: *lease,
retry: retry,
ticker: time.NewTicker(cfg.TickRate),
ticker: cfg.Clock.NewTicker(cfg.TickRate),
doneC: make(chan struct{}),
renewalC: make(chan struct{}),
cond: sync.NewCond(&sync.Mutex{}),
Expand Down
31 changes: 12 additions & 19 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
trace.ComponentFields: fields,
})

lockTargets, err := ComputeLockTargets(srv, identityContext)
clusterName, err := srv.GetAccessPoint().GetClusterName()
if err != nil {
return nil, nil, trace.Wrap(err)
childErr := child.Close()
return nil, nil, trace.NewAggregate(err, childErr)
}

monitorConfig := MonitorConfig{
LockWatcher: child.srv.GetLockWatcher(),
LockTargets: lockTargets,
LockTargets: ComputeLockTargets(clusterName.GetClusterName(), srv.HostUUID(), identityContext),
LockingMode: identityContext.AccessChecker.LockingMode(authPref.GetLockingMode()),
DisconnectExpiredCert: child.disconnectExpiredCert,
ClientIdleTimeout: child.clientIdleTimeout,
Expand Down Expand Up @@ -1147,28 +1149,19 @@ func newUaccMetadata(c *ServerContext) (*UaccMetadata, error) {
}, nil
}

// ComputeLockTargets computes lock targets inferred from a Server
// and an IdentityContext.
func ComputeLockTargets(s Server, id IdentityContext) ([]types.LockTarget, error) {
clusterName, err := s.GetAccessPoint().GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
// ComputeLockTargets computes lock targets inferred from the clusterName, serverID and IdentityContext.
func ComputeLockTargets(clusterName, serverID string, id IdentityContext) []types.LockTarget {
lockTargets := []types.LockTarget{
{User: id.TeleportUser},
{Login: id.Login},
{Node: s.HostUUID()},
{Node: auth.HostFQDN(s.HostUUID(), clusterName.GetClusterName())},
{Node: serverID},
{Node: auth.HostFQDN(serverID, clusterName)},
{MFADevice: id.Certificate.Extensions[teleport.CertExtensionMFAVerified]},
}
roles := apiutils.Deduplicate(append(id.AccessChecker.RoleNames(), id.UnmappedRoles...))
lockTargets = append(lockTargets,
services.RolesToLockTargets(roles)...,
)
lockTargets = append(lockTargets,
services.AccessRequestsToLockTargets(id.ActiveRequests)...,
)
return lockTargets, nil
lockTargets = append(lockTargets, services.RolesToLockTargets(roles)...)
lockTargets = append(lockTargets, services.AccessRequestsToLockTargets(id.ActiveRequests)...)
return lockTargets
}

// SetRequest sets the ssh request that was issued by the client.
Expand Down
Loading