diff --git a/apisec/sampler.go b/apisec/sampler.go index 97147e7..76131f2 100644 --- a/apisec/sampler.go +++ b/apisec/sampler.go @@ -11,6 +11,7 @@ import ( "time" "github.com/DataDog/appsec-internal-go/apisec/internal/timed" + "github.com/DataDog/appsec-internal-go/limiter" ) type ( @@ -20,6 +21,12 @@ type ( timedSetSampler timed.LRU + proxySampler struct { + limiter limiter.Limiter + } + + nullSampler struct{} + SamplingKey struct { // Method is the value of the http.method span tag Method string @@ -32,6 +39,20 @@ type ( clockFunc = func() int64 ) +// NewProxySampler creates a new sampler suitable for proxy environments where the sampling decision +// is not based on the request's properties, but on a rate. +func NewProxySampler(rate int, interval time.Duration) Sampler { + if rate <= 0 { + return &nullSampler{} + } + r := int64(rate) + l := limiter.NewTokenTickerWithInterval(r, r, interval) + l.Start() + return &proxySampler{ + limiter: l, + } +} + // NewSamplerWithInterval returns a new [*Sampler] with the specified interval. func NewSamplerWithInterval(interval time.Duration) Sampler { return newSampler(interval, timed.UnixTime) @@ -53,6 +74,14 @@ func (s *timedSetSampler) DecisionFor(key SamplingKey) bool { return (*timed.LRU)(s).Hit(keyHash) } +func (s *proxySampler) DecisionFor(_ SamplingKey) bool { + return s.limiter.Allow() +} + +func (s *nullSampler) DecisionFor(_ SamplingKey) bool { + return false +} + // hash returns a hash of the key. Given the same seed, it always produces the // same output. If the seed changes, the output is likely to change as well. func (k SamplingKey) hash() uint64 { diff --git a/appsec/config.go b/appsec/config.go index a50e4e1..09cbbd4 100644 --- a/appsec/config.go +++ b/appsec/config.go @@ -24,6 +24,8 @@ const ( // EnvAPISecSampleRate is the env var used to set the sampling rate of API Security schema extraction. // Deprecated: a new [APISecConfig.Sampler] is now used instead of this. EnvAPISecSampleRate = "DD_API_SECURITY_REQUEST_SAMPLE_RATE" + // EnvAPISecProxySampleRate is the env var used to set the sampling rate of API Security schema extraction for proxies. + EnvAPISecProxySampleRate = "DD_API_SECURITY_PROXY_SAMPLE_RATE" // EnvObfuscatorKey is the env var used to provide the WAF key obfuscation regexp EnvObfuscatorKey = "DD_APPSEC_OBFUSCATION_PARAMETER_KEY_REGEXP" // EnvObfuscatorValue is the env var used to provide the WAF value obfuscation regexp @@ -48,6 +50,10 @@ const ( DefaultAPISecSampleRate = .1 // DefaultAPISecSampleInterval is the default interval between two samples being taken. DefaultAPISecSampleInterval = 30 * time.Second + // DefaultAPISecProxySampleRate is the default rate (schemas per minute) at which API Security schemas are extracted from requests + DefaultAPISecProxySampleRate = 300 + // DefaultAPISecProxySampleInterval is the default time window for the API Security proxy sampler rate limiter. + DefaultAPISecProxySampleInterval = time.Minute // DefaultObfuscatorKeyRegex is the default regexp used to obfuscate keys DefaultObfuscatorKeyRegex = `(?i)pass|pw(?:or)?d|secret|(?:api|private|public|access)[_-]?key|token|consumer[_-]?(?:id|key|secret)|sign(?:ed|ature)|bearer|authorization|jsessionid|phpsessid|asp\.net[_-]sessionid|sid|jwt` // DefaultObfuscatorValueRegex is the default regexp used to obfuscate values @@ -63,6 +69,7 @@ const ( type APISecConfig struct { Sampler apisec.Sampler Enabled bool + IsProxy bool // Deprecated: use the new [APISecConfig.Sampler] instead. SampleRate float64 } @@ -79,12 +86,23 @@ type APISecOption func(*APISecConfig) func NewAPISecConfig(opts ...APISecOption) APISecConfig { cfg := APISecConfig{ Enabled: boolEnv(EnvAPISecEnabled, true), - Sampler: apisec.NewSamplerWithInterval(durationEnv(envAPISecSampleDelay, "s", DefaultAPISecSampleInterval)), SampleRate: readAPISecuritySampleRate(), } for _, opt := range opts { opt(&cfg) } + + if cfg.Sampler != nil { + return cfg + } + + if cfg.IsProxy { + rate := intEnv(EnvAPISecProxySampleRate, DefaultAPISecProxySampleRate) + cfg.Sampler = apisec.NewProxySampler(rate, DefaultAPISecProxySampleInterval) + } else { + cfg.Sampler = apisec.NewSamplerWithInterval(durationEnv(envAPISecSampleDelay, "s", DefaultAPISecSampleInterval)) + } + return cfg } @@ -116,6 +134,13 @@ func WithAPISecSampler(sampler apisec.Sampler) APISecOption { } } +// WithProxy configures API Security for a proxy environment. +func WithProxy() APISecOption { + return func(c *APISecConfig) { + c.IsProxy = true + } +} + // RASPEnabled returns true if RASP functionalities are enabled through the env, or if DD_APPSEC_RASP_ENABLED // is not set func RASPEnabled() bool { @@ -243,3 +268,16 @@ func durationEnv(key string, unit string, def time.Duration) time.Duration { } return val } + +func intEnv(key string, def int) int { + strVal, ok := os.LookupEnv(key) + if !ok { + return def + } + val, err := strconv.Atoi(strVal) + if err != nil { + logEnvVarParsingError(key, strVal, err, def) + return def + } + return val +} diff --git a/limiter/limiter.go b/limiter/limiter.go index f1f16d3..fbe4f3e 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -28,14 +28,21 @@ type Limiter interface { type TokenTicker struct { tokens atomic.Int64 // The amount of tokens currently available maxTokens int64 // The maximum amount of tokens the bucket can hold + interval time.Duration // The interval at which the tokens are refilled ticker *time.Ticker // The ticker used to update the bucket (nil if not started yet) stopChan chan struct{} // The channel to stop the ticker updater (nil if not started yet) } // NewTokenTicker is a utility function that allocates a token ticker, initializes necessary fields and returns it func NewTokenTicker(tokens, maxTokens int64) *TokenTicker { + return NewTokenTickerWithInterval(tokens, maxTokens, time.Second) +} + +// NewTokenTickerWithInterval is a utility function that allocates a token ticker with a custom interval +func NewTokenTickerWithInterval(tokens, maxTokens int64, interval time.Duration) *TokenTicker { t := &TokenTicker{ maxTokens: maxTokens, + interval: interval, } t.tokens.Store(tokens) return t @@ -44,7 +51,7 @@ func NewTokenTicker(tokens, maxTokens int64) *TokenTicker { // updateBucket performs a select loop to update the token amount in the bucket. // Used in a goroutine by the rate limiter. func (t *TokenTicker) updateBucket(startTime time.Time, ticksChan <-chan time.Time, stopChan <-chan struct{}, syncChan chan<- struct{}) { - nsPerToken := time.Second.Nanoseconds() / t.maxTokens + nsPerToken := t.interval.Nanoseconds() / t.maxTokens elapsedNs := int64(0) prevStamp := startTime diff --git a/limiter/limiter_test.go b/limiter/limiter_test.go index 6eaa3b3..7d0467c 100644 --- a/limiter/limiter_test.go +++ b/limiter/limiter_test.go @@ -349,3 +349,47 @@ func (t *TestTicker) tick(delta time.Duration) { func (t *TestTicker) Allow() bool { return t.t.Allow() } + +func newTestTickerWithInterval(tokens, maxTokens int64, interval time.Duration) *TestTicker { + return &TestTicker{ + C: make(chan time.Time), + t: NewTokenTickerWithInterval(tokens, maxTokens, interval), + } +} + +func TestLimiterWithInterval(t *testing.T) { + startTime := time.Now() + t.Run("60-per-minute-rate", func(t *testing.T) { + defer goleak.VerifyNone(t) + + l := newTestTickerWithInterval(1, 60, time.Minute) // 60 tokens per minute, so 1 per second + l.start(startTime) + defer l.stop() + require.True(t, l.Allow(), "First call should be allowed") + require.False(t, l.Allow(), "Second call should be disallowed") + + l.tick(500 * time.Millisecond) + require.False(t, l.Allow(), "A call after 0.5s should be disallowed") + + l.tick(500 * time.Millisecond) // Total 1 second passed + require.True(t, l.Allow(), "A call after 1s should be allowed") + require.False(t, l.Allow(), "Another call should be disallowed") + }) + + t.Run("1-per-100ms-rate", func(t *testing.T) { + defer goleak.VerifyNone(t) + + l := newTestTickerWithInterval(1, 1, 100*time.Millisecond) // 1 token per 100ms + l.start(startTime) + defer l.stop() + require.True(t, l.Allow(), "First call should be allowed") + require.False(t, l.Allow(), "Second call should be disallowed") + + l.tick(50 * time.Millisecond) + require.False(t, l.Allow(), "A call after 50ms should be disallowed") + + l.tick(50 * time.Millisecond) // Total 100ms passed + require.True(t, l.Allow(), "A call after 100ms should be allowed") + require.False(t, l.Allow(), "Another call should be disallowed") + }) +}