diff --git a/lib/srv/desktop/windows_server.go b/lib/srv/desktop/windows_server.go index ddde9418d5398..b7feba3783cb9 100644 --- a/lib/srv/desktop/windows_server.go +++ b/lib/srv/desktop/windows_server.go @@ -902,10 +902,7 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger, TeleportUser: identity.Username, ServerID: s.cfg.Heartbeat.HostUUID, IdleTimeoutMessage: netConfig.GetClientIdleTimeoutMessage(), - MessageWriter: &monitorErrorSender{ - log: log, - tdpConn: tdpConn, - }, + MessageWriter: &monitorErrorSender{tdpConn: tdpConn}, } // UpdateClientActivity before starting monitor to @@ -1303,15 +1300,12 @@ func (s *WindowsService) trackSession(ctx context.Context, id *tlsca.Identity, w // monitor disconnect messages back to the frontend // over the tdp.Conn type monitorErrorSender struct { - log logrus.FieldLogger tdpConn *tdp.Conn } func (m *monitorErrorSender) WriteString(s string) (n int, err error) { if err := m.tdpConn.SendNotification(s, tdp.SeverityError); err != nil { - errMsg := fmt.Sprintf("Failed to send TDP error message %v: %v", s, err) - m.log.Error(errMsg) - return 0, trace.Errorf(errMsg) + return 0, trace.Wrap(err, "sending TDP error message") } return len(s), nil diff --git a/lib/srv/monitor.go b/lib/srv/monitor.go index 6867f52b3e245..6247627e19c61 100644 --- a/lib/srv/monitor.go +++ b/lib/srv/monitor.go @@ -356,13 +356,16 @@ func (w *Monitor) start(lockWatch types.Watcher) { clientLastActive := w.Tracker.GetClientLastActive() since := w.Clock.Since(clientLastActive) if since >= w.ClientIdleTimeout { - reason := "client reported no activity" + reason := "Client reported no activity" if !clientLastActive.IsZero() { - reason = fmt.Sprintf("client is idle for %v, exceeded idle timeout of %v", - since, w.ClientIdleTimeout) + reason = fmt.Sprintf("Client exceeded idle timeout of %v", w.ClientIdleTimeout) } - if w.MessageWriter != nil && w.IdleTimeoutMessage != "" { - if _, err := w.MessageWriter.WriteString(w.IdleTimeoutMessage); err != nil { + if w.MessageWriter != nil { + msg := w.IdleTimeoutMessage + if msg == "" { + msg = reason + } + if _, err := w.MessageWriter.WriteString(msg); err != nil { w.Entry.WithError(err).Warn("Failed to send idle timeout message.") } } diff --git a/lib/srv/monitor_test.go b/lib/srv/monitor_test.go index cf2c43f524d42..40d3f5304d8c7 100644 --- a/lib/srv/monitor_test.go +++ b/lib/srv/monitor_test.go @@ -22,6 +22,7 @@ import ( "context" "io" "net" + "strings" "testing" "time" @@ -41,6 +42,9 @@ import ( ) func newTestMonitor(ctx context.Context, t *testing.T, asrv *auth.TestAuthServer, mut ...func(*MonitorConfig)) (*mockTrackingConn, *eventstest.ChannelEmitter, MonitorConfig) { + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + conn := &mockTrackingConn{closedC: make(chan struct{})} emitter := eventstest.NewChannelEmitter(1) cfg := MonitorConfig{ @@ -252,6 +256,29 @@ func TestMonitorStaleLocks(t *testing.T) { require.Equal(t, services.StrictLockingModeAccessDenied.Error(), (<-emitter.C()).(*apievents.ClientDisconnect).Reason) } +func TestWritesDisconnectMessage(t *testing.T) { + asrv, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ + Dir: t.TempDir(), + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, asrv.Close()) }) + + var sw strings.Builder + + ctx := context.Background() + clock := clockwork.NewFakeClock() + conn, _, _ := newTestMonitor(ctx, t, asrv, func(cfg *MonitorConfig) { + cfg.ClientIdleTimeout = 1 * time.Second + cfg.Clock = clock + cfg.MessageWriter = &sw + }) + clock.BlockUntil(1) + clock.Advance(2 * time.Second) + <-conn.closedC + require.Contains(t, sw.String(), "exceeded idle timeout") +} + type mockTrackingConn struct { net.Conn closedC chan struct{}