Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 65 additions & 45 deletions agent/consul/rate/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ func NewHandlerWithLimiter(
logger: logger,
}
h.globalCfg.Store(&cfg)
h.ipCfg.Store(&IPLimitConfig{})

return h
}
Expand Down Expand Up @@ -248,45 +249,39 @@ func (h *Handler) Allow(op Operation) error {
return nil
}

for _, l := range h.limits(op) {
if l.mode == ModeDisabled {
continue
}

if h.limiter.Allow(l.ent) {
continue
}

// TODO(NET-1382): is this the correct log-level?

enforced := l.mode == ModeEnforcing
h.logger.Debug("RPC exceeded allowed rate limit",
"rpc", op.Name,
"source_addr", op.SourceAddr,
"limit_type", l.desc,
"limit_enforced", enforced,
)

metrics.IncrCounterWithLabels([]string{"rpc", "rate_limit", "exceeded"}, 1, []metrics.Label{
{
Name: "limit_type",
Value: l.desc,
},
{
Name: "op",
Value: op.Name,
},
{
Name: "mode",
Value: l.mode.String(),
},
})

if enforced {
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
return ErrRetryLater
allow, throttledLimits := h.allowAllLimits(h.limits(op))

if !allow {
for _, l := range throttledLimits {
enforced := l.mode == ModeEnforcing
h.logger.Debug("RPC exceeded allowed rate limit",
"rpc", op.Name,
"source_addr", op.SourceAddr,
"limit_type", l.desc,
"limit_enforced", enforced,
)

metrics.IncrCounterWithLabels([]string{"rpc", "rate_limit", "exceeded"}, 1, []metrics.Label{
{
Name: "limit_type",
Value: l.desc,
},
{
Name: "op",
Value: op.Name,
},
{
Name: "mode",
Value: l.mode.String(),
},
})

if enforced {
if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite {
return ErrRetryLater
}
return ErrRetryElsewhere
}
return ErrRetryElsewhere
}
}
return nil
Expand Down Expand Up @@ -320,6 +315,23 @@ type limit struct {
desc string
}

func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) {
allow := true
throttledLimits := make([]limit, 0)

for _, l := range limits {
if l.mode == ModeDisabled {
continue
}

if !h.limiter.Allow(l.ent) {
throttledLimits = append(throttledLimits, l)
allow = false
}
}
return allow, throttledLimits
}

// limits returns the limits to check for the given operation (e.g. global +
// ip-based + tenant-based).
func (h *Handler) limits(op Operation) []limit {
Expand All @@ -329,6 +341,14 @@ func (h *Handler) limits(op Operation) []limit {
limits = append(limits, *global)
}

if ipGlobal := h.ipGlobalLimit(op); ipGlobal != nil {
limits = append(limits, *ipGlobal)
}

if ipCategory := h.ipCategoryLimit(op); ipCategory != nil {
limits = append(limits, *ipCategory)
}

return limits
}

Expand All @@ -354,23 +374,23 @@ func (h *Handler) globalLimit(op Operation) *limit {

var (
// globalWrite identifies the global rate limit applied to write operations.
globalWrite = globalLimit("global.write")
globalWrite = limitedEntity("global.write")

// globalRead identifies the global rate limit applied to read operations.
globalRead = globalLimit("global.read")
globalRead = limitedEntity("global.read")

// globalIPRead identifies the global rate limit applied to read operations.
globalIPRead = globalLimit("global.ip.read")
globalIPRead = limitedEntity("global.ip.read")

// globalIPWrite identifies the global rate limit applied to read operations.
globalIPWrite = globalLimit("global.ip.write")
globalIPWrite = limitedEntity("global.ip.write")
)

// globalLimit represents a limit that applies to all writes or reads.
type globalLimit []byte
// limitedEntity convert the string type to Multilimiter.LimitedEntity
type limitedEntity []byte

// Key satisfies the multilimiter.LimitedEntity interface.
func (prefix globalLimit) Key() multilimiter.KeyType {
func (prefix limitedEntity) Key() multilimiter.KeyType {
return multilimiter.Key(prefix, nil)
}

Expand Down
11 changes: 9 additions & 2 deletions agent/consul/rate/handler_oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@

package rate

type IPLimitConfig struct {
}
type IPLimitConfig struct{}

func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) {
// noop
}

func (h *Handler) ipGlobalLimit(op Operation) *limit {
return nil
}

func (h *Handler) ipCategoryLimit(op Operation) *limit {
return nil
}
48 changes: 11 additions & 37 deletions agent/consul/rate/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@ package rate

import (
"bytes"
"context"
"github.com/hashicorp/consul/agent/metrics"
"github.com/stretchr/testify/require"
"net"
"net/netip"
"testing"

"golang.org/x/time/rate"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/mock"

"github.com/hashicorp/consul/agent/consul/multilimiter"
"github.com/hashicorp/consul/agent/metrics"
)

//
Expand Down Expand Up @@ -226,10 +224,10 @@ func TestHandler(t *testing.T) {
for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
sink := metrics.TestSetupMetrics(t, "")
limiter := newMockLimiter(t)
limiter := multilimiter.NewMockRateLimiter(t)
limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
for _, c := range tc.checks {
limiter.On("Allow", c.limit).Return(c.allow)
limiter.On("Allow", mock.Anything).Return(c.allow)
}

leaderStatusProvider := NewMockLeaderStatusProvider(t)
Expand Down Expand Up @@ -376,7 +374,7 @@ func TestAllow(t *testing.T) {
type testCase struct {
description string
cfg *HandlerConfig
expectedAllowCalls int
expectedAllowCalls bool
}
testCases := []testCase{
{
Expand All @@ -390,7 +388,7 @@ func TestAllow(t *testing.T) {
},
},
},
expectedAllowCalls: 0,
expectedAllowCalls: false,
},
{
description: "RateLimiter gets called when mode is permissive.",
Expand All @@ -403,7 +401,7 @@ func TestAllow(t *testing.T) {
},
},
},
expectedAllowCalls: 1,
expectedAllowCalls: true,
},
{
description: "RateLimiter gets called when mode is enforcing.",
Expand All @@ -416,14 +414,14 @@ func TestAllow(t *testing.T) {
},
},
},
expectedAllowCalls: 1,
expectedAllowCalls: true,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
mockRateLimiter := multilimiter.NewMockRateLimiter(t)
if tc.expectedAllowCalls > 0 {
if tc.expectedAllowCalls {
mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true })
}
mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return()
Expand All @@ -435,31 +433,7 @@ func TestAllow(t *testing.T) {
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234"))
mockRateLimiter.Calls = nil
handler.Allow(Operation{Name: "test", SourceAddr: addr})
mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls)
mockRateLimiter.AssertExpectations(t)
})
}
}

var _ multilimiter.RateLimiter = (*mockLimiter)(nil)

func newMockLimiter(t *testing.T) *mockLimiter {
l := &mockLimiter{}
l.Mock.Test(t)

t.Cleanup(func() { l.AssertExpectations(t) })

return l
}

type mockLimiter struct {
mock.Mock
}

func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) }
func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) }
func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) {
m.Called(cfg, prefix)
}
func (m *mockLimiter) DeleteConfig(prefix []byte) {
m.Called(prefix)
}