Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request Limiter reloadable config #25095

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/25095.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
limits: Introduce a reloadable disable configuration for the Request Limiter.
```
12 changes: 12 additions & 0 deletions command/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,12 @@ func (c *ServerCommand) Run(args []string) int {
infoKeys = append(infoKeys, "administrative namespace")
info["administrative namespace"] = config.AdministrativeNamespacePath

infoKeys = append(infoKeys, "request limiter")
info["request limiter"] = "enabled"
if config.RequestLimiter != nil && config.RequestLimiter.Disable {
info["request limiter"] = "disabled"
}

sort.Strings(infoKeys)
c.UI.Output("==> Vault server configuration:\n")

Expand Down Expand Up @@ -1661,6 +1667,8 @@ func (c *ServerCommand) Run(args []string) int {
// Setting log request with the new value in the config after reload
core.ReloadLogRequestsLevel()

core.ReloadRequestLimiter()

// reloading HCP link
hcpLink, err = c.reloadHCPLink(hcpLink, config, core, hcpLogger)
if err != nil {
Expand Down Expand Up @@ -3095,6 +3103,10 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical.
AdministrativeNamespacePath: config.AdministrativeNamespacePath,
}

if config.RequestLimiter != nil {
coreConfig.DisableRequestLimiter = config.RequestLimiter.Disable
}

if c.flagDev {
coreConfig.EnableRaw = true
coreConfig.EnableIntrospection = true
Expand Down
53 changes: 53 additions & 0 deletions command/server/config_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package server

import (
"fmt"
"testing"

"github.com/hashicorp/vault/internalshared/configutil"
Expand Down Expand Up @@ -86,3 +87,55 @@ func TestCheckSealConfig(t *testing.T) {
})
}
}

// TestRequestLimiterConfig verifies that the census config is correctly instantiated from HCL
func TestRequestLimiterConfig(t *testing.T) {
testCases := []struct {
name string
inConfig string
outErr bool
outRequestLimiter *configutil.RequestLimiter
}{
{
name: "empty",
outRequestLimiter: nil,
},
{
name: "disabled",
inConfig: `
request_limiter {
disable = true
}`,
outRequestLimiter: &configutil.RequestLimiter{Disable: true},
},
{
name: "invalid disable",
inConfig: `
request_limiter {
disable = "whywouldyoudothis"
}`,
outErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
config := fmt.Sprintf(`
ui = false
storage "file" {
path = "/tmp/test"
}

listener "tcp" {
address = "0.0.0.0:8200"
}
%s`, tc.inConfig)
gotConfig, err := ParseConfig(config, "")
if tc.outErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.outRequestLimiter, gotConfig.RequestLimiter)
}
})
}
}
16 changes: 16 additions & 0 deletions internalshared/configutil/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ type SharedConfig struct {
ClusterName string `hcl:"cluster_name"`

AdministrativeNamespacePath string `hcl:"administrative_namespace_path"`

RequestLimiter *RequestLimiter `hcl:"request_limiter"`
}

func ParseConfig(d string) (*SharedConfig, error) {
Expand Down Expand Up @@ -156,6 +158,13 @@ func ParseConfig(d string) (*SharedConfig, error) {
}
}

if o := list.Filter("request_limiter"); len(o.Items) > 0 {
result.found("request_limiter", "RequestLimiter")
if err := parseRequestLimiter(&result, o); err != nil {
return nil, fmt.Errorf("error parsing 'request_limiter': %w", err)
}
}

entConfig := &(result.EntSharedConfig)
if err := entConfig.ParseConfig(list); err != nil {
return nil, fmt.Errorf("error parsing enterprise config: %w", err)
Expand Down Expand Up @@ -284,6 +293,13 @@ func (c *SharedConfig) Sanitized() map[string]interface{} {
result["telemetry"] = sanitizedTelemetry
}

if c.RequestLimiter != nil {
sanitizedRequestLimiter := map[string]interface{}{
"disable": c.RequestLimiter.Disable,
}
result["request_limiter"] = sanitizedRequestLimiter
}

return result
}

Expand Down
5 changes: 5 additions & 0 deletions internalshared/configutil/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,10 @@ func (c *SharedConfig) Merge(c2 *SharedConfig) *SharedConfig {
result.ClusterName = c2.ClusterName
}

result.RequestLimiter = c.RequestLimiter
if c2.RequestLimiter != nil {
result.RequestLimiter = c2.RequestLimiter
}

return result
}
59 changes: 59 additions & 0 deletions internalshared/configutil/request_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package configutil

import (
"fmt"

"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast"
)

type RequestLimiter struct {
UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"`

