From d66fd4339cd9c9b20fbfaf416c16333471e406f0 Mon Sep 17 00:00:00 2001 From: Mike Cohen Date: Mon, 5 Apr 2021 08:13:37 +0000 Subject: [PATCH] Added local file buffer for replication service. This allows the service to store events to a local file while the master node is down. --- api/mock/api_mock.go | 150 +++++++++++++++- api/mock/gen.go | 2 +- api/replication.go | 2 + services/hunt_manager/hunt_manager.go | 7 +- services/journal/buffer.go | 247 ++++++++++++++++++++++++++ services/journal/replication.go | 103 +++++++++-- services/journal/replication_test.go | 213 ++++++++++++++++++++++ services/journal/utils.go | 2 - vtesting/helpers.go | 19 +- 9 files changed, 720 insertions(+), 25 deletions(-) create mode 100644 services/journal/buffer.go create mode 100644 services/journal/replication_test.go diff --git a/api/mock/api_mock.go b/api/mock/api_mock.go index a7a6b29781..f8d83ec280 100644 --- a/api/mock/api_mock.go +++ b/api/mock/api_mock.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: www.velocidex.com/golang/velociraptor/api/proto (interfaces: APIClient) +// Source: www.velocidex.com/golang/velociraptor/api/proto (interfaces: APIClient,API_WatchEventClient) // Package mock_proto is a generated GoMock package. package mock_proto @@ -10,6 +10,7 @@ import ( gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" + metadata "google.golang.org/grpc/metadata" emptypb "google.golang.org/protobuf/types/known/emptypb" proto "www.velocidex.com/golang/velociraptor/actions/proto" proto0 "www.velocidex.com/golang/velociraptor/api/proto" @@ -780,6 +781,26 @@ func (mr *MockAPIClientMockRecorder) NotifyClients(arg0, arg1 interface{}, arg2 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotifyClients", reflect.TypeOf((*MockAPIClient)(nil).NotifyClients), varargs...) } +// PushEvents mocks base method. +func (m *MockAPIClient) PushEvents(arg0 context.Context, arg1 *proto0.PushEventRequest, arg2 ...grpc.CallOption) (*emptypb.Empty, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "PushEvents", varargs...) + ret0, _ := ret[0].(*emptypb.Empty) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PushEvents indicates an expected call of PushEvents. +func (mr *MockAPIClientMockRecorder) PushEvents(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushEvents", reflect.TypeOf((*MockAPIClient)(nil).PushEvents), varargs...) +} + // Query mocks base method. func (m *MockAPIClient) Query(arg0 context.Context, arg1 *proto.VQLCollectorArgs, arg2 ...grpc.CallOption) (proto0.API_QueryClient, error) { m.ctrl.T.Helper() @@ -1081,14 +1102,14 @@ func (mr *MockAPIClientMockRecorder) VFSStatDownload(arg0, arg1 interface{}, arg } // WatchEvent mocks base method. -func (m *MockAPIClient) WatchEvent(arg0 context.Context, arg1 *proto0.EventRequest, arg2 ...grpc.CallOption) (*proto0.EventResponse, error) { +func (m *MockAPIClient) WatchEvent(arg0 context.Context, arg1 *proto0.EventRequest, arg2 ...grpc.CallOption) (proto0.API_WatchEventClient, error) { m.ctrl.T.Helper() varargs := []interface{}{arg0, arg1} for _, a := range arg2 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "WatchEvent", varargs...) - ret0, _ := ret[0].(*proto0.EventResponse) + ret0, _ := ret[0].(proto0.API_WatchEventClient) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1119,3 +1140,126 @@ func (mr *MockAPIClientMockRecorder) WriteEvent(arg0, arg1 interface{}, arg2 ... varargs := append([]interface{}{arg0, arg1}, arg2...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteEvent", reflect.TypeOf((*MockAPIClient)(nil).WriteEvent), varargs...) } + +// MockAPI_WatchEventClient is a mock of API_WatchEventClient interface. +type MockAPI_WatchEventClient struct { + ctrl *gomock.Controller + recorder *MockAPI_WatchEventClientMockRecorder +} + +// MockAPI_WatchEventClientMockRecorder is the mock recorder for MockAPI_WatchEventClient. +type MockAPI_WatchEventClientMockRecorder struct { + mock *MockAPI_WatchEventClient +} + +// NewMockAPI_WatchEventClient creates a new mock instance. +func NewMockAPI_WatchEventClient(ctrl *gomock.Controller) *MockAPI_WatchEventClient { + mock := &MockAPI_WatchEventClient{ctrl: ctrl} + mock.recorder = &MockAPI_WatchEventClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAPI_WatchEventClient) EXPECT() *MockAPI_WatchEventClientMockRecorder { + return m.recorder +} + +// CloseSend mocks base method. +func (m *MockAPI_WatchEventClient) CloseSend() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseSend") + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseSend indicates an expected call of CloseSend. +func (mr *MockAPI_WatchEventClientMockRecorder) CloseSend() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseSend", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).CloseSend)) +} + +// Context mocks base method. +func (m *MockAPI_WatchEventClient) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockAPI_WatchEventClientMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).Context)) +} + +// Header mocks base method. +func (m *MockAPI_WatchEventClient) Header() (metadata.MD, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Header") + ret0, _ := ret[0].(metadata.MD) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Header indicates an expected call of Header. +func (mr *MockAPI_WatchEventClientMockRecorder) Header() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Header", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).Header)) +} + +// Recv mocks base method. +func (m *MockAPI_WatchEventClient) Recv() (*proto0.EventResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*proto0.EventResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockAPI_WatchEventClientMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).Recv)) +} + +// RecvMsg mocks base method. +func (m *MockAPI_WatchEventClient) RecvMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecvMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RecvMsg indicates an expected call of RecvMsg. +func (mr *MockAPI_WatchEventClientMockRecorder) RecvMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecvMsg", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).RecvMsg), arg0) +} + +// SendMsg mocks base method. +func (m *MockAPI_WatchEventClient) SendMsg(arg0 interface{}) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMsg", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMsg indicates an expected call of SendMsg. +func (mr *MockAPI_WatchEventClientMockRecorder) SendMsg(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMsg", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).SendMsg), arg0) +} + +// Trailer mocks base method. +func (m *MockAPI_WatchEventClient) Trailer() metadata.MD { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Trailer") + ret0, _ := ret[0].(metadata.MD) + return ret0 +} + +// Trailer indicates an expected call of Trailer. +func (mr *MockAPI_WatchEventClientMockRecorder) Trailer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Trailer", reflect.TypeOf((*MockAPI_WatchEventClient)(nil).Trailer)) +} diff --git a/api/mock/gen.go b/api/mock/gen.go index 54c92ddfb8..1ac0b63834 100644 --- a/api/mock/gen.go +++ b/api/mock/gen.go @@ -1,3 +1,3 @@ -//go:generate mockgen -destination api_mock.go www.velocidex.com/golang/velociraptor/api/proto APIClient +//go:generate mockgen -destination api_mock.go www.velocidex.com/golang/velociraptor/api/proto APIClient,API_WatchEventClient package mock_proto diff --git a/api/replication.go b/api/replication.go index 14843e7d5a..60c662aca7 100644 --- a/api/replication.go +++ b/api/replication.go @@ -37,6 +37,8 @@ func streamEvents( return err } + // The API service is running on the master only! This means + // the journal service is local. output_chan, cancel := journal.Watch(ctx, in.Queue) defer cancel() diff --git a/services/hunt_manager/hunt_manager.go b/services/hunt_manager/hunt_manager.go index 0a35e93dc9..71339f3e3b 100644 --- a/services/hunt_manager/hunt_manager.go +++ b/services/hunt_manager/hunt_manager.go @@ -272,10 +272,15 @@ func (self *HuntManager) ProcessParticipation( participation_row.ClientId, err) } + // If the hunt ran on the client already we just ignore + // it. This is possible because the client may not have + // updated its last hunt number in time to have a number of + // hunt participation messages sent for it from different + // frontends. err = checkHuntRanOnClient(config_obj, participation_row.ClientId, participation_row.HuntId) if err != nil { - return err + return nil } // Get hunt information about this hunt. diff --git a/services/journal/buffer.go b/services/journal/buffer.go new file mode 100644 index 0000000000..58014405c3 --- /dev/null +++ b/services/journal/buffer.go @@ -0,0 +1,247 @@ +// A ring buffer to queue messages + +// Similar to the client ring buffer but this one has no limit because +// we never want to block writers. + +package journal + +import ( + "encoding/binary" + "errors" + "io" + "os" + "sync" + + "google.golang.org/protobuf/proto" + api_proto "www.velocidex.com/golang/velociraptor/api/proto" + config_proto "www.velocidex.com/golang/velociraptor/config/proto" + "www.velocidex.com/golang/velociraptor/constants" + logging "www.velocidex.com/golang/velociraptor/logging" +) + +// The below is similar to http_comms.FileBasedRingBuffer except: +// * Size of the file is not limited. +// * Leasing a single message at once. +// * Messages are of type api_proto.PushEventRequest + +const ( + FileMagic = "VRB\x5f" + FirstRecordOffset = 50 +) + +var ( + ErrorsCorrupted = errors.New("File is corrupted") +) + +type Header struct { + ReadPointer int64 // Leasing will start at this file offset. + WritePointer int64 // Enqueue will write at this file position. +} + +func (self *Header) MarshalBinary() ([]byte, error) { + data := make([]byte, FirstRecordOffset) + copy(data, FileMagic) + + binary.LittleEndian.PutUint64(data[4:12], uint64(self.ReadPointer)) + binary.LittleEndian.PutUint64(data[12:20], uint64(self.WritePointer)) + + return data, nil +} + +func (self *Header) UnmarshalBinary(data []byte) error { + if len(data) < FirstRecordOffset { + return errors.New("Invalid header length") + } + + if string(data[:4]) != FileMagic { + return errors.New("Invalid Magic") + } + + self.ReadPointer = int64(binary.LittleEndian.Uint64(data[4:12])) + self.WritePointer = int64(binary.LittleEndian.Uint64(data[12:20])) + + return nil +} + +type BufferFile struct { + config_obj *config_proto.Config + + mu sync.Mutex + + fd *os.File + Header *Header + + read_buf []byte + write_buf []byte + + log_ctx *logging.LogContext +} + +// Enqueue the item into the ring buffer and append to the end. +func (self *BufferFile) Enqueue(item *api_proto.PushEventRequest) error { + serialized, err := proto.Marshal(item) + if err != nil { + return err + } + + self.mu.Lock() + defer self.mu.Unlock() + + // Write the new message to the end of the file at the WritePointer + binary.LittleEndian.PutUint64(self.write_buf, uint64(len(serialized))) + _, err = self.fd.WriteAt(self.write_buf, int64(self.Header.WritePointer)) + if err != nil { + // File is corrupt now, reset it. + self.Reset() + return err + } + + n, err := self.fd.WriteAt(serialized, int64(self.Header.WritePointer+8)) + if err != nil { + self.Reset() + return err + } + + self.Header.WritePointer += 8 + int64(n) + + // Update the header + serialized, err = self.Header.MarshalBinary() + if err != nil { + return err + } + _, err = self.fd.WriteAt(serialized, 0) + if err != nil { + self.Reset() + return err + } + + return nil +} + +// Returns some messages message from the file. +func (self *BufferFile) Lease() (*api_proto.PushEventRequest, error) { + self.mu.Lock() + defer self.mu.Unlock() + + result := &api_proto.PushEventRequest{} + + // The file is empty. + if self.Header.WritePointer <= self.Header.ReadPointer { + return nil, io.EOF + } + + // Read the next chunk (length+value) from the current leased pointer. + n, err := self.fd.ReadAt(self.read_buf, self.Header.ReadPointer) + if err != nil || n != len(self.read_buf) { + self.log_ctx.Error("Possible corruption detected: file too short.") + self._Truncate() + return nil, ErrorsCorrupted + } + + length := int64(binary.LittleEndian.Uint64(self.read_buf)) + // File might be corrupt - just reset the + // entire file. + if length > constants.MAX_MEMORY*2 || length <= 0 { + self.log_ctx.Error("Possible corruption detected - item length is too large.") + self._Truncate() + return nil, ErrorsCorrupted + } + + // Unmarshal one item at a time. + serialized := make([]byte, length) + n, _ = self.fd.ReadAt(serialized, self.Header.ReadPointer+8) + if int64(n) != length { + self.log_ctx.Errorf( + "Possible corruption detected - expected item of length %v received %v.", + length, n) + self._Truncate() + return nil, ErrorsCorrupted + } + + err = proto.Unmarshal(serialized, result) + if err != nil { + self.log_ctx.Errorf( + "Possible corruption detected - unable to decode item.") + self._Truncate() + return nil, ErrorsCorrupted + } + + // Advance the read pointer + self.Header.ReadPointer += 8 + int64(n) + + // We read up to the write pointer, we may truncate the file + // now. + if self.Header.ReadPointer == self.Header.WritePointer { + self._Truncate() + } + + return result, nil +} + +// _Truncate returns the file to a virgin state. Assumes +// FileBasedRingBuffer is already under lock. +func (self *BufferFile) _Truncate() { + _ = self.fd.Truncate(0) + self.Header.ReadPointer = FirstRecordOffset + self.Header.WritePointer = FirstRecordOffset + serialized, _ := self.Header.MarshalBinary() + _, _ = self.fd.WriteAt(serialized, 0) +} + +func (self *BufferFile) Reset() { + self.mu.Lock() + defer self.mu.Unlock() + + self._Truncate() +} + +// Closes the underlying file and shut down the readers. +func (self *BufferFile) Close() { + self.fd.Close() + os.Remove(self.fd.Name()) +} + +func NewBufferFile( + config_obj *config_proto.Config, fd *os.File) (*BufferFile, error) { + + log_ctx := logging.GetLogger(config_obj, &logging.FrontendComponent) + + header := &Header{ + // Pad the header a bit to allow for extensions. + WritePointer: FirstRecordOffset, + ReadPointer: FirstRecordOffset, + } + data := make([]byte, FirstRecordOffset) + n, err := fd.ReadAt(data, 0) + if n > 0 && n < FirstRecordOffset && err == io.EOF { + log_ctx.Error("Possible corruption detected: file too short.") + err = fd.Truncate(0) + if err != nil { + return nil, err + } + } + + if n > 0 && (err == nil || err == io.EOF) { + err := header.UnmarshalBinary(data[:n]) + // The header is not valid, truncate the file and + // start again. + if err != nil { + log_ctx.Errorf("Possible corruption detected: %v.", err) + err = fd.Truncate(0) + if err != nil { + return nil, err + } + } + } + + result := &BufferFile{ + config_obj: config_obj, + fd: fd, + Header: header, + read_buf: make([]byte, 8), + write_buf: make([]byte, 8), + log_ctx: log_ctx, + } + + return result, nil +} diff --git a/services/journal/replication.go b/services/journal/replication.go index ec3d9c3ff9..bb746f33c3 100644 --- a/services/journal/replication.go +++ b/services/journal/replication.go @@ -15,7 +15,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" api_proto "www.velocidex.com/golang/velociraptor/api/proto" config_proto "www.velocidex.com/golang/velociraptor/config/proto" - "www.velocidex.com/golang/velociraptor/file_store/directory" "www.velocidex.com/golang/velociraptor/json" "www.velocidex.com/golang/velociraptor/logging" "www.velocidex.com/golang/velociraptor/services" @@ -39,14 +38,52 @@ var ( ) type ReplicationService struct { - config_obj *config_proto.Config - file_buffer *directory.FileBasedRingBuffer - tmpfile *os.File + config_obj *config_proto.Config + Buffer *BufferFile + tmpfile *os.File + ctx context.Context + RetryDuration time.Duration + + sender chan *api_proto.PushEventRequest api_client api_proto.APIClient closer func() error } +func (self *ReplicationService) pumpEventFromBufferFile() { + for { + event, err := self.Buffer.Lease() + // No events available or some other error, sleep and + // try again later. + if err != nil { + select { + case <-self.ctx.Done(): + return + + case <-time.After(self.RetryDuration): + continue + } + } + + // Retry to send the event. + for { + _, err := self.api_client.PushEvents(self.ctx, event) + if err == nil { + break + } + // We are unable to send it, sleep and + // try again later. + select { + case <-self.ctx.Done(): + return + + case <-time.After(self.RetryDuration): + continue + } + } + } +} + func (self *ReplicationService) Start( ctx context.Context, wg *sync.WaitGroup) (err error) { @@ -56,27 +93,54 @@ func (self *ReplicationService) Start( if err != nil { return err } + + // Initialize our default values and start the service for + // real. self.api_client = api_client self.closer = closer + self.ctx = ctx + self.sender = make(chan *api_proto.PushEventRequest, 100) + self.RetryDuration = time.Second self.tmpfile, err = ioutil.TempFile("", "replication") if err != nil { return err } - self.file_buffer, err = directory.NewFileBasedRingBuffer( - self.config_obj, self.tmpfile) + self.Buffer, err = NewBufferFile(self.config_obj, self.tmpfile) if err != nil { return err } + go self.pumpEventFromBufferFile() + wg.Add(1) go func() { defer wg.Done() - defer self.Close() - <-ctx.Done() + for { + select { + case <-ctx.Done(): + return + + // Read events from the channel and + // try to send them + case request, ok := <-self.sender: + if !ok { + return + } + _, err = self.api_client.PushEvents(ctx, request) + if err != nil { + replicationTotalSendErrors.Inc() + + // Attempt to push the events + // to the buffer file instead + // for later delivery. + _ = self.Buffer.Enqueue(request) + } + } + } }() logger := logging.GetLogger(self.config_obj, &logging.FrontendComponent) @@ -91,9 +155,6 @@ func (self *ReplicationService) PushRowsToArtifact( replicationTotalSent.Inc() - // FIXME: implement buffer file here. - ctx := context.Background() - serialized, err := json.MarshalJsonl(rows) if err != nil { return err @@ -110,11 +171,14 @@ func (self *ReplicationService) PushRowsToArtifact( logger.Debug("ReplicationService Sending %v rows to %v for %v.", len(rows), artifact, client_id) - _, err = self.api_client.PushEvents(ctx, request) - if err != nil { - replicationTotalSendErrors.Inc() + // Should not block! If the channel is full we save the event + // into the file buffer for later. + select { + case self.sender <- request: + return nil + default: + return self.Buffer.Enqueue(request) } - return err } func (self *ReplicationService) Watch(ctx context.Context, queue string) ( @@ -131,7 +195,12 @@ func (self *ReplicationService) Watch(ctx context.Context, queue string) ( output_chan <- event } - time.Sleep(10 * time.Second) + select { + case <-self.ctx.Done(): + return + case <-time.After(self.RetryDuration): + } + logger := logging.GetLogger(self.config_obj, &logging.FrontendComponent) logger.Info("ReplicationService Reconnect: "+ @@ -194,6 +263,6 @@ func (self *ReplicationService) watchOnce(ctx context.Context, queue string) <-c func (self *ReplicationService) Close() { self.closer() - self.file_buffer.Close() + self.Buffer.Close() os.Remove(self.tmpfile.Name()) // clean up file buffer } diff --git a/services/journal/replication_test.go b/services/journal/replication_test.go new file mode 100644 index 0000000000..a9ef099fe7 --- /dev/null +++ b/services/journal/replication_test.go @@ -0,0 +1,213 @@ +package journal_test + +import ( + "context" + "testing" + "time" + + "github.com/Velocidex/ordereddict" + "github.com/alecthomas/assert" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/types/known/emptypb" + mock_proto "www.velocidex.com/golang/velociraptor/api/mock" + api_proto "www.velocidex.com/golang/velociraptor/api/proto" + "www.velocidex.com/golang/velociraptor/config" + config_proto "www.velocidex.com/golang/velociraptor/config/proto" + "www.velocidex.com/golang/velociraptor/file_store/test_utils" + "www.velocidex.com/golang/velociraptor/services" + "www.velocidex.com/golang/velociraptor/services/inventory" + "www.velocidex.com/golang/velociraptor/services/journal" + "www.velocidex.com/golang/velociraptor/services/launcher" + "www.velocidex.com/golang/velociraptor/services/notifications" + "www.velocidex.com/golang/velociraptor/services/repository" + "www.velocidex.com/golang/velociraptor/vtesting" +) + +type MockFrontendService struct { + mock *mock_proto.MockAPIClient +} + +func (self MockFrontendService) IsMaster() bool { + return false +} + +// The slave replicates to the master node. +func (self MockFrontendService) GetMasterAPIClient(ctx context.Context) ( + api_proto.APIClient, func() error, error) { + return self.mock, func() error { return nil }, nil +} + +type ReplicationTestSuite struct { + suite.Suite + config_obj *config_proto.Config + sm *services.Service + ctrl *gomock.Controller + mock *mock_proto.MockAPIClient +} + +func (self *ReplicationTestSuite) startServices() { + ctx, _ := context.WithTimeout(context.Background(), time.Second*60) + self.sm = services.NewServiceManager(ctx, self.config_obj) + + t := self.T() + assert.NoError(t, self.sm.Start(journal.StartJournalService)) + assert.NoError(t, self.sm.Start(notifications.StartNotificationService)) + assert.NoError(t, self.sm.Start(inventory.StartInventoryService)) + assert.NoError(t, self.sm.Start(launcher.StartLauncherService)) + assert.NoError(t, self.sm.Start(repository.StartRepositoryManagerForTest)) + + // Set retry to be faster. + journal_service, err := services.GetJournal() + assert.NoError(self.T(), err) + + replicator := journal_service.(*journal.ReplicationService) + replicator.RetryDuration = 100 * time.Millisecond +} + +func (self *ReplicationTestSuite) SetupTest() { + var err error + self.config_obj, err = new(config.Loader).WithFileLoader( + "../../http_comms/test_data/server.config.yaml"). + WithRequiredFrontend().WithWriteback(). + LoadAndValidate() + require.NoError(self.T(), err) + + self.ctrl = gomock.NewController(self.T()) + self.mock = mock_proto.NewMockAPIClient(self.ctrl) + + // Replication service only runs on the slave node. We mock + // the slave frontend manager so we can inject the RPC mock. + services.RegisterFrontendManager(&MockFrontendService{self.mock}) +} + +func (self *ReplicationTestSuite) TearDownTest() { + self.sm.Close() + self.ctrl.Finish() + + test_utils.GetMemoryFileStore(self.T(), self.config_obj).Clear() + test_utils.GetMemoryDataStore(self.T(), self.config_obj).Clear() +} + +func (self *ReplicationTestSuite) TestReplicationServiceStandardWatchers() { + + // The ReplicationService will call WatchEvents for both the + // Server.Internal.Ping and Server.Internal.Notifications + // queues. + stream := mock_proto.NewMockAPI_WatchEventClient(self.ctrl) + stream.EXPECT().Recv().AnyTimes().Return(nil, errors.New("Error")) + + // Record the WatchEvents calls + watched := []string{} + mock_watch_event_recorder := func( + ctx context.Context, in *api_proto.EventRequest) ( + api_proto.API_WatchEventClient, error) { + watched = append(watched, in.Queue) + return stream, nil + } + + self.mock.EXPECT().WatchEvent(gomock.Any(), gomock.Any()). + //gomock.AssignableToTypeOf(ctxInterface), + //gomock.AssignableToTypeOf(&api_proto.EventRequest{})). + DoAndReturn(mock_watch_event_recorder).AnyTimes() + + self.startServices() + + // Wait here until we call all the watchers. + vtesting.WaitUntil(2*time.Second, self.T(), func() bool { + return vtesting.CompareStrings(watched, []string{ + // Watch for ping requests from the + // master. This is used to let the master know + // if a client is connected to us. + "Server.Internal.Ping", + + // The notifications service will watch for + // notifications through us. + "Server.Internal.Notifications", + }) + }) +} + +func (self *ReplicationTestSuite) TestSendingEvents() { + self.TestReplicationServiceStandardWatchers() + + events := []*api_proto.PushEventRequest{} + var last_error error + + // Sending some rows to an event queue + record_push_event := func(ctx context.Context, + in *api_proto.PushEventRequest) (*emptypb.Empty, error) { + // On error do not capture the request + if last_error != nil { + return nil, last_error + } + + events = append(events, in) + return &emptypb.Empty{}, last_error + } + + // Push an event into the journal service on the slave. It + // will result in an RPC on the master to pass the event on. + self.mock.EXPECT().PushEvents(gomock.Any(), gomock.Any()). + DoAndReturn(record_push_event).AnyTimes() + + my_event := []*ordereddict.Dict{ + ordereddict.NewDict().Set("Foo", "Bar")} + + journal_service, err := services.GetJournal() + assert.NoError(self.T(), err) + + replicator := journal_service.(*journal.ReplicationService) + replicator.RetryDuration = 100 * time.Millisecond + + events = nil + err = journal_service.PushRowsToArtifact(self.config_obj, + my_event, "Test.Artifact", "C.1234", "F.123") + assert.NoError(self.T(), err) + + // Wait to see if the first event was properly delivered. + vtesting.WaitUntil(time.Second, self.T(), func() bool { + return len(events) > 0 + }) + assert.Equal(self.T(), len(events), 1) + + // Now emulate an RPC server error. + last_error = errors.New("Master is down!") + + events = nil + + // Pushing to the journal service will transparently queue the + // messages to a buffer file and will relay them later. NOTE: + // This does not block, callers can not be blocked since this + // is often on the critical path. We just dump 1000 messages + // into the queue - this should overflow into the file. + for i := 0; i < 1000; i++ { + err = journal_service.PushRowsToArtifact(self.config_obj, + my_event, "Test.Artifact", "C.1234", "F.123") + assert.NoError(self.T(), err) + } + + // Make sure we wrote something to the buffer file. + assert.True(self.T(), replicator.Buffer.Header.WritePointer > 2000) + + // Wait a while to allow events to be delivered. + time.Sleep(time.Second) + + // Still no event got through + assert.Equal(self.T(), len(events), 0) + + // Now enable the server, it should just deliver all the + // messages from the buffer file after a while as the + // ReplicationService will retry. + last_error = nil + vtesting.WaitUntil(time.Second, self.T(), func() bool { + return len(events) == 1000 + }) + assert.Equal(self.T(), len(events), 1000) +} + +func TestReplication(t *testing.T) { + suite.Run(t, &ReplicationTestSuite{}) +} diff --git a/services/journal/utils.go b/services/journal/utils.go index 06e6608e7e..1cf1c79808 100644 --- a/services/journal/utils.go +++ b/services/journal/utils.go @@ -8,7 +8,6 @@ import ( "github.com/pkg/errors" config_proto "www.velocidex.com/golang/velociraptor/config/proto" flows_proto "www.velocidex.com/golang/velociraptor/flows/proto" - "www.velocidex.com/golang/velociraptor/json" "www.velocidex.com/golang/velociraptor/logging" "www.velocidex.com/golang/velociraptor/services" "www.velocidex.com/golang/velociraptor/utils" @@ -39,7 +38,6 @@ func WatchForCollectionWithCB(ctx context.Context, // This is not what we are looking for. if !utils.InString(flow.ArtifactsWithResults, artifact) { - json.Dump(flow) return nil } diff --git a/vtesting/helpers.go b/vtesting/helpers.go index f06cb472ca..51351597bb 100644 --- a/vtesting/helpers.go +++ b/vtesting/helpers.go @@ -22,8 +22,11 @@ package vtesting import ( "io/ioutil" + "runtime/debug" "testing" "time" + + "www.velocidex.com/golang/velociraptor/utils" ) func ReadFile(t *testing.T, filename string) []byte { @@ -48,6 +51,20 @@ func (self RealClock) After(d time.Duration) <-chan time.Time { return time.After(d) } +// Compares lists of strings regardless of order. +func CompareStrings(expected []string, watched []string) bool { + if len(expected) != len(watched) { + return false + } + + for _, item := range watched { + if !utils.InString(expected, item) { + return false + } + } + return true +} + func WaitUntil(deadline time.Duration, t *testing.T, cb func() bool) { end_time := time.Now().Add(deadline) @@ -60,5 +77,5 @@ func WaitUntil(deadline time.Duration, t *testing.T, cb func() bool) { time.Sleep(50 * time.Millisecond) } - t.Fatalf("Timed out") + t.Fatalf("Timed out " + string(debug.Stack())) }