Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions lib/srv/desktop/windows_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -902,10 +902,7 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger,
TeleportUser: identity.Username,
ServerID: s.cfg.Heartbeat.HostUUID,
IdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(),
MessageWriter: &monitorErrorSender{
log: log,
tdpConn: tdpConn,
},
MessageWriter: &monitorErrorSender{tdpConn: tdpConn},
}

// UpdateClientActivity before starting monitor to
Expand Down Expand Up @@ -1303,15 +1300,12 @@ func (s *WindowsService) trackSession(ctx context.Context, id *tlsca.Identity, w
// monitor disconnect messages back to the frontend
// over the tdp.Conn
type monitorErrorSender struct {
log logrus.FieldLogger
tdpConn *tdp.Conn
}

func (m *monitorErrorSender) WriteString(s string) (n int, err error) {
if err := m.tdpConn.SendNotification(s, tdp.SeverityError); err != nil {
errMsg := fmt.Sprintf("Failed to send TDP error message %v: %v", s, err)
m.log.Error(errMsg)
return 0, trace.Errorf(errMsg)
return 0, trace.Wrap(err, "sending TDP error message")
}

return len(s), nil
Expand Down
13 changes: 8 additions & 5 deletions lib/srv/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,16 @@ func (w *Monitor) start(lockWatch types.Watcher) {
clientLastActive := w.Tracker.GetClientLastActive()
since := w.Clock.Since(clientLastActive)
if since >= w.ClientIdleTimeout {
reason := "client reported no activity"
reason := "Client reported no activity"
if !clientLastActive.IsZero() {
reason = fmt.Sprintf("client is idle for %v, exceeded idle timeout of %v",
since, w.ClientIdleTimeout)
reason = fmt.Sprintf("Client exceeded idle timeout of %v", w.ClientIdleTimeout)
}
if w.MessageWriter != nil && w.IdleTimeoutMessage != "" {
if _, err := w.MessageWriter.WriteString(w.IdleTimeoutMessage); err != nil {
if w.MessageWriter != nil {
msg := w.IdleTimeoutMessage
if msg == "" {
msg = reason
}
if _, err := w.MessageWriter.WriteString(msg); err != nil {
w.Entry.WithError(err).Warn("Failed to send idle timeout message.")
}
}
Expand Down
27 changes: 27 additions & 0 deletions lib/srv/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"io"
"net"
"strings"
"testing"
"time"

Expand All @@ -41,6 +42,9 @@ import (
)

func newTestMonitor(ctx context.Context, t *testing.T, asrv *auth.TestAuthServer, mut ...func(*MonitorConfig)) (*mockTrackingConn, *eventstest.ChannelEmitter, MonitorConfig) {
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)

conn := &mockTrackingConn{closedC: make(chan struct{})}
emitter := eventstest.NewChannelEmitter(1)
cfg := MonitorConfig{
Expand Down Expand Up @@ -252,6 +256,29 @@ func TestMonitorStaleLocks(t *testing.T) {
require.Equal(t, services.StrictLockingModeAccessDenied.Error(), (<-emitter.C()).(*apievents.ClientDisconnect).Reason)
}

func TestWritesDisconnectMessage(t *testing.T) {
asrv, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
Dir: t.TempDir(),
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, asrv.Close()) })

var sw strings.Builder

ctx := context.Background()
clock := clockwork.NewFakeClock()
conn, _, _ := newTestMonitor(ctx, t, asrv, func(cfg *MonitorConfig) {
cfg.ClientIdleTimeout = 1 * time.Second
cfg.Clock = clock
cfg.MessageWriter = &sw
})
clock.BlockUntil(1)
clock.Advance(2 * time.Second)
<-conn.closedC
require.Contains(t, sw.String(), "exceeded idle timeout")
}

type mockTrackingConn struct {
net.Conn
closedC chan struct{}
Expand Down