diff --git a/agent/grpc-external/services/peerstream/stream_test.go b/agent/grpc-external/services/peerstream/stream_test.go index 5f5d399af9e..a286e6fefe1 100644 --- a/agent/grpc-external/services/peerstream/stream_test.go +++ b/agent/grpc-external/services/peerstream/stream_test.go @@ -1431,10 +1431,6 @@ func makeClient(t *testing.T, srv *testServer, peerID string) *MockClient { receivedSub3, err := client.Recv() require.NoError(t, err) - // This is required when the client subscribes to server address replication messages. - // We assert for the handler to be called at least once but the data doesn't matter. - srv.mockSnapshotHandler.expect("", 0, 0, nil) - // Issue services, roots, and server address subscription to server. // Note that server address may not come as an initial message for _, resourceURL := range []string{ @@ -3057,9 +3053,9 @@ func requireEqualInstances(t *testing.T, expect, got structs.CheckServiceNodes) type testServer struct { *Server - // mockSnapshotHandler is solely used for handling autopilot events + // readyServersSnapshotHandler is solely used for handling autopilot events // which don't come from the state store. - mockSnapshotHandler *mockSnapshotHandler + readyServersSnapshotHandler *dummyReadyServersSnapshotHandler } func newTestServer(t *testing.T, configFn func(c *Config)) (*testServer, *state.Store) { @@ -3101,8 +3097,8 @@ func newTestServer(t *testing.T, configFn func(c *Config)) (*testServer, *state. t.Cleanup(grpcServer.Stop) return &testServer{ - Server: srv, - mockSnapshotHandler: handler, + Server: srv, + readyServersSnapshotHandler: handler, }, store } diff --git a/agent/grpc-external/services/peerstream/subscription_manager_test.go b/agent/grpc-external/services/peerstream/subscription_manager_test.go index c7b77edec96..c9e664477f3 100644 --- a/agent/grpc-external/services/peerstream/subscription_manager_test.go +++ b/agent/grpc-external/services/peerstream/subscription_manager_test.go @@ -2,12 +2,12 @@ package peerstream import ( "context" + "fmt" "sort" "sync" "testing" "time" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/hashicorp/consul/acl" @@ -678,7 +678,7 @@ func TestSubscriptionManager_ServerAddrs(t *testing.T) { }, } // mock handler only gets called once during the initial subscription - backend.handler.expect("", 0, 1, payload) + backend.handler.SetPayload(1, payload) // Only configure a tracker for server address events. tracker := newResourceSubscriptionTracker() @@ -1013,7 +1013,7 @@ func TestFlattenChecks(t *testing.T) { type testSubscriptionBackend struct { state.EventPublisher store *state.Store - handler *mockSnapshotHandler + handler *dummyReadyServersSnapshotHandler lastIdx uint64 } @@ -1130,11 +1130,11 @@ func setupTestPeering(t *testing.T, store *state.Store, name string, index uint6 return p.ID } -func newStateStore(t *testing.T, publisher *stream.EventPublisher) (*state.Store, *mockSnapshotHandler) { +func newStateStore(t *testing.T, publisher *stream.EventPublisher) (*state.Store, *dummyReadyServersSnapshotHandler) { gc, err := state.NewTombstoneGC(time.Second, time.Millisecond) require.NoError(t, err) - handler := newMockSnapshotHandler(t) + handler := &dummyReadyServersSnapshotHandler{} store := state.NewStateStoreWithEventPublisher(gc, publisher) require.NoError(t, publisher.RegisterHandler(state.EventTopicServiceHealth, store.ServiceHealthSnapshot, false)) @@ -1294,38 +1294,34 @@ func pbCheck(node, svcID, svcName, status string, entMeta *pbcommon.EnterpriseMe } } -// mockSnapshotHandler is copied from server_discovery/server_test.go -type mockSnapshotHandler struct { - mock.Mock +type dummyReadyServersSnapshotHandler struct { + lock sync.Mutex + eventIndex uint64 + payload autopilotevents.EventPayloadReadyServers } -func newMockSnapshotHandler(t *testing.T) *mockSnapshotHandler { - handler := &mockSnapshotHandler{} - t.Cleanup(func() { - handler.AssertExpectations(t) - }) - return handler +func (h *dummyReadyServersSnapshotHandler) SetPayload(idx uint64, payload autopilotevents.EventPayloadReadyServers) { + h.lock.Lock() + defer h.lock.Unlock() + h.eventIndex = idx + h.payload = payload } -func (m *mockSnapshotHandler) handle(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { - ret := m.Called(req, buf) - return ret.Get(0).(uint64), ret.Error(1) -} +func (h *dummyReadyServersSnapshotHandler) handle(req stream.SubscribeRequest, buf stream.SnapshotAppender) (uint64, error) { + if req.Topic != autopilotevents.EventTopicReadyServers { + return 0, fmt.Errorf("bad request") + } + if req.Subject != stream.SubjectNone { + return 0, fmt.Errorf("bad request") + } -func (m *mockSnapshotHandler) expect(token string, requestIndex uint64, eventIndex uint64, payload autopilotevents.EventPayloadReadyServers) { - m.On("handle", stream.SubscribeRequest{ + h.lock.Lock() + defer h.lock.Unlock() + buf.Append([]stream.Event{{ Topic: autopilotevents.EventTopicReadyServers, - Subject: stream.SubjectNone, - Token: token, - Index: requestIndex, - }, mock.Anything).Run(func(args mock.Arguments) { - buf := args.Get(1).(stream.SnapshotAppender) - buf.Append([]stream.Event{ - { - Topic: autopilotevents.EventTopicReadyServers, - Index: eventIndex, - Payload: payload, - }, - }) - }).Return(eventIndex, nil) + Index: h.eventIndex, + Payload: h.payload, + }}) + + return h.eventIndex, nil }