Skip to content

Commit

Permalink
Restart entire node on tunnel collapse (#8102)
Browse files Browse the repository at this point in the history
Fixes #7606, where a node doesn't notice when the tunnel port changes. 

Imagine you have a cluster with a node connected in via a tunnel through a proxy `proxy.example.com` on port `3024`

Now change the proxy config so that `tunnel_public_address` is `proxy.example.com:4024`. You either restart the proxy, or reload the proxy config with a `SIGHUP`.

...and then the node
  a) loses its connection to auth (because the tunnel is gone), and 
  b)  _doesn't reconnect_, because even though the proxy address hasn't changed,
      the node has cached the old tunnel_public_address and keeps trying to connect
      to that.

You can always manually restart the node to have it reconnect, but that would be a pain if you have thousands of nodes.

In order to not have to manually restart all nodes, this change implements a check for a connection failures to the auth server, and re-starts the node if there are multiple connection failures in a given period of time. The check as-implemented piggybacks on the node's "common.rotate" service, which can already restart the node in certain circumstances, and uses the success of the periodic rotation sync as a proxy for the health of the node's connection to the auth server.

See-Also: #7606
  • Loading branch information
tcsc authored Nov 17, 2021
1 parent d67e9b3 commit 97c18fa
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 18 deletions.
111 changes: 111 additions & 0 deletions integration/restart_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package integration

