diff --git a/CHANGELOG.md b/CHANGELOG.md index 354d10a45..b83f6e2b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,3 +84,4 @@ * [BUGFIX] stringslicecsv: handle unmarshalling empty yaml string #206 * [BUGFIX] Memberlist: retry joining memberlist cluster on startup when no nodes are resolved. #215 * [BUGFIX] Ring status page: display 100% ownership as "100%", rather than "1e+02%". #231 +* [BUGFIX] Memberlist: fix crash when methods from `memberlist.Delegate` interface are called on `*memberlist.KV` before the service is fully started. #244 diff --git a/kv/memberlist/memberlist_client.go b/kv/memberlist/memberlist_client.go index 390eca606..d18b6d452 100644 --- a/kv/memberlist/memberlist_client.go +++ b/kv/memberlist/memberlist_client.go @@ -18,6 +18,7 @@ import ( "github.com/go-kit/log/level" "github.com/hashicorp/memberlist" "github.com/prometheus/client_golang/prometheus" + "go.uber.org/atomic" "github.com/grafana/dskit/backoff" "github.com/grafana/dskit/flagext" @@ -233,9 +234,9 @@ type KV struct { provider DNSProvider // Protects access to memberlist and broadcasts fields. - initWG sync.WaitGroup - memberlist *memberlist.Memberlist - broadcasts *memberlist.TransmitLimitedQueue + delegateReady atomic.Bool + memberlist *memberlist.Memberlist + broadcasts *memberlist.TransmitLimitedQueue // KV Store. storeMu sync.Mutex @@ -451,7 +452,6 @@ func (m *KV) starting(ctx context.Context) error { // // Note: We cannot check for Starting state, as we want to use delegate during cluster joining process // that happens in Starting state. - m.initWG.Add(1) list, err := memberlist.Create(mlCfg) if err != nil { return fmt.Errorf("failed to create memberlist: %v", err) @@ -462,7 +462,7 @@ func (m *KV) starting(ctx context.Context) error { NumNodes: list.NumMembers, RetransmitMult: mlCfg.RetransmitMult, } - m.initWG.Done() + m.delegateReady.Store(true) // Try to fast-join memberlist cluster in Starting state, so that we don't start with empty KV store. if len(m.cfg.JoinMembers) > 0 { @@ -992,6 +992,10 @@ func (m *KV) NodeMeta(limit int) []byte { // NotifyMsg is method from Memberlist Delegate interface // Called when single message is received, i.e. what our broadcastNewValue has sent. func (m *KV) NotifyMsg(msg []byte) { + if !m.delegateReady.Load() { + return + } + m.numberOfReceivedMessages.Inc() m.totalSizeOfReceivedMessages.Add(float64(len(msg))) @@ -1101,7 +1105,9 @@ func (m *KV) queueBroadcast(key string, content []string, version uint, message // GetBroadcasts is method from Memberlist Delegate interface // It returns all pending broadcasts (within the size limit) func (m *KV) GetBroadcasts(overhead, limit int) [][]byte { - m.initWG.Wait() + if !m.delegateReady.Load() { + return nil + } return m.broadcasts.GetBroadcasts(overhead, limit) } @@ -1112,7 +1118,9 @@ func (m *KV) GetBroadcasts(overhead, limit int) [][]byte { // Here we dump our entire state -- all keys and their values. There is no limit on message size here, // as Memberlist uses 'stream' operations for transferring this state. func (m *KV) LocalState(join bool) []byte { - m.initWG.Wait() + if !m.delegateReady.Load() { + return nil + } m.numberOfPulls.Inc() @@ -1184,9 +1192,11 @@ func (m *KV) LocalState(join bool) []byte { // // Data is full state of remote KV store, as generated by LocalState method (run on another node). func (m *KV) MergeRemoteState(data []byte, join bool) { - received := time.Now() + if !m.delegateReady.Load() { + return + } - m.initWG.Wait() + received := time.Now() m.numberOfPushes.Inc() m.totalSizeOfPushes.Add(float64(len(data))) diff --git a/kv/memberlist/memberlist_client_test.go b/kv/memberlist/memberlist_client_test.go index eefcd617f..e45001a92 100644 --- a/kv/memberlist/memberlist_client_test.go +++ b/kv/memberlist/memberlist_client_test.go @@ -1435,6 +1435,49 @@ func TestFastJoin(t *testing.T) { require.Equal(t, JOINING, val.(*data).Members[memberKey].State) } +func TestDelegateMethodsDontCrashBeforeKVStarts(t *testing.T) { + codec := dataCodec{} + + cfg := KVConfig{} + cfg.Codecs = append(cfg.Codecs, codec) + + kv := NewKV(cfg, log.NewNopLogger(), &dnsProviderMock{}, prometheus.NewPedanticRegistry()) + + // Make sure we can call delegate methods on unstarted service, and they don't crash nor block. + kv.LocalState(true) + kv.MergeRemoteState(nil, true) + kv.GetBroadcasts(100, 100) + + now := time.Now() + msg := &data{ + Members: map[string]member{ + "a": {Timestamp: now.Unix() - 5, State: ACTIVE, Tokens: []uint32{}}, + "b": {Timestamp: now.Unix() + 5, State: ACTIVE, Tokens: []uint32{1, 2, 3}}, + "c": {Timestamp: now.Unix(), State: ACTIVE, Tokens: []uint32{}}, + }} + + kv.NotifyMsg(marshalKeyValuePair(t, key, codec, msg)) + + // Verify that message was not added to KV. + time.Sleep(time.Millisecond * 100) + val, err := kv.Get(key, codec) + require.NoError(t, err) + require.Nil(t, val) + + // Now start the service, and try NotifyMsg again + require.NoError(t, services.StartAndAwaitRunning(context.Background(), kv)) + defer services.StopAndAwaitTerminated(context.Background(), kv) //nolint:errcheck + + kv.NotifyMsg(marshalKeyValuePair(t, key, codec, msg)) + + // Wait until processing finished, and check the message again. + time.Sleep(time.Millisecond * 100) + + val, err = kv.Get(key, codec) + require.NoError(t, err) + assert.Equal(t, msg, val) +} + func decodeDataFromMarshalledKeyValuePair(t *testing.T, marshalledKVP []byte, key string, codec dataCodec) *data { kvp := KeyValuePair{} require.NoError(t, kvp.Unmarshal(marshalledKVP))