diff --git a/option.go b/option.go index 8c5a59c191d8..4772f9881f4d 100644 --- a/option.go +++ b/option.go @@ -7,7 +7,7 @@ import ( ) // Global writer group for logs to output to -var WriterGroup = new(MirrorWriter) +var WriterGroup = NewMirrorWriter() type Option func() diff --git a/writer.go b/writer.go index 6f38bbe60682..b5f1b9a3c495 100644 --- a/writer.go +++ b/writer.go @@ -1,66 +1,240 @@ package log import ( + "fmt" "io" "sync" - "time" ) +var MaxWriterBuffer = 512 * 1024 + +var log = Logger("eventlog") + type MirrorWriter struct { - writers []io.WriteCloser - lk sync.Mutex + active bool + activelk sync.Mutex + + // channel for incoming writers + writerAdd chan io.WriteCloser + + // slices of writer/sync-channel pairs + writers []*bufWriter + + // synchronization channel for incoming writes + msgSync chan []byte +} + +type writerSync struct { + w io.WriteCloser + br chan []byte +} + +func NewMirrorWriter() *MirrorWriter { + mw := &MirrorWriter{ + msgSync: make(chan []byte, 64), // sufficiently large buffer to avoid callers waiting + writerAdd: make(chan io.WriteCloser), + } + + go mw.logRoutine() + + return mw } func (mw *MirrorWriter) Write(b []byte) (int, error) { - mw.lk.Lock() - // write to all writers, and nil out the broken ones. - var dropped bool - done := make(chan error, 1) - for i, w := range mw.writers { - go func(out chan error) { - _, err := w.Write(b) - out <- err - }(done) + mycopy := make([]byte, len(b)) + copy(mycopy, b) + mw.msgSync <- mycopy + return len(b), nil +} + +func (mw *MirrorWriter) Close() error { + // it is up to the caller to ensure that write is not called during or + // after close is called. + close(mw.msgSync) + return nil +} + +func (mw *MirrorWriter) doClose() { + for _, w := range mw.writers { + w.writer.Close() + } +} + +func (mw *MirrorWriter) logRoutine() { + // rebind to avoid races on nilling out struct fields + msgSync := mw.msgSync + writerAdd := mw.writerAdd + + defer mw.doClose() + + for { select { - case err := <-done: - if err != nil { - mw.writers[i].Close() - mw.writers[i] = nil - dropped = true + case b, ok := <-msgSync: + if !ok { + return } - case <-time.After(time.Millisecond * 500): - mw.writers[i].Close() + + // write to all writers + dropped := mw.broadcastMessage(b) + + // consolidate the slice + if dropped { + mw.clearDeadWriters() + } + case w := <-writerAdd: + mw.writers = append(mw.writers, newBufWriter(w)) + + mw.activelk.Lock() + mw.active = true + mw.activelk.Unlock() + } + } +} + +// broadcastMessage sends the given message to every writer +// if any writer is killed during the send, 'true' is returned +func (mw *MirrorWriter) broadcastMessage(b []byte) bool { + var dropped bool + for i, w := range mw.writers { + _, err := w.Write(b) + if err != nil { mw.writers[i] = nil dropped = true - - // clear channel out - done = make(chan error, 1) } } + return dropped +} - // consolidate the slice - if dropped { - writers := mw.writers - mw.writers = nil - for _, w := range writers { - if w != nil { - mw.writers = append(mw.writers, w) - } +func (mw *MirrorWriter) clearDeadWriters() { + writers := mw.writers + mw.writers = nil + for _, w := range writers { + if w != nil { + mw.writers = append(mw.writers, w) } } - mw.lk.Unlock() - return len(b), nil + if len(mw.writers) == 0 { + mw.activelk.Lock() + mw.active = false + mw.activelk.Unlock() + } } func (mw *MirrorWriter) AddWriter(w io.WriteCloser) { - mw.lk.Lock() - mw.writers = append(mw.writers, w) - mw.lk.Unlock() + mw.writerAdd <- w } func (mw *MirrorWriter) Active() (active bool) { - mw.lk.Lock() - active = len(mw.writers) > 0 - mw.lk.Unlock() + mw.activelk.Lock() + active = mw.active + mw.activelk.Unlock() return } + +func newBufWriter(w io.WriteCloser) *bufWriter { + bw := &bufWriter{ + writer: w, + incoming: make(chan []byte, 1), + } + + go bw.loop() + return bw +} + +type bufWriter struct { + writer io.WriteCloser + + incoming chan []byte + + deathLock sync.Mutex + dead bool +} + +var errDeadWriter = fmt.Errorf("writer is dead") + +func (bw *bufWriter) Write(b []byte) (int, error) { + bw.deathLock.Lock() + dead := bw.dead + bw.deathLock.Unlock() + if dead { + if bw.incoming != nil { + close(bw.incoming) + bw.incoming = nil + } + return 0, errDeadWriter + } + + bw.incoming <- b + return len(b), nil +} + +func (bw *bufWriter) die() { + bw.deathLock.Lock() + bw.dead = true + bw.writer.Close() + bw.deathLock.Unlock() +} + +func (bw *bufWriter) loop() { + bufsize := 0 + bufBase := make([][]byte, 0, 16) // some initial memory + buffered := bufBase + nextCh := make(chan []byte) + + var nextMsg []byte + + go func() { + for b := range nextCh { + _, err := bw.writer.Write(b) + if err != nil { + log.Info("eventlog write error: %s", err) + bw.die() + return + } + } + }() + + // collect and buffer messages + incoming := bw.incoming + for { + if nextMsg == nil || nextCh == nil { + // nextCh == nil implies we are 'dead' and draining the incoming channel + // until the caller notices and closes it for us + select { + case b, ok := <-incoming: + if !ok { + return + } + nextMsg = b + } + } + + select { + case b, ok := <-incoming: + if !ok { + return + } + bufsize += len(b) + buffered = append(buffered, b) + if bufsize > MaxWriterBuffer { + // if we have too many messages buffered, kill the writer + bw.die() + close(nextCh) + nextCh = nil + // explicity keep going here to drain incoming + } + case nextCh <- nextMsg: + nextMsg = nil + if len(buffered) > 0 { + nextMsg = buffered[0] + buffered = buffered[1:] + bufsize -= len(nextMsg) + } + + if len(buffered) == 0 { + // reset slice position + buffered = bufBase[:0] + } + } + } +} diff --git a/writer_test.go b/writer_test.go new file mode 100644 index 000000000000..55466abc05d1 --- /dev/null +++ b/writer_test.go @@ -0,0 +1,160 @@ +package log + +import ( + "fmt" + "hash/fnv" + "io" + "sync" + "testing" + "time" + + randbo "github.com/dustin/randbo" +) + +type hangwriter struct { + c chan struct{} +} + +func newHangWriter() *hangwriter { + return &hangwriter{make(chan struct{})} +} + +func (hw *hangwriter) Write([]byte) (int, error) { + <-make(chan struct{}) + return 0, fmt.Errorf("write on closed writer") +} + +func (hw *hangwriter) Close() error { + close(hw.c) + return nil +} + +func TestMirrorWriterHang(t *testing.T) { + mw := NewMirrorWriter() + + hw := newHangWriter() + pr, pw := io.Pipe() + + mw.AddWriter(hw) + mw.AddWriter(pw) + + msg := "Hello!" + mw.Write([]byte(msg)) + + // make sure writes through can happen even with one writer hanging + done := make(chan struct{}) + go func() { + buf := make([]byte, 10) + n, err := pr.Read(buf) + if err != nil { + t.Fatal(err) + } + + if n != len(msg) { + t.Fatal("read wrong amount") + } + + if string(buf[:n]) != msg { + t.Fatal("didnt read right content") + } + + done <- struct{}{} + }() + + select { + case <-time.After(time.Second * 5): + t.Fatal("write to mirrorwriter hung") + case <-done: + } + + if !mw.Active() { + t.Fatal("writer should still be active") + } + + pw.Close() + + if !mw.Active() { + t.Fatal("writer should still be active") + } + + // now we just have the hangwriter + + // write a bunch to it + buf := make([]byte, 8192) + for i := 0; i < 128; i++ { + mw.Write(buf) + } + + // wait for goroutines to sync up + time.Sleep(time.Millisecond * 500) + + // the hangwriter should have been killed, causing the mirrorwriter to be inactive now + if mw.Active() { + t.Fatal("should be inactive now") + } +} + +func TestStress(t *testing.T) { + mw := NewMirrorWriter() + + nreaders := 20 + + var readers []io.Reader + for i := 0; i < nreaders; i++ { + pr, pw := io.Pipe() + mw.AddWriter(pw) + readers = append(readers, pr) + } + + hashout := make(chan []byte) + + numwriters := 20 + writesize := 1024 + writecount := 300 + + f := func(r io.Reader) { + h := fnv.New64a() + sum, err := io.Copy(h, r) + if err != nil { + t.Fatal(err) + } + + if sum != int64(numwriters*writesize*writecount) { + t.Fatal("read wrong number of bytes") + } + + hashout <- h.Sum(nil) + } + + for _, r := range readers { + go f(r) + } + + work := sync.WaitGroup{} + for i := 0; i < numwriters; i++ { + work.Add(1) + go func() { + defer work.Done() + r := randbo.New() + buf := make([]byte, writesize) + for j := 0; j < writecount; j++ { + r.Read(buf) + mw.Write(buf) + time.Sleep(time.Millisecond * 5) + } + }() + } + + work.Wait() + mw.Close() + + check := make(map[string]bool) + for i := 0; i < nreaders; i++ { + h := <-hashout + check[string(h)] = true + } + + if len(check) > 1 { + t.Fatal("writers received different data!") + } +}