Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions api/breaker/breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,18 @@ type Config struct {
// StateStandby to StateTripped. This is required to be supplied, failure to do so will result in an error
// creating the CircuitBreaker.
Trip TripFn
// OnTripped will be called when the CircuitBreaker enters the StateTripped state
// OnTripped will be called when the CircuitBreaker enters the StateTripped
// state; this callback is called while holding a lock, so it should return
// quickly.
OnTripped func()
// OnStandby will be called when the CircuitBreaker returns to the StateStandby state
// OnStandby will be called when the CircuitBreaker returns to the
// StateStandby state; this callback is called while holding a lock, so it
// should return quickly.
OnStandBy func()
// OnExecute will be called once for each execution, and given the result
// and the current state of the breaker state; this callback is called while
// holding a lock, so it should return quickly.
OnExecute func(success bool, state State)
// IsSuccessful is used by the CircuitBreaker to determine if the executed function was successful or not
IsSuccessful func(v interface{}, err error) bool
// Logger is the logger
Expand All @@ -139,6 +147,12 @@ type Config struct {
TrippedErrorMessage string
}

// Clone returns a clone of the Config.
func (c *Config) Clone() Config {
// the current Config can just be copied without issues
return *c
}

// TripFn determines if the CircuitBreaker should be tripped based
// on the state of the provided Metrics. A return value of true will
// cause the CircuitBreaker to transition into the StateTripped state
Expand Down Expand Up @@ -256,6 +270,10 @@ func (c *Config) CheckAndSetDefaults() error {
c.OnStandBy = func() {}
}

if c.OnExecute == nil {
c.OnExecute = func(bool, State) {}
}

if c.IsSuccessful == nil {
c.IsSuccessful = NonNilErrorIsSuccess
}
Expand Down Expand Up @@ -332,8 +350,9 @@ func (c *CircuitBreaker) beforeExecution() (uint64, error) {

generation, state := c.currentState(now)

switch {
case state == StateTripped:
if state == StateTripped {
c.cfg.OnExecute(false, StateTripped)

if c.cfg.TrippedErrorMessage != "" {
return generation, trace.ConnectionProblem(nil, c.cfg.TrippedErrorMessage)
}
Expand All @@ -359,21 +378,21 @@ func (c *CircuitBreaker) afterExecution(prior uint64, v interface{}, err error)
}

if c.cfg.IsSuccessful(v, err) {
c.cfg.Logger.Debugf("successful execution, %s", c.metrics.String())
c.success(state, now)
c.successLocked(state, now)
} else {
c.cfg.Logger.Debugf("failed execution, %s", c.metrics.String())
c.failure(state, now)
c.failureLocked(state, now)
}
}

// success tallies a successful execution and migrates to StateStandby
// successLocked tallies a successful execution and migrates to StateStandby
// if in another state and criteria has been met to transition
func (c *CircuitBreaker) success(state State, t time.Time) {
func (c *CircuitBreaker) successLocked(state State, t time.Time) {
switch state {
case StateStandby:
c.cfg.OnExecute(true, StateStandby)
c.metrics.success()
case StateRecovering:
c.cfg.OnExecute(true, StateRecovering)
c.metrics.success()
if c.metrics.ConsecutiveSuccesses >= c.cfg.RecoveryLimit {
c.setState(StateStandby, t)
Expand All @@ -382,17 +401,19 @@ func (c *CircuitBreaker) success(state State, t time.Time) {
}
}

// failure tallies a failed execution and migrate to StateTripped
// failureLocked tallies a failed execution and migrate to StateTripped
// if in another state and criteria has been met to transition
func (c *CircuitBreaker) failure(state State, t time.Time) {
func (c *CircuitBreaker) failureLocked(state State, t time.Time) {
c.metrics.failure()

switch state {
case StateRecovering:
c.cfg.OnExecute(false, StateRecovering)
if c.cfg.Recover(c.metrics) {
c.setState(StateTripped, t)
}
case StateStandby:
c.cfg.OnExecute(false, StateStandby)
if c.cfg.Trip(c.metrics) {
c.setState(StateTripped, t)
go c.cfg.OnTripped()
Expand Down
4 changes: 2 additions & 2 deletions api/breaker/breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestCircuitBreaker_success(t *testing.T) {
cb.state = tt.initialState

generation, state := cb.currentState(clock.Now())
cb.success(tt.successState, clock.Now())
cb.successLocked(tt.successState, clock.Now())
require.Equal(t, tt.expectedState, cb.state)
if tt.expectedState != state {
require.NotEqual(t, generation, cb.generation)
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestCircuitBreaker_failure(t *testing.T) {
cb.state = tt.initialState

generation, state := cb.currentState(clock.Now())
cb.failure(tt.failureState, clock.Now())
cb.failureLocked(tt.failureState, clock.Now())
require.Equal(t, tt.expectedState, cb.state)
if tt.expectedState != state {
require.NotEqual(t, generation, cb.generation)
Expand Down
24 changes: 21 additions & 3 deletions lib/auth/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"

"github.com/gravitational/teleport/api/client/proto"
Expand Down Expand Up @@ -103,14 +104,31 @@ type joinAttributeSourcer interface {
//
// If the token includes a specific join method, the rules for that join method
// will be checked.
func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
log.Infof("Node %q [%v] is trying to join with role: %v.", req.NodeName, req.HostID, req.Role)
func (a *Server) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (_ *proto.Certs, err error) {
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}

method := a.tokenJoinMethod(ctx, req.Token)
defer func() {
if err == nil {
return
}
level := logrus.WarnLevel
if trace.IsAccessDenied(err) {
level = logrus.DebugLevel
}
log.WithFields(logrus.Fields{
"node_name": req.NodeName,
"host_id": req.HostID,
"role": req.Role,
"method": method,
logrus.ErrorKey: err,
}).Log(level, "Agent has failed to join the cluster.")
}()

var joinAttributeSrc joinAttributeSourcer
switch method := a.tokenJoinMethod(ctx, req.Token); method {
switch method {
case types.JoinMethodEC2:
if err := a.checkEC2JoinRequest(ctx, req); err != nil {
return nil, trace.Wrap(err)
Expand Down
22 changes: 21 additions & 1 deletion lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/exp/slices"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -338,7 +339,7 @@ func withFips(fips bool) iamRegisterOption {
// The caller must provide a ChallengeResponseFunc which returns a
// *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request
// including the challenge as a signed header.
func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc, opts ...iamRegisterOption) (*proto.Certs, error) {
func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc, opts ...iamRegisterOption) (_ *proto.Certs, err error) {
cfg := defaultIAMRegisterConfig(a.fips)
for _, opt := range opts {
opt(cfg)
Expand All @@ -365,11 +366,30 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c
return nil, trace.Wrap(err)
}

var method types.JoinMethod = "unknown"
defer func() {
if err == nil {
return
}
level := logrus.WarnLevel
if trace.IsAccessDenied(err) {
level = logrus.DebugLevel
}
log.WithFields(logrus.Fields{
"node_name": req.RegisterUsingTokenRequest.NodeName,
"host_id": req.RegisterUsingTokenRequest.HostID,
"role": req.RegisterUsingTokenRequest.Role,
"method": method,
logrus.ErrorKey: err,
}).Log(level, "Agent has failed to join the cluster.")
}()

// perform common token checks
provisionToken, err := a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest)
if err != nil {
return nil, trace.Wrap(err)
}
method = provisionToken.GetJoinMethod()

// check that the GetCallerIdentity request is valid and matches the token
if err := a.checkIAMRequest(ctx, challenge, req, cfg); err != nil {
Expand Down
32 changes: 14 additions & 18 deletions lib/joinserver/joinserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,37 +65,35 @@ func NewJoinServiceGRPCServer(joinServiceClient joinServiceClient) *JoinServiceG
// sts:GetCallerIdentity request with the challenge string. Finally, the signed
// cluster certs are sent on the server stream.
func (s *JoinServiceGRPCServer) RegisterUsingIAMMethod(srv proto.JoinService_RegisterUsingIAMMethodServer) error {
ctx := srv.Context()

// Enforce a timeout on the entire RPC so that misbehaving clients cannot
// hold connections open indefinitely.
timeout := s.clock.After(iamJoinRequestTimeout)
timeout := s.clock.NewTimer(iamJoinRequestTimeout)
defer timeout.Stop()

// The only way to cancel a blocked Send or Recv on the server side without
// adding an interceptor to the entire gRPC service is to return from the
// handler https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
errCh := make(chan error, 1)
go func() {
errCh <- s.registerUsingIAMMethod(ctx, srv)
errCh <- s.registerUsingIAMMethod(srv)
}()
select {
case err := <-errCh:
// Completed before the deadline, return the error (may be nil).
return trace.Wrap(err)
case <-timeout:
case <-timeout.Chan():
nodeAddr := ""
if peerInfo, ok := peer.FromContext(ctx); ok {
if peerInfo, ok := peer.FromContext(srv.Context()); ok {
nodeAddr = peerInfo.Addr.String()
}
logrus.Warnf("IAM join attempt timed out, node at (%s) is misbehaving or did not close the connection after encountering an error.", nodeAddr)
// Returning here should cancel any blocked Send or Recv operations.
return trace.LimitExceeded("RegisterUsingIAMMethod timed out after %s, terminating the stream on the server", iamJoinRequestTimeout)
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

func (s *JoinServiceGRPCServer) registerUsingIAMMethod(ctx context.Context, srv proto.JoinService_RegisterUsingIAMMethodServer) error {
func (s *JoinServiceGRPCServer) registerUsingIAMMethod(srv proto.JoinService_RegisterUsingIAMMethodServer) error {
ctx := srv.Context()
// Call RegisterUsingIAMMethod with a callback to get the challenge response
// from the gRPC client.
certs, err := s.joinServiceClient.RegisterUsingIAMMethod(ctx, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
Expand Down Expand Up @@ -130,37 +128,35 @@ func (s *JoinServiceGRPCServer) registerUsingIAMMethod(ctx context.Context, srv
// attested data document with the challenge string. Finally, the signed
// cluster certs are sent on the server stream.
func (s *JoinServiceGRPCServer) RegisterUsingAzureMethod(srv proto.JoinService_RegisterUsingAzureMethodServer) error {
ctx := srv.Context()

// Enforce a timeout on the entire RPC so that misbehaving clients cannot
// hold connections open indefinitely.
timeout := s.clock.After(azureJoinRequestTimeout)
timeout := s.clock.NewTimer(azureJoinRequestTimeout)
defer timeout.Stop()

// The only way to cancel a blocked Send or Recv on the server side without
// adding an interceptor to the entire gRPC service is to return from the
// handler https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
errCh := make(chan error, 1)
go func() {
errCh <- s.registerUsingAzureMethod(ctx, srv)
errCh <- s.registerUsingAzureMethod(srv)
}()
select {
case err := <-errCh:
// Completed before the deadline, return the error (may be nil).
return trace.Wrap(err)
case <-timeout:
case <-timeout.Chan():
nodeAddr := ""
if peerInfo, ok := peer.FromContext(ctx); ok {
if peerInfo, ok := peer.FromContext(srv.Context()); ok {
nodeAddr = peerInfo.Addr.String()
}
logrus.Warnf("Azure join attempt timed out, node at (%s) is misbehaving or did not close the connection after encountering an error.", nodeAddr)
// Returning here should cancel any blocked Send or Recv operations.
return trace.LimitExceeded("RegisterUsingAzureMethod timed out after %s, terminating the stream on the server", azureJoinRequestTimeout)
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

func (s *JoinServiceGRPCServer) registerUsingAzureMethod(ctx context.Context, srv proto.JoinService_RegisterUsingAzureMethodServer) error {
func (s *JoinServiceGRPCServer) registerUsingAzureMethod(srv proto.JoinService_RegisterUsingAzureMethodServer) error {
ctx := srv.Context()
certs, err := s.joinServiceClient.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
err := srv.Send(&proto.RegisterUsingAzureMethodResponse{
Challenge: challenge,
Expand Down
56 changes: 56 additions & 0 deletions lib/service/breaker/breaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package breaker

import (
"strconv"
"sync"

"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/types"
)

var connectorExecutions = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: teleport.MetricNamespace,
Subsystem: "breaker",
Name: "connector_executions_total",
Help: "Client requests per system role, state of the breaker and success as interpreted by the breaker.",
}, []string{"role", "state", "success"})

var registerOnce sync.Once

func ensureRegistered() {
registerOnce.Do(func() {
prometheus.MustRegister(connectorExecutions)
})
}

// InstrumentBreakerForConnector returns a copy of a [breaker.Config] that
// counts client "executions" (i.e. requests or streams) that go through the
// breaker, attributing the count to the given system role.
func InstrumentBreakerForConnector(role types.SystemRole, cfg breaker.Config) breaker.Config {
ensureRegistered()

cfg = cfg.Clone()
cfg.OnExecute = func(success bool, state breaker.State) {
connectorExecutions.WithLabelValues(role.String(), state.String(), strconv.FormatBool(success)).Inc()
}
return cfg
}
Loading