diff --git a/statsd/aggregator.go b/statsd/aggregator.go index 139e0684..ed18f8f5 100644 --- a/statsd/aggregator.go +++ b/statsd/aggregator.go @@ -14,18 +14,31 @@ type ( bufferedMetricMap map[string]*bufferedMetric ) +type countShard struct { + sync.RWMutex + counts countsMap +} + +type gaugeShard struct { + sync.RWMutex + gauges gaugesMap +} + +type setShard struct { + sync.RWMutex + sets setsMap +} + type aggregator struct { nbContextGauge uint64 nbContextCount uint64 nbContextSet uint64 - countsM sync.RWMutex - gaugesM sync.RWMutex - setsM sync.RWMutex + shardsCount int + countShards []*countShard + gaugeShards []*gaugeShard + setShards []*setShard - gauges gaugesMap - counts countsMap - sets setsMap histograms bufferedMetricContexts distributions bufferedMetricContexts timings bufferedMetricContexts @@ -43,18 +56,25 @@ type aggregator struct { wg sync.WaitGroup } -func newAggregator(c *ClientEx, maxSamplesPerContext int64) *aggregator { - return &aggregator{ +func newAggregator(c *ClientEx, maxSamplesPerContext int64, shardsCount int) *aggregator { + agg := &aggregator{ client: c, - counts: countsMap{}, - gauges: gaugesMap{}, - sets: setsMap{}, + shardsCount: shardsCount, + countShards: make([]*countShard, shardsCount), + gaugeShards: make([]*gaugeShard, shardsCount), + setShards: make([]*setShard, shardsCount), histograms: newBufferedContexts(newHistogramMetric, maxSamplesPerContext), distributions: newBufferedContexts(newDistributionMetric, maxSamplesPerContext), timings: newBufferedContexts(newTimingMetric, maxSamplesPerContext), closed: make(chan struct{}), stopChannelMode: make(chan struct{}), } + for i := 0; i < shardsCount; i++ { + agg.countShards[i] = &countShard{counts: countsMap{}} + agg.gaugeShards[i] = &gaugeShard{gauges: gaugesMap{}} + agg.setShards[i] = &setShard{sets: setsMap{}} + } + return agg } func (a *aggregator) start(flushInterval time.Duration) { @@ -135,40 +155,43 @@ func (a *aggregator) flushMetrics() []metric { // We reset the values to avoid sending 'zero' values for metrics not // sampled during this flush interval - a.setsM.Lock() - sets := a.sets - a.sets = setsMap{} - a.setsM.Unlock() - - for _, s := range sets { - metrics = append(metrics, s.flushUnsafe()...) + for _, shard := range a.setShards { + shard.Lock() + sets := shard.sets + shard.sets = setsMap{} + shard.Unlock() + for _, s := range sets { + metrics = append(metrics, s.flushUnsafe()...) + } + atomic.AddUint64(&a.nbContextSet, uint64(len(sets))) } - a.gaugesM.Lock() - gauges := a.gauges - a.gauges = gaugesMap{} - a.gaugesM.Unlock() - - for _, g := range gauges { - metrics = append(metrics, g.flushUnsafe()) + for _, shard := range a.gaugeShards { + shard.Lock() + gauges := shard.gauges + shard.gauges = gaugesMap{} + shard.Unlock() + for _, g := range gauges { + metrics = append(metrics, g.flushUnsafe()) + } + atomic.AddUint64(&a.nbContextGauge, uint64(len(gauges))) } - a.countsM.Lock() - counts := a.counts - a.counts = countsMap{} - a.countsM.Unlock() - - for _, c := range counts { - metrics = append(metrics, c.flushUnsafe()) + for _, shard := range a.countShards { + shard.Lock() + counts := shard.counts + shard.counts = countsMap{} + shard.Unlock() + for _, c := range counts { + metrics = append(metrics, c.flushUnsafe()) + } + atomic.AddUint64(&a.nbContextCount, uint64(len(counts))) } metrics = a.histograms.flush(metrics) metrics = a.distributions.flush(metrics) metrics = a.timings.flush(metrics) - atomic.AddUint64(&a.nbContextCount, uint64(len(counts))) - atomic.AddUint64(&a.nbContextGauge, uint64(len(gauges))) - atomic.AddUint64(&a.nbContextSet, uint64(len(sets))) return metrics } @@ -223,76 +246,86 @@ func getContextAndTags(name string, tags []string, cardinality Cardinality) (str return s, s[len(name)+len(nameSeparatorSymbol)+cardStringLen:] } +func getShardIndex(shardsCount int, context string) int { + if shardsCount <= 1 { + return 0 + } + return int(hashString32(context) % uint32(shardsCount)) +} + func (a *aggregator) count(name string, value int64, tags []string, cardinality Cardinality) error { context := getContext(name, tags, cardinality) - a.countsM.RLock() - if count, found := a.counts[context]; found { + shard := a.countShards[getShardIndex(a.shardsCount, context)] + shard.RLock() + if count, found := shard.counts[context]; found { count.sample(value) - a.countsM.RUnlock() + shard.RUnlock() return nil } - a.countsM.RUnlock() + shard.RUnlock() metric := newCountMetric(name, value, tags, cardinality) - a.countsM.Lock() + shard.Lock() // Check if another goroutines hasn't created the value between the RUnlock and 'Lock' - if count, found := a.counts[context]; found { + if count, found := shard.counts[context]; found { count.sample(value) - a.countsM.Unlock() + shard.Unlock() return nil } - a.counts[context] = metric - a.countsM.Unlock() + shard.counts[context] = metric + shard.Unlock() return nil } func (a *aggregator) gauge(name string, value float64, tags []string, cardinality Cardinality) error { context := getContext(name, tags, cardinality) - a.gaugesM.RLock() - if gauge, found := a.gauges[context]; found { + shard := a.gaugeShards[getShardIndex(a.shardsCount, context)] + shard.RLock() + if gauge, found := shard.gauges[context]; found { gauge.sample(value) - a.gaugesM.RUnlock() + shard.RUnlock() return nil } - a.gaugesM.RUnlock() + shard.RUnlock() gauge := newGaugeMetric(name, value, tags, cardinality) - a.gaugesM.Lock() - // Check if another goroutines hasn't created the value betwen the 'RUnlock' and 'Lock' - if gauge, found := a.gauges[context]; found { + shard.Lock() + // Check if another goroutines hasn't created the value between the 'RUnlock' and 'Lock' + if gauge, found := shard.gauges[context]; found { gauge.sample(value) - a.gaugesM.Unlock() + shard.Unlock() return nil } - a.gauges[context] = gauge - a.gaugesM.Unlock() + shard.gauges[context] = gauge + shard.Unlock() return nil } func (a *aggregator) set(name string, value string, tags []string, cardinality Cardinality) error { context := getContext(name, tags, cardinality) - a.setsM.RLock() - if set, found := a.sets[context]; found { + shard := a.setShards[getShardIndex(a.shardsCount, context)] + shard.RLock() + if set, found := shard.sets[context]; found { set.sample(value) - a.setsM.RUnlock() + shard.RUnlock() return nil } - a.setsM.RUnlock() + shard.RUnlock() metric := newSetMetric(name, value, tags, cardinality) - a.setsM.Lock() - // Check if another goroutines hasn't created the value betwen the 'RUnlock' and 'Lock' - if set, found := a.sets[context]; found { + shard.Lock() + // Check if another goroutines hasn't created the value between the 'RUnlock' and 'Lock' + if set, found := shard.sets[context]; found { set.sample(value) - a.setsM.Unlock() + shard.Unlock() return nil } - a.sets[context] = metric - a.setsM.Unlock() + shard.sets[context] = metric + shard.Unlock() return nil } diff --git a/statsd/aggregator_benchmark_test.go b/statsd/aggregator_benchmark_test.go new file mode 100644 index 00000000..269c6ebb --- /dev/null +++ b/statsd/aggregator_benchmark_test.go @@ -0,0 +1,44 @@ +package statsd + +import ( + "fmt" + "testing" +) + +// Prevent compiler from optimizing away function calls +var benchErr error + +func BenchmarkAggregatorSharding(b *testing.B) { + shardCounts := []int{1, 2, 3, 4, 5, 6, 8, 16, 32, 64, 128, 256} + + // Pre-generate metric names to avoid measuring fmt.Sprintf performance + const numMetrics = 100000 + metricNames := make([]string, numMetrics) + for i := 0; i < numMetrics; i++ { + metricNames[i] = fmt.Sprintf("metric.%d", i) + } + + for _, shards := range shardCounts { + b.Run(fmt.Sprintf("Shards_%d", shards), func(b *testing.B) { + a := newAggregator(nil, 0, shards) + tags := []string{"tag:1", "tag:2"} + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var err error + i := 0 + for pb.Next() { + name := metricNames[i%numMetrics] + i++ + err = a.count(name, 1, tags, CardinalityLow) + err = a.gauge(name, 10.0, tags, CardinalityLow) + err = a.set(name, "val", tags, CardinalityLow) + } + benchErr = err + }) + }) + if benchErr != nil { + b.Fatal(benchErr) + } + } +} diff --git a/statsd/aggregator_test.go b/statsd/aggregator_test.go index 38a27434..896f3ba2 100644 --- a/statsd/aggregator_test.go +++ b/statsd/aggregator_test.go @@ -11,27 +11,67 @@ import ( "github.com/stretchr/testify/require" ) +func getAllCounts(a *aggregator) countsMap { + counts := countsMap{} + for _, shard := range a.countShards { + shard.RLock() + for k, v := range shard.counts { + counts[k] = v + } + shard.RUnlock() + } + return counts +} + +func getAllGauges(a *aggregator) gaugesMap { + gauges := gaugesMap{} + for _, shard := range a.gaugeShards { + shard.RLock() + for k, v := range shard.gauges { + gauges[k] = v + } + shard.RUnlock() + } + return gauges +} + +func getAllSets(a *aggregator) setsMap { + sets := setsMap{} + for _, shard := range a.setShards { + shard.RLock() + for k, v := range shard.sets { + sets[k] = v + } + shard.RUnlock() + } + return sets +} + func TestAggregatorSample(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"tag1", "tag2"} for i := 0; i < 2; i++ { a.gauge("gaugeTest", 21, tags, CardinalityNotSet) - assert.Len(t, a.gauges, 1) - assert.Contains(t, a.gauges, "gaugeTest:tag1,tag2") + gauges := getAllGauges(a) + assert.Len(t, gauges, 1) + assert.Contains(t, gauges, "gaugeTest:tag1,tag2") a.count("countTest", 21, tags, CardinalityNotSet) - assert.Len(t, a.counts, 1) - assert.Contains(t, a.counts, "countTest:tag1,tag2") + counts := getAllCounts(a) + assert.Len(t, counts, 1) + assert.Contains(t, counts, "countTest:tag1,tag2") a.set("setTest", "value1", tags, CardinalityNotSet) - assert.Len(t, a.sets, 1) - assert.Contains(t, a.sets, "setTest:tag1,tag2") + sets := getAllSets(a) + assert.Len(t, sets, 1) + assert.Contains(t, sets, "setTest:tag1,tag2") a.set("setTest", "value1", tags, CardinalityNotSet) - assert.Len(t, a.sets, 1) - assert.Contains(t, a.sets, "setTest:tag1,tag2") + sets = getAllSets(a) + assert.Len(t, sets, 1) + assert.Contains(t, sets, "setTest:tag1,tag2") a.histogram("histogramTest", 21, tags, 1, CardinalityNotSet) assert.Len(t, a.histograms.values, 1) @@ -48,7 +88,7 @@ func TestAggregatorSample(t *testing.T) { } func TestAggregatorFlush(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"tag1", "tag2"} @@ -79,9 +119,9 @@ func TestAggregatorFlush(t *testing.T) { metrics := a.flushMetrics() - assert.Len(t, a.gauges, 0) - assert.Len(t, a.counts, 0) - assert.Len(t, a.sets, 0) + assert.Len(t, getAllGauges(a), 0) + assert.Len(t, getAllCounts(a), 0) + assert.Len(t, getAllSets(a), 0) assert.Len(t, a.histograms.values, 0) assert.Len(t, a.distributions.values, 0) assert.Len(t, a.timings.values, 0) @@ -212,7 +252,7 @@ func TestAggregatorFlush(t *testing.T) { func TestAggregatorFlushWithMaxSamplesPerContext(t *testing.T) { // In this test we keep only 2 samples per context for metrics where it's relevant. maxSamples := int64(2) - a := newAggregator(nil, maxSamples) + a := newAggregator(nil, maxSamples, 8) tags := []string{"tag1", "tag2"} @@ -242,9 +282,9 @@ func TestAggregatorFlushWithMaxSamplesPerContext(t *testing.T) { metrics := a.flushMetrics() - assert.Len(t, a.gauges, 0) - assert.Len(t, a.counts, 0) - assert.Len(t, a.sets, 0) + assert.Len(t, getAllGauges(a), 0) + assert.Len(t, getAllCounts(a), 0) + assert.Len(t, getAllSets(a), 0) assert.Len(t, a.histograms.values, 0) assert.Len(t, a.distributions.values, 0) assert.Len(t, a.timings.values, 0) @@ -331,7 +371,7 @@ func TestAggregatorFlushWithMaxSamplesPerContext(t *testing.T) { } func TestAggregatorFlushConcurrency(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) var wg sync.WaitGroup wg.Add(10) @@ -363,7 +403,7 @@ func TestAggregatorFlushConcurrency(t *testing.T) { } func TestAggregatorTagsCopy(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"tag1", "tag2"} a.gauge("gauge", 21, tags, CardinalityLow) @@ -437,7 +477,7 @@ func BenchmarkGetContextNoTags(b *testing.B) { } func TestAggregatorCardinalitySeparation(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"env:prod", "service:api"} a.gauge("test.metric", 10, tags, CardinalityLow) @@ -507,7 +547,7 @@ func TestAggregatorCardinalitySeparation(t *testing.T) { } func TestAggregatorCardinalityPreservation(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"env:prod"} // Test that cardinality is preserved in flushed metrics. @@ -535,7 +575,7 @@ func TestAggregatorCardinalityPreservation(t *testing.T) { } func TestAggregatorCardinalityWithBufferedMetrics(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"env:prod"} a.histogram("test.hist", 10, tags, 1, CardinalityLow) @@ -610,7 +650,7 @@ func TestAggregatorCardinalityWithBufferedMetrics(t *testing.T) { } func TestAggregatorCardinalityEmptyVsNonEmpty(t *testing.T) { - a := newAggregator(nil, 0) + a := newAggregator(nil, 0, 8) tags := []string{"env:prod"} a.gauge("test.metric", 10, tags, CardinalityNotSet) diff --git a/statsd/options.go b/statsd/options.go index a9f6c5f6..225a5aea 100644 --- a/statsd/options.go +++ b/statsd/options.go @@ -28,6 +28,7 @@ var ( defaultOriginDetection = true defaultChannelModeErrorsWhenFull = false defaultErrorHandler = func(error) {} + defaultAggregatorShardCount = 1 ) // Options contains the configuration options for a client. @@ -49,6 +50,7 @@ type Options struct { aggregation bool extendedAggregation bool maxBufferedSamplesPerContext int + aggregatorShardCount int telemetryAddr string originDetection bool containerID string @@ -79,6 +81,7 @@ func resolveOptions(options []Option) (*Options, error) { originDetection: defaultOriginDetection, channelModeErrorsWhenFull: defaultChannelModeErrorsWhenFull, errorHandler: defaultErrorHandler, + aggregatorShardCount: defaultAggregatorShardCount, } for _, option := range options { @@ -424,3 +427,17 @@ func WithCardinality(card Cardinality) Option { return nil } } + +// WithAggregatorShardCount sets the number of shards used for the aggregator. +// Higher values reduce lock contention but increase memory usage. +// +// The default is 1 as to mimic current behavior. +func WithAggregatorShardCount(shardCount int) Option { + return func(o *Options) error { + if shardCount < 1 { + return fmt.Errorf("shardCount must be a positive integer") + } + o.aggregatorShardCount = shardCount + return nil + } +} diff --git a/statsd/options_test.go b/statsd/options_test.go index 9253d0d4..24452442 100644 --- a/statsd/options_test.go +++ b/statsd/options_test.go @@ -28,6 +28,7 @@ func TestDefaultOptions(t *testing.T) { assert.Equal(t, options.extendedAggregation, defaultExtendedAggregation) assert.Zero(t, options.telemetryAddr) assert.Nil(t, options.tagCardinality) + assert.Equal(t, options.aggregatorShardCount, defaultAggregatorShardCount) } func TestOptions(t *testing.T) { @@ -44,6 +45,7 @@ func TestOptions(t *testing.T) { testAggregationWindow := 10 * time.Second testTelemetryAddr := "localhost:1234" testTagCardinality := CardinalityHigh + testAggregatorShardCount := 4 options, err := resolveOptions([]Option{ WithNamespace(testNamespace), @@ -62,6 +64,7 @@ func TestOptions(t *testing.T) { WithClientSideAggregation(), WithTelemetryAddr(testTelemetryAddr), WithCardinality(testTagCardinality), + WithAggregatorShardCount(testAggregatorShardCount), }) assert.NoError(t, err) @@ -82,6 +85,7 @@ func TestOptions(t *testing.T) { assert.Equal(t, options.extendedAggregation, false) assert.Equal(t, options.telemetryAddr, testTelemetryAddr) assert.Equal(t, *options.tagCardinality, testTagCardinality) + assert.Equal(t, options.aggregatorShardCount, testAggregatorShardCount) } func TestExtendedAggregation(t *testing.T) { diff --git a/statsd/statsdex.go b/statsd/statsdex.go index f051cb67..faa3a194 100644 --- a/statsd/statsdex.go +++ b/statsd/statsdex.go @@ -525,7 +525,7 @@ func newWithWriter(w Transport, o *Options, writerName string) (*ClientEx, error } if o.aggregation || o.extendedAggregation || o.maxBufferedSamplesPerContext > 0 { - c.agg = newAggregator(&c, int64(o.maxBufferedSamplesPerContext)) + c.agg = newAggregator(&c, int64(o.maxBufferedSamplesPerContext), o.aggregatorShardCount) c.agg.start(o.aggregationFlushInterval) if o.extendedAggregation {