import (
"context"
"testing"
"time"

"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/service"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

// TestLostConnectionToAuthCausesReload tests that a lost connection to the auth server
// will eventually restart a node
func TestLostConnectionToAuthCausesReload(t *testing.T) {
// Because testing that the node does a full restart is a bit tricky when
// running a cluster from inside a test runner (i.e. we don't want to
// SIGTERM the test runner), we will watch for the node emitting a
// `TeleportReload` even instead. In a proper Teleport instance, this
// event would be picked up at the Supervisor level and would eventually
// cause the instance to gracefully restart.

require := require.New(t)
log := log.StandardLogger()

log.Info(">>> Entering Test")

// InsecureDevMode needed for SSH connections
// TODO(tcsc): surface this as per-server config (see also issue #8913)
lib.SetInsecureDevMode(true)
defer lib.SetInsecureDevMode(false)

// GIVEN a cluster with a running auth+proxy instance....
log.Info(">>> Creating cluster")
keygen := testauthority.New()
privateKey, publicKey, err := keygen.GenerateKeyPair("")
require.NoError(err)
auth := NewInstance(InstanceConfig{
ClusterName: "test-tunnel-collapse",
HostID: "auth",
Priv: privateKey,
Pub: publicKey,
Ports: standardPortSetup(),
log: log,
})

log.Info(">>> Creating auth-proxy...")
authConf := service.MakeDefaultConfig()
authConf.Hostname = Host
authConf.Auth.Enabled = true
authConf.Proxy.Enabled = true
authConf.SSH.Enabled = false
authConf.Proxy.DisableWebInterface = true
authConf.Proxy.DisableDatabaseProxy = true
require.NoError(auth.CreateEx(t, nil, authConf))
t.Cleanup(func() { require.NoError(auth.StopAll()) })

log.Info(">>> Start auth-proxy...")
require.NoError(auth.Start())

// ... and an SSH node connected via a reverse tunnel configured to
// reload after only a few failed connection attempts per minute
log.Info(">>> Creating and starting node...")
nodeCfg := service.MakeDefaultConfig()
nodeCfg.Hostname = Host
nodeCfg.SSH.Enabled = true
nodeCfg.RotationConnectionInterval = 1 * time.Second
nodeCfg.RestartThreshold = service.Rate{Amount: 3, Time: 1 * time.Minute}
node, err := auth.StartReverseTunnelNode(nodeCfg)
require.NoError(err)

// WHEN I stop the auth node (and, by implication, disrupt the ssh node's
// connection to it)
log.Info(">>> Stopping auth node")
auth.StopAuth(false)

// EXPECT THAT the ssh node will eventually issue a reload request
log.Info(">>> Waiting for node restart request.")
waitCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()

eventCh := make(chan service.Event)
node.WaitForEvent(waitCtx, service.TeleportReloadEvent, eventCh)
select {
case e := <-eventCh:
log.Infof(">>> Received Reload event: %v. Test passed.", e)

case <-waitCtx.Done():
require.FailNow("Timed out", "Timed out waiting for reload event")
}

log.Info(">>> TEST COMPLETE")
}
12 changes: 12 additions & 0 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,18 @@ var (

// AsyncBufferSize is a default buffer size for async emitters
AsyncBufferSize = 1024

// ConnectionErrorMeasurementPeriod is the maximum age of a connection error
// to be considered when deciding to restart the process. The process will
// restart if there has been more than `MaxConnectionErrorsBeforeRestart`
// errors in the preceding `ConnectionErrorMeasurementPeriod`
ConnectionErrorMeasurementPeriod = 2 * time.Minute

// MaxConnectionErrorsBeforeRestart is the number or allowable network errors
// in the previous `ConnectionErrorMeasurementPeriod`. The process will
// restart if there has been more than `MaxConnectionErrorsBeforeRestart`
// errors in the preceding `ConnectionErrorMeasurementPeriod`
MaxConnectionErrorsBeforeRestart = 5
)

// Default connection limits, they can be applied separately on any of the Teleport
Expand Down
19 changes: 17 additions & 2 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ type AgentPool struct {
log *log.Entry
cfg AgentPoolConfig
proxyTracker *track.Tracker
ctx context.Context
cancel context.CancelFunc

// ctx controls the lifespan of the agent pool, and is used to control
// all of the sub-processes it spawns.
ctx context.Context
cancel context.CancelFunc
// spawnLimiter limits agent spawn rate
spawnLimiter utils.Retry

Expand Down Expand Up @@ -186,16 +189,25 @@ func (m *AgentPool) Wait() {
<-m.ctx.Done()
}

// processSeekEvents receives acquisition messages from the ProxyTracker
// (i.e. "I've found a proxy that you may not know about") and routes the
// new proxy address to the AgentPool, which will manage the connection
// to that address.
func (m *AgentPool) processSeekEvents() {
limiter := m.spawnLimiter.Clone()
for {
select {
case <-m.ctx.Done():
m.log.Debugf("Halting seek event processing (pool closing)")
return

// The proxy tracker has given us permission to act on a given
// tunnel address
case lease := <-m.proxyTracker.Acquire():
m.log.Debugf("Seeking: %+v.", lease.Key())
m.withLock(func() {
// Note that ownership of the lease is transferred to agent
// pool for the lifetime of the connection
if err := m.addAgent(lease); err != nil {
m.log.WithError(err).Errorf("Failed to add agent.")
}
Expand Down Expand Up @@ -267,6 +279,9 @@ func (m *AgentPool) getReverseTunnelDetails(addr utils.NetAddr) *reverseTunnelDe
return agents[0].reverseTunnelDetails
}

// addAgent adds a new agent to the pool. Note that ownership of the lease
// transfers into the AgentPool, and will be released when the AgentPool
// is done with it.
func (m *AgentPool) addAgent(lease track.Lease) error {
addr := lease.Key().(utils.NetAddr)
agent, err := NewAgent(AgentConfig{
Expand Down
22 changes: 22 additions & 0 deletions lib/service/cfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ import (
"github.com/jonboulle/clockwork"
)

// Rate describes a rate ratio, i.e. the number of "events" that happen over
// some unit time period
type Rate struct {
Amount int
Time time.Duration
}

// Config structure is used to initialize _all_ services Teleport can run.
// Some settings are global (like DataDir) while others are grouped into
// sections, like AuthConfig
Expand Down Expand Up @@ -231,6 +238,15 @@ type Config struct {

// PluginRegistry allows adding enterprise logic to Teleport services
PluginRegistry plugin.Registry

// RotationConnectionInterval is the interval between connection
// attempts as used by the rotation state service
RotationConnectionInterval time.Duration

// RestartThreshold describes the number of connection failures per
// unit time that the node can sustain before restarting itself, as
// measured by the rotation state service.
RestartThreshold Rate
}

// ApplyToken assigns a given token to all internal services but only if token
Expand Down Expand Up @@ -1021,6 +1037,12 @@ func ApplyDefaults(cfg *Config) {
// Windows desktop service is disabled by default.
cfg.WindowsDesktop.Enabled = false
defaults.ConfigureLimiter(&cfg.WindowsDesktop.ConnLimiter)

cfg.RotationConnectionInterval = defaults.HighResPollingPeriod
cfg.RestartThreshold = Rate{
Amount: defaults.MaxConnectionErrorsBeforeRestart,
Time: defaults.ConnectionErrorMeasurementPeriod,
}
}

// ApplyFIPSDefaults updates default configuration to be FedRAMP/FIPS 140-2
Expand Down
5 changes: 5 additions & 0 deletions lib/service/cfg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ func TestDefaultConfig(t *testing.T) {
proxy := config.Proxy
require.Equal(t, proxy.Limiter.MaxConnections, int64(defaults.LimiterMaxConnections))
require.Equal(t, proxy.Limiter.MaxNumberOfUsers, defaults.LimiterMaxConcurrentUsers)

// Misc levers and dials
require.Equal(t, config.RotationConnectionInterval, defaults.HighResPollingPeriod)
require.Equal(t, config.RestartThreshold.Amount, defaults.MaxConnectionErrorsBeforeRestart)
require.Equal(t, config.RestartThreshold.Time, defaults.ConnectionErrorMeasurementPeriod)
}

// TestCheckApp validates application configuration.
Expand Down
25 changes: 21 additions & 4 deletions lib/service/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,18 +437,35 @@ func (process *TeleportProcess) periodicSyncRotationState() error {
}

periodic := interval.New(interval.Config{
Duration: defaults.HighResPollingPeriod,
FirstDuration: utils.HalfJitter(defaults.HighResPollingPeriod),
Duration: process.Config.RotationConnectionInterval,
FirstDuration: utils.HalfJitter(process.Config.RotationConnectionInterval),
Jitter: utils.NewSeventhJitter(),
})
defer periodic.Stop()

errors := utils.NewTimedCounter(process.Clock, process.Config.RestartThreshold.Time)

for {
err := process.syncRotationStateCycle()
if err == nil {
return nil
}
process.log.Warningf("Sync rotation state cycle failed: %v, going to retry after ~%v.", err, defaults.HighResPollingPeriod)
process.log.WithError(err).Warning("Sync rotation state cycle failed")

// If we have had a *lot* of failures very recently, then it's likely that our
// route to the auth server is gone. If we're using a tunnel then it's possible
// that the proxy has been reconfigured and the tunnel address has moved.
count := errors.Increment()
process.log.Warnf("%d connection errors in last %v.", count, process.Config.RestartThreshold.Time)
if count > process.Config.RestartThreshold.Amount {
// signal quit
process.log.Error("Connection error threshold exceeded. Asking for a graceful restart.")
process.BroadcastEvent(Event{Name: TeleportReloadEvent})
return nil
}

process.log.Warningf("Retrying in ~%v", process.Config.RotationConnectionInterval)

select {
case <-periodic.Next():
case <-process.GracefulExitContext().Done():
Expand All @@ -460,7 +477,7 @@ func (process *TeleportProcess) periodicSyncRotationState() error {
// syncRotationCycle executes a rotation cycle that returns:
//
// * nil whenever rotation state leads to teleport reload event
// * error whenever rotation sycle has to be restarted
// * error whenever rotation cycle has to be restarted
//
// the function accepts extra delay timer extraDelay in case if parent
// function needs a
Expand Down
71 changes: 71 additions & 0 deletions lib/utils/timed_counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package utils

import (
"time"

"github.com/jonboulle/clockwork"
)

// TimedCounter is essentially a lightweight rate calculator. It counts events
// that happen over a period of time, e.g. have there been more than 4 errors
// in the last 30 seconds. Automatically expires old events so they are not
// included in the count. Not safe for concurrent use.
type TimedCounter struct {
clock clockwork.Clock
timeout time.Duration
events []time.Time
}

// TimedCounted creates a new timed counter with the specified timeout
func NewTimedCounter(clock clockwork.Clock, timeout time.Duration) *TimedCounter {
return &TimedCounter{
clock: clock,
timeout: timeout,
events: nil,
}
}

// Increment adds a new item into the counter, returning the current count.
func (c *TimedCounter) Increment() int {
c.trim()
c.events = append(c.events, c.clock.Now())
return len(c.events)
}

// Count fetches the number of recorded events currently in the measurement
// time window.
func (c *TimedCounter) Count() int {
c.trim()
return len(c.events)
}

func (c *TimedCounter) trim() {
deadline := c.clock.Now().Add(-c.timeout)
lastExpiredEvent := -1
for i := range c.events {
if c.events[i].After(deadline) {
break
}
lastExpiredEvent = i
}

if lastExpiredEvent > -1 {
c.events = c.events[lastExpiredEvent+1:]
}
}
Loading

0 comments on commit 97c18fa

Please sign in to comment.