Skip to content

Commit 8a629fb

Browse files
committed
fix race in tests
1 parent d39da69 commit 8a629fb

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

auth/conn_reauth_credentials_listener.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,7 @@ func (c *ConnReAuthCredentialsListener) OnNext(credentials Credentials) {
3737
// this is important because the connection pool may be in the process of reconnecting the connection
3838
// and we don't want to interfere with that process
3939
// but we also don't want to block for too long, so incorporate a timeout
40-
for {
41-
// we were able to mark the connection as unusable
42-
if c.conn.Usable.CompareAndSwap(true, false) {
43-
break
44-
}
45-
40+
for !c.conn.Usable.CompareAndSwap(true, false) {
4641
select {
4742
case <-timeout:
4843
err = pool.ErrConnUnusableTimeout

internal/pool/buffer_size_test.go

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package pool_test
33
import (
44
"bufio"
55
"context"
6-
"net"
76
"unsafe"
87

98
. "github.com/bsm/ginkgo/v2"
@@ -124,20 +123,27 @@ var _ = Describe("Buffer Size Configuration", func() {
124123
})
125124

126125
// Helper functions to extract buffer sizes using unsafe pointers
126+
// The struct layout must match pool.Conn exactly to avoid checkptr violations
127127
func getWriterBufSizeUnsafe(cn *pool.Conn) int {
128+
// Import required for atomic types
129+
type atomicBool struct{ _ uint32 }
130+
type atomicInt64 struct{ _ int64 }
131+
128132
cnPtr := (*struct {
129-
usedAt int64
130-
netConn net.Conn
131-
rd *proto.Reader
132-
bw *bufio.Writer
133-
wr *proto.Writer
134-
// ... other fields
133+
id uint64 // First field in pool.Conn
134+
usedAt int64 // Second field (atomic)
135+
netConnAtomic interface{} // atomic.Value (interface{} has same size)
136+
rd *proto.Reader
137+
bw *bufio.Writer
138+
wr *proto.Writer
139+
// We only need fields up to bw, so we can stop here
135140
})(unsafe.Pointer(cn))
136141

137142
if cnPtr.bw == nil {
138143
return -1
139144
}
140145

146+
// bufio.Writer internal structure
141147
bwPtr := (*struct {
142148
err error
143149
buf []byte
@@ -150,18 +156,20 @@ func getWriterBufSizeUnsafe(cn *pool.Conn) int {
150156

151157
func getReaderBufSizeUnsafe(cn *pool.Conn) int {
152158
cnPtr := (*struct {
153-
usedAt int64
154-
netConn net.Conn
155-
rd *proto.Reader
156-
bw *bufio.Writer
157-
wr *proto.Writer
158-
// ... other fields
159+
id uint64 // First field in pool.Conn
160+
usedAt int64 // Second field (atomic)
161+
netConnAtomic interface{} // atomic.Value (interface{} has same size)
162+
rd *proto.Reader
163+
bw *bufio.Writer
164+
wr *proto.Writer
165+
// We only need fields up to rd, so we can stop here
159166
})(unsafe.Pointer(cn))
160167

161168
if cnPtr.rd == nil {
162169
return -1
163170
}
164171

172+
// proto.Reader internal structure
165173
rdPtr := (*struct {
166174
rd *bufio.Reader
167175
})(unsafe.Pointer(cnPtr.rd))
@@ -170,6 +178,7 @@ func getReaderBufSizeUnsafe(cn *pool.Conn) int {
170178
return -1
171179
}
172180

181+
// bufio.Reader internal structure
173182
bufReaderPtr := (*struct {
174183
buf []byte
175184
rd interface{}

redis.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func (c *baseClient) connReAuthCredentialsListener(poolCn *pool.Conn) (auth.Cred
308308
credListener, ok := c.credListeners[poolCn]
309309
c.credListenersLock.RUnlock()
310310
if ok {
311-
return credListener.(auth.CredentialsListener), func() {
311+
return credListener, func() {
312312
c.removeCredListener(poolCn)
313313
}
314314
}

redis_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,7 @@ var _ = Describe("Credentials Provider Priority", func() {
872872
})
873873

874874
type mockStreamingProvider struct {
875+
mu sync.RWMutex
875876
credentials auth.Credentials
876877
err error
877878
updates chan auth.Credentials
@@ -885,12 +886,18 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au
885886
// Start goroutine to handle updates
886887
go func() {
887888
for creds := range m.updates {
889+
m.mu.Lock()
888890
m.credentials = creds
891+
m.mu.Unlock()
889892
listener.OnNext(creds)
890893
}
891894
}()
892895

893-
return m.credentials, func() (err error) {
896+
m.mu.RLock()
897+
currentCreds := m.credentials
898+
m.mu.RUnlock()
899+
900+
return currentCreds, func() (err error) {
894901
defer func() {
895902
if r := recover(); r != nil {
896903
// this is just a mock:

0 commit comments

Comments
 (0)