diff --git a/manual.go b/manual.go index d5fdb9d..9e14787 100644 --- a/manual.go +++ b/manual.go @@ -1,35 +1,71 @@ package progress +import ( + "sync" + "sync/atomic" +) + type Manual struct { - N int64 - Total int64 - Err error + n int64 + total int64 + err error + errMutex sync.Mutex +} + +func NewManual(size int64) *Manual { + return &Manual{ + total: size, + } } -func (p Manual) Current() int64 { - return int64(p.N) +func (p *Manual) Current() int64 { + return atomic.LoadInt64(&p.n) } -func (p Manual) Size() int64 { - return int64(p.Total) +func (p *Manual) Size() int64 { + return atomic.LoadInt64(&p.total) } -func (p Manual) Error() error { - return p.Err +func (p *Manual) Error() error { + p.errMutex.Lock() + defer p.errMutex.Unlock() + return p.err } -func (p Manual) Progress() Progress { +func (p *Manual) SetError(err error) { + p.errMutex.Lock() + defer p.errMutex.Unlock() + p.err = err +} + +func (p *Manual) Progress() Progress { return Progress{ - current: p.N, - size: p.Total, - err: p.Err, + current: p.Current(), + size: p.Size(), + err: p.Error(), } } +func (p *Manual) Add(n int64) { + atomic.AddInt64(&p.n, n) +} + +func (p *Manual) Increment() { + atomic.AddInt64(&p.n, 1) +} + +func (p *Manual) Set(n int64) { + atomic.StoreInt64(&p.n, n) +} + +func (p *Manual) SetTotal(total int64) { + atomic.StoreInt64(&p.total, total) +} + func (p *Manual) SetCompleted() { - p.Err = ErrCompleted - if p.N > 0 && p.Total <= 0 { - p.Total = p.N + p.SetError(ErrCompleted) + if p.Current() > 0 && p.Size() <= 0 { + p.SetTotal(p.Current()) return } } diff --git a/reader.go b/reader.go index 4b6c07a..f782662 100644 --- a/reader.go +++ b/reader.go @@ -8,31 +8,27 @@ import ( // Reader should wrap another reader (acts as a bytes pass through) type Reader struct { - reader io.Reader + reader io.Reader monitor *Manual } func NewSizedReader(reader io.Reader, size int64) *Reader { return &Reader{ - reader: reader, - monitor: &Manual{ - Total: size, - }, + reader: reader, + monitor: NewManual(size), } } func NewReader(reader io.Reader) *Reader { return &Reader{ - reader: reader, - monitor: &Manual{ - Total: -1, - }, + reader: reader, + monitor: NewManual(-1), } } func NewProxyReader(reader io.Reader, monitor *Manual) *Reader { return &Reader{ - reader: reader, + reader: reader, monitor: monitor, } } @@ -42,26 +38,26 @@ func (r *Reader) SetReader(reader io.Reader) { } func (r *Reader) SetCompleted() { - r.monitor.Err = multierror.Append(r.monitor.Err, ErrCompleted) + r.monitor.SetError(multierror.Append(r.monitor.Error(), ErrCompleted)) } func (r *Reader) Read(p []byte) (n int, err error) { bytes, err := r.reader.Read(p) - r.monitor.N += int64(bytes) + r.monitor.Add(int64(bytes)) if err != nil { - r.monitor.Err = multierror.Append(r.monitor.Err, err) + r.monitor.SetError(multierror.Append(r.monitor.Error(), err)) } return bytes, err } func (r *Reader) Current() int64 { - return r.monitor.N + return r.monitor.Current() } func (r *Reader) Size() int64 { - return r.monitor.Total + return r.monitor.Size() } func (r *Reader) Error() error { - return r.monitor.Err + return r.monitor.Error() }