Disable bool `hcl:"-"`
DisableRaw interface{} `hcl:"disable"`
}

func (r *RequestLimiter) Validate(source string) []ConfigError {
return ValidateUnusedFields(r.UnusedKeys, source)
}

func (r *RequestLimiter) GoString() string {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for tests? Truthfully not sure, but other configs used this so I followed suit.

return fmt.Sprintf("*%#v", *r)
}

var DefaultRequestLimiter = &RequestLimiter{
Disable: false,
}

func parseRequestLimiter(result *SharedConfig, list *ast.ObjectList) error {
if len(list.Items) > 1 {
return fmt.Errorf("only one 'request_limiter' block is permitted")
}

result.RequestLimiter = DefaultRequestLimiter

// Get our one item
item := list.Items[0]

if err := hcl.DecodeObject(&result.RequestLimiter, item.Val); err != nil {
return multierror.Prefix(err, "request_limiter:")
}

if result.RequestLimiter.DisableRaw != nil {
var err error
if result.RequestLimiter.Disable, err = parseutil.ParseBool(result.RequestLimiter.DisableRaw); err != nil {
return err
}
result.RequestLimiter.DisableRaw = nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a common pattern, setting it to nil after parsing it? I haven't seen that before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for some reason we clear the raw entry after parsing. I'm not sure I understand it, but copied it from other implementations.

} else {
result.RequestLimiter.Disable = false
}

return nil
}
17 changes: 17 additions & 0 deletions limits/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) {
r.Limiters[flags.Name] = limiter
}

// Disable drops its references to underlying limiters.
func (r *LimiterRegistry) Disable() {
r.Lock()

if !r.Enabled {
return
}

r.Logger.Info("disabling request limiters")
// Any outstanding tokens will be flushed when their request completes, as
// they've already acquired a listener. Just drop the limiter references
// here and the garbage-collector should take care of the rest.
r.Limiters = map[string]*RequestLimiter{}
r.Enabled = false
r.Unlock()
}

// GetLimiter looks up a RequestLimiter by key in the LimiterRegistry.
func (r *LimiterRegistry) GetLimiter(key string) *RequestLimiter {
r.RLock()
Expand Down
33 changes: 32 additions & 1 deletion vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,8 @@ type CoreConfig struct {

ClusterAddrBridge *raft.ClusterAddrBridge

LimiterRegistry *limits.LimiterRegistry
DisableRequestLimiter bool
LimiterRegistry *limits.LimiterRegistry
}

// GetServiceRegistration returns the config's ServiceRegistration, or nil if it does
Expand Down Expand Up @@ -1293,6 +1294,15 @@ func NewCore(conf *CoreConfig) (*Core, error) {
return nil, err
}

c.limiterRegistry = conf.LimiterRegistry
c.limiterRegistryLock.Lock()
if conf.DisableRequestLimiter {
c.limiterRegistry.Disable()
} else {
c.limiterRegistry.Enable()
}
c.limiterRegistryLock.Unlock()

err = c.adjustForSealMigration(conf.UnwrapSeal)
if err != nil {
return nil, err
Expand Down Expand Up @@ -4056,6 +4066,27 @@ func (c *Core) ReloadLogRequestsLevel() {
}
}

func (c *Core) ReloadRequestLimiter() {
c.limiterRegistry.Logger.Info("reloading request limiter config")
conf := c.rawConfig.Load()
if conf == nil {
return
}

disable := false
requestLimiterConfig := conf.(*server.Config).RequestLimiter
if requestLimiterConfig != nil {
disable = requestLimiterConfig.Disable
}

switch disable {
case true:
c.limiterRegistry.Disable()
default:
c.limiterRegistry.Enable()
}
}

func (c *Core) ReloadIntrospectionEndpointEnabled() {
conf := c.rawConfig.Load()
if conf == nil {
Expand Down
Loading