From a2c4bad40c092ed6fd45c94cbda11e9ffba2d7a6 Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Tue, 3 Jan 2023 15:19:38 -0700 Subject: [PATCH] kgo: universally switch to 1.19's atomics if on Go 1.19+ The current lint on arm should be ensuring alignment is proper, but apparently that is not always the case, as seen in #286. Go has compiler intrinsics to ensure proper alignment for the actual atomic number types introduced in 1.19. This doesn't fix 1.18, but it should fix 1.19+. Closes #286. --- pkg/kgo/atomic_maybe_work.go | 18 ++++++------ pkg/kgo/broker.go | 54 ++++++++++++++++++------------------ pkg/kgo/consumer.go | 8 +++--- pkg/kgo/go118.go | 31 +++++++++++++++++++++ pkg/kgo/go119.go | 10 +++++-- pkg/kgo/group_test.go | 5 ++-- pkg/kgo/helpers_test.go | 7 ++--- pkg/kgo/partitioner.go | 3 +- pkg/kgo/producer.go | 44 ++++++++++++++--------------- pkg/kgo/record_formatter.go | 5 ++-- pkg/kgo/sink.go | 36 ++++++++++++------------ pkg/kgo/source.go | 17 ++++++------ pkg/kgo/txn.go | 9 +++--- pkg/kgo/txn_test.go | 5 ++-- 14 files changed, 139 insertions(+), 113 deletions(-) diff --git a/pkg/kgo/atomic_maybe_work.go b/pkg/kgo/atomic_maybe_work.go index 15fddef5..10e51d6e 100644 --- a/pkg/kgo/atomic_maybe_work.go +++ b/pkg/kgo/atomic_maybe_work.go @@ -1,26 +1,24 @@ package kgo -import "sync/atomic" - const ( stateUnstarted = iota stateWorking stateContinueWorking ) -type workLoop struct{ state uint32 } +type workLoop struct{ state atomicU32 } // maybeBegin returns whether a work loop should begin. func (l *workLoop) maybeBegin() bool { var state uint32 var done bool for !done { - switch state = atomic.LoadUint32(&l.state); state { + switch state = l.state.Load(); state { case stateUnstarted: - done = atomic.CompareAndSwapUint32(&l.state, state, stateWorking) + done = l.state.CompareAndSwap(state, stateWorking) state = stateWorking case stateWorking: - done = atomic.CompareAndSwapUint32(&l.state, state, stateContinueWorking) + done = l.state.CompareAndSwap(state, stateContinueWorking) state = stateContinueWorking case stateContinueWorking: done = true @@ -43,18 +41,18 @@ func (l *workLoop) maybeBegin() bool { // since the loop itself calls MaybeFinish after it has been started, this // should never be called if the loop is unstarted. func (l *workLoop) maybeFinish(again bool) bool { - switch state := atomic.LoadUint32(&l.state); state { + switch state := l.state.Load(); state { // Working: // If again, we know we should continue; keep our state. // If not again, we try to downgrade state and stop. // If we cannot, then something slipped in to say keep going. case stateWorking: if !again { - again = !atomic.CompareAndSwapUint32(&l.state, state, stateUnstarted) + again = !l.state.CompareAndSwap(state, stateUnstarted) } // Continue: demote ourself and run again no matter what. case stateContinueWorking: - atomic.StoreUint32(&l.state, stateWorking) + l.state.Store(stateWorking) again = true } @@ -62,5 +60,5 @@ func (l *workLoop) maybeFinish(again bool) bool { } func (l *workLoop) hardFinish() { - atomic.StoreUint32(&l.state, stateUnstarted) + l.state.Store(stateUnstarted) } diff --git a/pkg/kgo/broker.go b/pkg/kgo/broker.go index d26c0aeb..cd525078 100644 --- a/pkg/kgo/broker.go +++ b/pkg/kgo/broker.go @@ -155,7 +155,7 @@ type broker struct { // reqs manages incoming message requests. reqs ringReq // dead is an atomic so a backed up reqs cannot block broker stoppage. - dead int32 + dead atomicBool } // brokerVersions is loaded once (and potentially a few times concurrently if @@ -214,7 +214,7 @@ func (cl *Client) newBroker(nodeID int32, host string, port int32, rack *string) // stopForever permanently disables this broker. func (b *broker) stopForever() { - if atomic.SwapInt32(&b.dead, 1) == 1 { + if b.dead.Swap(true) { return } @@ -502,7 +502,7 @@ func (b *broker) loadConnection(ctx context.Context, req kmsg.Request) (*brokerC pcxn = &b.cxnSlow } - if *pcxn != nil && atomic.LoadInt32(&(*pcxn).dead) == 0 { + if *pcxn != nil && !(*pcxn).dead.Load() { return *pcxn, nil } @@ -581,7 +581,7 @@ func (b *broker) reapConnections(idleTimeout time.Duration) (total int) { b.cxnGroup, b.cxnSlow, } { - if cxn == nil || atomic.LoadInt32(&cxn.dead) == 1 { + if cxn == nil || cxn.dead.Load() { continue } @@ -592,11 +592,11 @@ func (b *broker) reapConnections(idleTimeout time.Duration) (total int) { // - produce can write but never read // - fetch can hang for a while reading (infrequent writes) - lastWrite := time.Unix(0, atomic.LoadInt64(&cxn.lastWrite)) - lastRead := time.Unix(0, atomic.LoadInt64(&cxn.lastRead)) + lastWrite := time.Unix(0, cxn.lastWrite.Load()) + lastRead := time.Unix(0, cxn.lastRead.Load()) - writeIdle := time.Since(lastWrite) > idleTimeout && atomic.LoadUint32(&cxn.writing) == 0 - readIdle := time.Since(lastRead) > idleTimeout && atomic.LoadUint32(&cxn.reading) == 0 + writeIdle := time.Since(lastWrite) > idleTimeout && !cxn.writing.Load() + readIdle := time.Since(lastRead) > idleTimeout && !cxn.reading.Load() if writeIdle && readIdle { cxn.die() @@ -634,7 +634,7 @@ func (b *broker) connect(ctx context.Context) (net.Conn, error) { // brokerCxn manages an actual connection to a Kafka broker. This is separate // the broker struct to allow lazy connection (re)creation. type brokerCxn struct { - throttleUntil int64 // atomic nanosec + throttleUntil atomicI64 // atomic nanosec conn net.Conn @@ -651,17 +651,17 @@ type brokerCxn struct { // The following four fields are used for connection reaping. // Write is only updated in one location; read is updated in three // due to readConn, readConnAsync, and discard. - lastWrite int64 - lastRead int64 - writing uint32 - reading uint32 + lastWrite atomicI64 + lastRead atomicI64 + writing atomicBool + reading atomicBool successes uint64 // resps manages reading kafka responses. resps ringResp // dead is an atomic so that a backed up resps cannot block cxn death. - dead int32 + dead atomicBool // closed in cloneConn; allows throttle waiting to quit deadCh chan struct{} } @@ -982,7 +982,7 @@ func maybeUpdateCtxErr(clientCtx, reqCtx context.Context, err *error) { func (cxn *brokerCxn) writeRequest(ctx context.Context, enqueuedForWritingAt time.Time, req kmsg.Request) (corrID int32, bytesWritten int, writeWait, timeToWrite time.Duration, readEnqueue time.Time, writeErr error) { // A nil ctx means we cannot be throttled. if ctx != nil { - throttleUntil := time.Unix(0, atomic.LoadInt64(&cxn.throttleUntil)) + throttleUntil := time.Unix(0, cxn.throttleUntil.Load()) if sleep := time.Until(throttleUntil); sleep > 0 { after := time.NewTimer(sleep) select { @@ -1037,10 +1037,10 @@ func (cxn *brokerCxn) writeConn( timeout time.Duration, enqueuedForWritingAt time.Time, ) (bytesWritten int, writeWait, timeToWrite time.Duration, readEnqueue time.Time, writeErr error) { - atomic.SwapUint32(&cxn.writing, 1) + cxn.writing.Store(true) defer func() { - atomic.StoreInt64(&cxn.lastWrite, time.Now().UnixNano()) - atomic.SwapUint32(&cxn.writing, 0) + cxn.lastWrite.Store(time.Now().UnixNano()) + cxn.writing.Store(false) }() if ctx == nil { @@ -1085,10 +1085,10 @@ func (cxn *brokerCxn) readConn( timeout time.Duration, enqueuedForReadingAt time.Time, ) (nread int, buf []byte, readWait, timeToRead time.Duration, err error) { - atomic.SwapUint32(&cxn.reading, 1) + cxn.reading.Store(true) defer func() { - atomic.StoreInt64(&cxn.lastRead, time.Now().UnixNano()) - atomic.SwapUint32(&cxn.reading, 0) + cxn.lastRead.Store(time.Now().UnixNano()) + cxn.reading.Store(false) }() if ctx == nil { @@ -1256,7 +1256,7 @@ func (cxn *brokerCxn) closeConn() { // die kills a broker connection (which could be dead already) and replies to // all requests awaiting responses appropriately. func (cxn *brokerCxn) die() { - if cxn == nil || atomic.SwapInt32(&cxn.dead, 1) == 1 { + if cxn == nil || cxn.dead.Swap(true) { return } cxn.closeConn() @@ -1364,10 +1364,10 @@ func (cxn *brokerCxn) discard() { } deadlineMu.Unlock() - atomic.SwapUint32(&cxn.reading, 1) + cxn.reading.Store(true) defer func() { - atomic.StoreInt64(&cxn.lastRead, time.Now().UnixNano()) - atomic.SwapUint32(&cxn.reading, 0) + cxn.lastRead.Store(time.Now().UnixNano()) + cxn.reading.Store(false) }() readStart := time.Now() @@ -1470,8 +1470,8 @@ func (cxn *brokerCxn) handleResp(pr promisedResp) { if millis > 0 { if throttlesAfterResp { throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano() - if throttleUntil > cxn.throttleUntil { - atomic.StoreInt64(&cxn.throttleUntil, throttleUntil) + if throttleUntil > cxn.throttleUntil.Load() { + cxn.throttleUntil.Store(throttleUntil) } } cxn.cl.cfg.hooks.each(func(h Hook) { diff --git a/pkg/kgo/consumer.go b/pkg/kgo/consumer.go index 92c72c3a..f5ceed1e 100644 --- a/pkg/kgo/consumer.go +++ b/pkg/kgo/consumer.go @@ -152,7 +152,7 @@ func (o Offset) At(at int64) Offset { } type consumer struct { - bufferedRecords int64 + bufferedRecords atomicI64 cl *Client @@ -272,7 +272,7 @@ func (c *consumer) unaddRebalance() { // problematic if for you if this function is consistently returning large // values. func (cl *Client) BufferedFetchRecords() int64 { - return atomic.LoadInt64(&cl.consumer.bufferedRecords) + return cl.consumer.bufferedRecords.Load() } type usedCursors map[*cursor]struct{} @@ -1224,7 +1224,7 @@ type consumerSession struct { desireFetchCh chan chan chan struct{} cancelFetchCh chan chan chan struct{} allowedFetches int - fetchManagerStarted uint32 // atomic, once 1, we start the fetch manager + fetchManagerStarted atomicBool // atomic, once true, we start the fetch manager // Workers signify the number of fetch and list / epoch goroutines that // are currently running within the context of this consumer session. @@ -1278,7 +1278,7 @@ func (c *consumer) newConsumerSession(tps *topicsPartitions) *consumerSession { } func (s *consumerSession) desireFetch() chan chan chan struct{} { - if atomic.SwapUint32(&s.fetchManagerStarted, 1) == 0 { + if !s.fetchManagerStarted.Swap(true) { go s.manageFetchConcurrency() } return s.desireFetchCh diff --git a/pkg/kgo/go118.go b/pkg/kgo/go118.go index 46b29b47..483c3e91 100644 --- a/pkg/kgo/go118.go +++ b/pkg/kgo/go118.go @@ -24,3 +24,34 @@ func (b *atomicBool) Swap(v bool) bool { } return atomic.SwapUint32((*uint32)(b), swap) == 1 } + +type atomicI32 int32 + +func (v *atomicI32) Add(s int32) int32 { return atomic.AddInt32((*int32)(v), s) } +func (v *atomicI32) Store(s int32) { atomic.StoreInt32((*int32)(v), s) } +func (v *atomicI32) Load() int32 { return atomic.LoadInt32((*int32)(v)) } +func (v *atomicI32) Swap(s int32) int32 { return atomic.SwapInt32((*int32)(v), s) } + +type atomicU32 uint32 + +func (v *atomicU32) Add(s uint32) uint32 { return atomic.AddUint32((*uint32)(v), s) } +func (v *atomicU32) Store(s uint32) { atomic.StoreUint32((*uint32)(v), s) } +func (v *atomicU32) Load() uint32 { return atomic.LoadUint32((*uint32)(v)) } +func (v *atomicU32) Swap(s uint32) uint32 { return atomic.SwapUint32((*uint32)(v), s) } +func (v *atomicU32) CompareAndSwap(old, new uint32) bool { + return atomic.CompareAndSwapUint32((*uint32)(v), old, new) +} + +type atomicI64 int64 + +func (v *atomicI64) Add(s int64) int64 { return atomic.AddInt64((*int64)(v), s) } +func (v *atomicI64) Store(s int64) { atomic.StoreInt64((*int64)(v), s) } +func (v *atomicI64) Load() int64 { return atomic.LoadInt64((*int64)(v)) } +func (v *atomicI64) Swap(s int64) int64 { return atomic.SwapInt64((*int64)(v), s) } + +type atomicU64 uint64 + +func (v *atomicU64) Add(s uint64) uint64 { return atomic.AddUint64((*uint64)(v), s) } +func (v *atomicU64) Store(s uint64) { atomic.StoreUint64((*uint64)(v), s) } +func (v *atomicU64) Load() uint64 { return atomic.LoadUint64((*uint64)(v)) } +func (v *atomicU64) Swap(s uint64) uint64 { return atomic.SwapUint64((*uint64)(v), s) } diff --git a/pkg/kgo/go119.go b/pkg/kgo/go119.go index ac35e915..7c8ade5e 100644 --- a/pkg/kgo/go119.go +++ b/pkg/kgo/go119.go @@ -5,6 +5,10 @@ package kgo import "sync/atomic" -type atomicBool struct { - atomic.Bool -} +type ( + atomicBool struct{ atomic.Bool } + atomicI32 struct{ atomic.Int32 } + atomicU32 struct{ atomic.Uint32 } + atomicI64 struct{ atomic.Int64 } + atomicU64 struct{ atomic.Uint64 } +) diff --git a/pkg/kgo/group_test.go b/pkg/kgo/group_test.go index 19c6ca86..4a03b013 100644 --- a/pkg/kgo/group_test.go +++ b/pkg/kgo/group_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "strconv" - "sync/atomic" "testing" "time" ) @@ -178,7 +177,7 @@ func (c *testConsumer) etl(etlsBeforeQuit int) { fetches := cl.PollRecords(ctx, 100) cancel() if fetches.Err() == context.DeadlineExceeded || fetches.Err() == ErrClientClosed { - if consumed := int(atomic.LoadUint64(&c.consumed)); consumed == testRecordLimit { + if consumed := int(c.consumed.Load()); consumed == testRecordLimit { return } else if consumed > testRecordLimit { panic(fmt.Sprintf("invalid: consumed too much from %s (group %s)", c.consumeFrom, c.group)) @@ -217,7 +216,7 @@ func (c *testConsumer) etl(etlsBeforeQuit int) { c.mu.Unlock() - atomic.AddUint64(&c.consumed, 1) + c.consumed.Add(1) cl.Produce( context.Background(), diff --git a/pkg/kgo/helpers_test.go b/pkg/kgo/helpers_test.go index 5d36194a..85c42cf4 100644 --- a/pkg/kgo/helpers_test.go +++ b/pkg/kgo/helpers_test.go @@ -12,7 +12,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "testing" "time" @@ -66,7 +65,7 @@ func getSeedBrokers() Opt { return SeedBrokers(strings.Split(seeds, ",")...) } -var loggerNum int64 +var loggerNum atomicI64 var testLogLevel = func() LogLevel { level := strings.ToLower(os.Getenv("KGO_LOG_LEVEL")) @@ -80,7 +79,7 @@ var testLogLevel = func() LogLevel { }() func testLogger() Logger { - num := atomic.AddInt64(&loggerNum, 1) + num := loggerNum.Add(1) pfx := strconv.Itoa(int(num)) return BasicLogger(os.Stderr, testLogLevel, func() string { return time.Now().Format("[15:04:05 ") + pfx + "]" @@ -193,7 +192,7 @@ type testConsumer struct { expBody []byte // what every record body should be - consumed uint64 // shared atomically + consumed atomicU64 // shared atomically wg sync.WaitGroup mu sync.Mutex diff --git a/pkg/kgo/partitioner.go b/pkg/kgo/partitioner.go index 793f98f0..a2123698 100644 --- a/pkg/kgo/partitioner.go +++ b/pkg/kgo/partitioner.go @@ -3,7 +3,6 @@ package kgo import ( "math" "math/rand" - "sync/atomic" "time" "github.com/twmb/franz-go/pkg/kbin" @@ -200,7 +199,7 @@ type ( func (i *leastBackupInput) Next() (int, int64) { last := len(i.mapping) - 1 - buffered := atomic.LoadInt64(&i.mapping[last].records.buffered) + buffered := i.mapping[last].records.buffered.Load() i.mapping = i.mapping[:last] return last, buffered } diff --git a/pkg/kgo/producer.go b/pkg/kgo/producer.go index 02f3bd1a..72bfdfad 100644 --- a/pkg/kgo/producer.go +++ b/pkg/kgo/producer.go @@ -14,8 +14,8 @@ import ( ) type producer struct { - bufferedRecords int64 - inflight int64 // high 16: # waiters, low 48: # inflight + bufferedRecords atomicI64 + inflight atomicI64 // high 16: # waiters, low 48: # inflight cl *Client @@ -38,14 +38,14 @@ type producer struct { unknownTopics map[string]*unknownTopicProduces id atomic.Value - producingTxn uint32 // 1 if in txn + producingTxn atomicBool // We must have a producer field for flushing; we cannot just have a // field on recBufs that is toggled on flush. If we did, then a new // recBuf could be created and records sent to while we are flushing. - flushing int32 // >0 if flushing, can Flush many times concurrently + flushing atomicI32 // >0 if flushing, can Flush many times concurrently - aborting int32 // >0 if aborting, can abort many times concurrently + aborting atomicI32 // >0 if aborting, can abort many times concurrently idMu sync.Mutex idVersion int16 @@ -83,7 +83,7 @@ type producer struct { // flushing records produced by your client (which can help determine network / // cluster health). func (cl *Client) BufferedProduceRecords() int64 { - return atomic.LoadInt64(&cl.producer.bufferedRecords) + return cl.producer.bufferedRecords.Load() } type unknownTopicProduces struct { @@ -184,7 +184,7 @@ func (p *producer) purgeTopics(topics []string) { } } -func (p *producer) isAborting() bool { return atomic.LoadInt32(&p.aborting) > 0 } +func (p *producer) isAborting() bool { return p.aborting.Load() > 0 } func noPromise(*Record, error) {} @@ -258,7 +258,7 @@ func (cl *Client) ProduceSync(ctx context.Context, rs ...*Record) ProduceResults // This is similar to using ProduceResult's FirstErr function. type FirstErrPromise struct { wg sync.WaitGroup - once uint32 + once atomicBool err error cl *Client } @@ -278,7 +278,7 @@ func AbortingFirstErrPromise(cl *Client) *FirstErrPromise { // encountered. func (f *FirstErrPromise) promise(_ *Record, err error) { defer f.wg.Done() - if err != nil && atomic.SwapUint32(&f.once, 1) == 0 { + if err != nil && !f.once.Swap(true) { f.err = err if f.cl != nil { f.wg.Add(1) @@ -385,12 +385,12 @@ func (cl *Client) produce( p.promiseRecord(promisedRec{ctx, promise, r}, errNoTopic) return } - if cl.cfg.txnID != nil && atomic.LoadUint32(&p.producingTxn) != 1 { + if cl.cfg.txnID != nil && !p.producingTxn.Load() { p.promiseRecord(promisedRec{ctx, promise, r}, errNotInTransaction) return } - if atomic.AddInt64(&p.bufferedRecords, 1) > cl.cfg.maxBufferedRecords { + if p.bufferedRecords.Add(1) > cl.cfg.maxBufferedRecords { // If the client ctx cancels or the produce ctx cancels, we // need to un-count our buffering of this record. We also need // to drain a slot from the waitBuffer chan, which could be @@ -476,10 +476,10 @@ func (cl *Client) finishRecordPromise(pr promisedRec, err error) { // before Flush returns. pr.promise(pr.Record, err) - buffered := atomic.AddInt64(&p.bufferedRecords, -1) + buffered := p.bufferedRecords.Add(-1) if buffered >= cl.cfg.maxBufferedRecords { p.waitBuffer <- struct{}{} - } else if buffered == 0 && atomic.LoadInt32(&p.flushing) > 0 { + } else if buffered == 0 && p.flushing.Load() > 0 { p.mu.Lock() p.mu.Unlock() //nolint:gocritic,staticcheck // We use the lock as a barrier, unlocking immediately is safe. p.c.Broadcast() @@ -918,8 +918,8 @@ func (cl *Client) Flush(ctx context.Context) error { // Signal to finishRecord that we want to be notified once buffered hits 0. // Also forbid any new producing to start a linger. - atomic.AddInt32(&p.flushing, 1) - defer atomic.AddInt32(&p.flushing, -1) + p.flushing.Add(1) + defer p.flushing.Add(-1) cl.cfg.logger.Log(LogLevelInfo, "flushing") defer cl.cfg.logger.Log(LogLevelDebug, "flushed") @@ -943,7 +943,7 @@ func (cl *Client) Flush(ctx context.Context) error { defer p.mu.Unlock() defer close(done) - for !quit && atomic.LoadInt64(&p.bufferedRecords) > 0 { + for !quit && p.bufferedRecords.Load() > 0 { p.c.Wait() } }() @@ -961,7 +961,7 @@ func (cl *Client) Flush(ctx context.Context) error { } func (p *producer) pause(ctx context.Context) error { - atomic.AddInt64(&p.inflight, 1<<48) + p.inflight.Add(1 << 48) quit := false done := make(chan struct{}) @@ -969,7 +969,7 @@ func (p *producer) pause(ctx context.Context) error { p.mu.Lock() defer p.mu.Unlock() defer close(done) - for !quit && atomic.LoadInt64(&p.inflight)&((1<<48)-1) != 0 { + for !quit && p.inflight.Load()&((1<<48)-1) != 0 { p.c.Wait() } }() @@ -988,7 +988,7 @@ func (p *producer) pause(ctx context.Context) error { } func (p *producer) resume() { - if atomic.AddInt64(&p.inflight, -1<<48) == 0 { + if p.inflight.Add(-1<<48) == 0 { p.cl.allSinksAndSources(func(sns sinkAndSource) { sns.sink.maybeDrain() }) @@ -996,10 +996,10 @@ func (p *producer) resume() { } func (p *producer) maybeAddInflight() bool { - if atomic.LoadInt64(&p.inflight)>>48 > 0 { + if p.inflight.Load()>>48 > 0 { return false } - if atomic.AddInt64(&p.inflight, 1)>>48 > 0 { + if p.inflight.Add(1)>>48 > 0 { p.decInflight() return false } @@ -1007,7 +1007,7 @@ func (p *producer) maybeAddInflight() bool { } func (p *producer) decInflight() { - if atomic.AddInt64(&p.inflight, -1)>>48 > 0 { + if p.inflight.Add(-1)>>48 > 0 { p.c.Broadcast() } } diff --git a/pkg/kgo/record_formatter.go b/pkg/kgo/record_formatter.go index 8e3dade9..06e4fec9 100644 --- a/pkg/kgo/record_formatter.go +++ b/pkg/kgo/record_formatter.go @@ -13,7 +13,6 @@ import ( "regexp" "strconv" "strings" - "sync/atomic" "time" "unicode/utf8" @@ -26,7 +25,7 @@ import ( // RecordFormatter formats records. type RecordFormatter struct { - calls int64 + calls atomicI64 fns []func([]byte, *FetchPartition, *Record) []byte } @@ -339,7 +338,7 @@ func NewRecordFormatter(layout string) (*RecordFormatter, error) { }) case 'i': f.fns = append(f.fns, func(b []byte, _ *FetchPartition, _ *Record) []byte { - return numfn(b, atomic.AddInt64(&f.calls, 1)) + return numfn(b, f.calls.Add(1)) }) case 'x': f.fns = append(f.fns, func(b []byte, _ *FetchPartition, r *Record) []byte { diff --git a/pkg/kgo/sink.go b/pkg/kgo/sink.go index ea7c4eb4..8ac9babf 100644 --- a/pkg/kgo/sink.go +++ b/pkg/kgo/sink.go @@ -26,7 +26,7 @@ type sink struct { // response, we check what version was set in the request. If it is at // least 4, which 1.0 introduced, we upgrade the sem size. inflightSem atomic.Value - produceVersion int32 // atomic, negative is unset, positive is version + produceVersion atomicI32 // negative is unset, positive is version drainState workLoop @@ -43,7 +43,7 @@ type sink struct { // successful response. For simplicity, if we have a good response // following an error response before the error response's backoff // occurs, the backoff is not cleared. - consecutiveFailures uint32 + consecutiveFailures atomicU32 recBufsMu sync.Mutex // guards the following recBufs []*recBuf // contains all partition records for batch building @@ -60,10 +60,10 @@ type seqResp struct { func (cl *Client) newSink(nodeID int32) *sink { s := &sink{ - cl: cl, - nodeID: nodeID, - produceVersion: -1, + cl: cl, + nodeID: nodeID, } + s.produceVersion.Store(-1) maxInflight := 1 if cl.cfg.disableIdempotency { maxInflight = cl.cfg.maxProduceInflight @@ -113,7 +113,7 @@ func (s *sink) createReq(id int64, epoch int16) (*produceRequest, *kmsg.AddParti } batch := recBuf.batches[recBuf.batchDrainIdx] - if added := req.tryAddBatch(atomic.LoadInt32(&s.produceVersion), recBuf, batch); !added { + if added := req.tryAddBatch(s.produceVersion.Load(), recBuf, batch); !added { recBuf.mu.Unlock() moreToDrain = true continue @@ -181,7 +181,7 @@ func (t *txnReqBuilder) add(rb *recBuf) { } func (s *sink) maybeDrain() { - if s.cl.cfg.manualFlushing && atomic.LoadInt32(&s.cl.producer.flushing) == 0 { + if s.cl.cfg.manualFlushing && s.cl.producer.flushing.Load() == 0 { return } if s.drainState.maybeBegin() { @@ -201,7 +201,7 @@ func (s *sink) maybeBackoff() { s.cl.triggerUpdateMetadata(false, "opportunistic load during sink backoff") // as good a time as any - tries := int(atomic.AddUint32(&s.consecutiveFailures, 1)) + tries := int(s.consecutiveFailures.Add(1)) after := time.NewTimer(s.cl.cfg.retryBackoff(tries)) defer after.Stop() @@ -258,7 +258,7 @@ func (s *sink) produce(sem <-chan struct{}) bool { // We could have been triggered from a metadata update even though the // user is not producing at all. If we have no buffered records, let's // avoid potentially creating a producer ID. - if atomic.LoadInt64(&s.cl.producer.bufferedRecords) == 0 { + if s.cl.producer.bufferedRecords.Load() == 0 { return false } @@ -515,8 +515,8 @@ func (s *sink) issueTxnReq( // https://cwiki.apache.org/confluence/display/KAFKA/An+analysis+of+the+impact+of+max.in.flight.requests.per.connection+and+acks+on+Producer+performance // https://issues.apache.org/jira/browse/KAFKA-5494 func (s *sink) firstRespCheck(idempotent bool, version int16) { - if s.produceVersion < 0 { // this is the only place this can be checked non-atomically - atomic.StoreInt32(&s.produceVersion, int32(version)) + if s.produceVersion.Load() < 0 { + s.produceVersion.Store(int32(version)) if idempotent && version >= 4 { s.inflightSem.Store(make(chan struct{}, 4)) } @@ -582,7 +582,7 @@ func (s *sink) handleReqResp(br *broker, req *produceRequest, resp kmsg.Response return } s.firstRespCheck(req.idempotent(), req.version) - atomic.StoreUint32(&s.consecutiveFailures, 0) + s.consecutiveFailures.Store(0) defer req.metrics.hook(&s.cl.cfg, br) // defer to end so that non-written batches are removed var b *bytes.Buffer @@ -870,7 +870,7 @@ func (cl *Client) finishBatch(batch *recBatch, producerID int64, producerEpoch i // We remove this batch and finish all records appropriately. finished := len(batch.records) recBuf.batch0Seq = incrementSequence(recBuf.batch0Seq, int32(finished)) - atomic.AddInt64(&recBuf.buffered, -int64(finished)) + recBuf.buffered.Add(-int64(finished)) recBuf.batches[0] = nil recBuf.batches = recBuf.batches[1:] recBuf.batchDrainIdx-- @@ -1030,7 +1030,7 @@ type recBuf struct { // For LoadTopicPartitioner partitioning; atomically tracks the number // of records buffered in total on this recBuf. - buffered int64 + buffered atomicI64 mu sync.Mutex // guards r/w access to all fields below @@ -1158,7 +1158,7 @@ func (recBuf *recBuf) bufferRecord(pr promisedRec, abortOnNewBatch bool) bool { var ( newBatch = true onDrainBatch = recBuf.batchDrainIdx == len(recBuf.batches) - produceVersion = atomic.LoadInt32(&recBuf.sink.produceVersion) + produceVersion = recBuf.sink.produceVersion.Load() ) if !onDrainBatch { @@ -1200,7 +1200,7 @@ func (recBuf *recBuf) bufferRecord(pr promisedRec, abortOnNewBatch bool) bool { } } - atomic.AddInt64(&recBuf.buffered, 1) + recBuf.buffered.Add(1) return true } @@ -1221,7 +1221,7 @@ func (recBuf *recBuf) tryStopLingerForDraining() bool { // Begins a linger timer unless the producer is being flushed. func (recBuf *recBuf) lockedMaybeStartLinger() bool { - if atomic.LoadInt32(&recBuf.cl.producer.flushing) == 1 { + if recBuf.cl.producer.flushing.Load() > 0 { return false } recBuf.lingering = time.AfterFunc(recBuf.cl.cfg.linger, recBuf.sink.maybeDrain) @@ -1325,7 +1325,7 @@ func (recBuf *recBuf) failAllRecords(err error) { }) } recBuf.resetBatchDrainIdx() - atomic.StoreInt64(&recBuf.buffered, 0) + recBuf.buffered.Store(0) recBuf.batches = nil } diff --git a/pkg/kgo/source.go b/pkg/kgo/source.go index 87b00773..f961ff5b 100644 --- a/pkg/kgo/source.go +++ b/pkg/kgo/source.go @@ -7,7 +7,6 @@ import ( "hash/crc32" "sort" "sync" - "sync/atomic" "time" "github.com/twmb/franz-go/pkg/kbin" @@ -112,7 +111,7 @@ type cursor struct { // transitioning from used to usable. source *source - // useState is an atomic that has two states: unusable and usable. A + // useState is an atomic that has two states: unusable and usable. A // cursor can be used in a fetch request if it is in the usable state. // Once used, the cursor is unusable, and will be set back to usable // one the request lifecycle is complete (a usable fetch response, or @@ -123,7 +122,7 @@ type cursor struct { // // The used state is exclusively updated by either building a fetch // request or when the source is stopped. - useState uint32 + useState atomicBool topicPartitionData // updated in metadata when session is stopped @@ -161,7 +160,7 @@ func (c *cursor) use() *cursorOffsetNext { // A source using a cursor has exclusive access to the use field by // virtue of that source building a request during a live session, // or by virtue of the session being stopped. - c.useState = 0 + c.useState.Store(false) return &cursorOffsetNext{ cursorOffset: c.cursorOffset, from: c, @@ -173,7 +172,7 @@ func (c *cursor) use() *cursorOffsetNext { // to be consumed. This is called exclusively after sources are stopped. // This also unsets the cursor offset, which is assumed to be unused now. func (c *cursor) unset() { - c.useState = 0 + c.useState.Store(false) c.setOffset(cursorOffset{ offset: -1, lastConsumedEpoch: -1, @@ -183,14 +182,14 @@ func (c *cursor) unset() { // usable returns whether a cursor can be used for building a fetch request. func (c *cursor) usable() bool { - return atomic.LoadUint32(&c.useState) == 1 + return c.useState.Load() } // allowUsable allows a cursor to be fetched, and is called either in assigning // offsets, or when a buffered fetch is taken or discarded, or when listing / // epoch loading finishes. func (c *cursor) allowUsable() { - atomic.SwapUint32(&c.useState, 1) + c.useState.Swap(true) c.source.maybeConsume() } @@ -329,9 +328,9 @@ func (s *source) hook(f *Fetch, buffered, polled bool) { } } if buffered { - atomic.AddInt64(&s.cl.consumer.bufferedRecords, int64(nrecs)) + s.cl.consumer.bufferedRecords.Add(int64(nrecs)) } else { - atomic.AddInt64(&s.cl.consumer.bufferedRecords, -int64(nrecs)) + s.cl.consumer.bufferedRecords.Add(-int64(nrecs)) } } diff --git a/pkg/kgo/txn.go b/pkg/kgo/txn.go index b3ae9e03..685a3439 100644 --- a/pkg/kgo/txn.go +++ b/pkg/kgo/txn.go @@ -6,7 +6,6 @@ import ( "fmt" "strings" "sync" - "sync/atomic" "time" "github.com/twmb/franz-go/pkg/kerr" @@ -479,7 +478,7 @@ func (cl *Client) BeginTransaction() error { } cl.producer.inTxn = true - atomic.StoreUint32(&cl.producer.producingTxn, 1) // allow produces for txns now + cl.producer.producingTxn.Store(true) // allow produces for txns now cl.cfg.logger.Log(LogLevelInfo, "beginning transaction", "transactional_id", *cl.cfg.txnID) return nil @@ -754,8 +753,8 @@ func (cl *Client) EndAndBeginTransaction( // are known to not be in flight. This function is safe to call multiple times // concurrently, and safe to call concurrent with Flush. func (cl *Client) AbortBufferedRecords(ctx context.Context) error { - atomic.AddInt32(&cl.producer.aborting, 1) - defer atomic.AddInt32(&cl.producer.aborting, -1) + cl.producer.aborting.Add(1) + defer cl.producer.aborting.Add(-1) cl.cfg.logger.Log(LogLevelInfo, "producer state set to aborting; continuing to wait via flushing") defer cl.cfg.logger.Log(LogLevelDebug, "aborted buffered records") @@ -810,7 +809,7 @@ func (cl *Client) EndTransaction(ctx context.Context, commit TransactionEndTry) } cl.producer.inTxn = false - atomic.StoreUint32(&cl.producer.producingTxn, 0) // forbid any new produces while ending txn + cl.producer.producingTxn.Store(false) // forbid any new produces while ending txn // anyAdded tracks if any partitions were added to this txn, because // any partitions written to triggers AddPartitionToTxn, which triggers diff --git a/pkg/kgo/txn_test.go b/pkg/kgo/txn_test.go index 0fcc17bd..d0fc03c5 100644 --- a/pkg/kgo/txn_test.go +++ b/pkg/kgo/txn_test.go @@ -7,7 +7,6 @@ import ( "fmt" "os" "strconv" - "sync/atomic" "testing" "time" ) @@ -172,7 +171,7 @@ func (c *testConsumer) transact(txnsBeforeQuit int) { fetches := txnSess.PollFetches(ctx) cancel() if fetches.Err() == context.DeadlineExceeded || fetches.Err() == ErrClientClosed { - if consumed := int(atomic.LoadUint64(&c.consumed)); consumed == testRecordLimit { + if consumed := int(c.consumed.Load()); consumed == testRecordLimit { return } else if consumed > testRecordLimit { panic("invalid: consumed too much") @@ -254,7 +253,7 @@ func (c *testConsumer) transact(txnsBeforeQuit int) { if !rec.control { c.part2key[part] = append(c.part2key[part], rec.num) - atomic.AddUint64(&c.consumed, 1) + c.consumed.Add(1) } } }