Skip to content

Commit b565a23

Browse files
committed
code cleanup
1 parent ddbd3e9 commit b565a23

File tree

3 files changed

+91
-94
lines changed

3 files changed

+91
-94
lines changed

x/mongo/driver/topology/connection.go

Lines changed: 84 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type connection struct {
7979
driverConnectionID uint64
8080
generation uint64
8181

82-
// awaitingResponse indicates the size of server response that was not completely
82+
// awaitRemainingBytes indicates the size of server response that was not completely
8383
// read before returning the connection to the pool.
84-
awaitingResponse *int32
84+
awaitRemainingBytes *int32
8585

8686
// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
8787
// accessTokens in the OIDC authenticator cache.
@@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
115115
return c
116116
}
117117

118-
// DriverConnectionID returns the driver connection ID.
119-
// TODO(GODRIVER-2824): change return type to int64.
120-
func (c *connection) DriverConnectionID() uint64 {
121-
return c.driverConnectionID
122-
}
123-
124118
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
125119
// configuration.
126120
func (c *connection) setGenerationNumber() {
@@ -142,6 +136,39 @@ func (c *connection) hasGenerationNumber() bool {
142136
return c.desc.LoadBalanced()
143137
}
144138

139+
func configureTLS(ctx context.Context,
140+
tlsConnSource tlsConnectionSource,
141+
nc net.Conn,
142+
addr address.Address,
143+
config *tls.Config,
144+
ocspOpts *ocsp.VerifyOptions,
145+
) (net.Conn, error) {
146+
// Ensure config.ServerName is always set for SNI.
147+
if config.ServerName == "" {
148+
hostname := addr.String()
149+
colonPos := strings.LastIndex(hostname, ":")
150+
if colonPos == -1 {
151+
colonPos = len(hostname)
152+
}
153+
154+
hostname = hostname[:colonPos]
155+
config.ServerName = hostname
156+
}
157+
158+
client := tlsConnSource.Client(nc, config)
159+
if err := clientHandshake(ctx, client); err != nil {
160+
return nil, err
161+
}
162+
163+
// Only do OCSP verification if TLS verification is requested.
164+
if !config.InsecureSkipVerify {
165+
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
166+
return nil, ocspErr
167+
}
168+
}
169+
return client, nil
170+
}
171+
145172
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
146173
// handshakes. All errors returned by connect are considered "before the handshake completes" and
147174
// must be handled by calling the appropriate SDAM handshake error handler.
@@ -317,6 +344,10 @@ func (c *connection) closeConnectContext() {
317344
}
318345
}
319346

347+
func (c *connection) cancellationListenerCallback() {
348+
_ = c.close()
349+
}
350+
320351
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
321352
if originalError == nil {
322353
return nil
@@ -339,10 +370,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
339370
return originalError
340371
}
341372

342-
func (c *connection) cancellationListenerCallback() {
343-
_ = c.close()
344-
}
345-
346373
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
347374
var err error
348375
if atomic.LoadInt64(&c.state) != connConnected {
@@ -423,7 +450,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
423450

424451
dst, errMsg, err := c.read(ctx)
425452
if err != nil {
426-
if c.awaitingResponse == nil {
453+
if c.awaitRemainingBytes == nil {
427454
// If the connection was not marked as awaiting response, use the
428455
// pre-CSOT behavior and close the connection because we don't know
429456
// if there are other bytes left to read.
@@ -443,6 +470,29 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
443470
return dst, nil
444471
}
445472

473+
func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
474+
// read the length as an int32
475+
size := (int32(wmSizeBytes[0])) |
476+
(int32(wmSizeBytes[1]) << 8) |
477+
(int32(wmSizeBytes[2]) << 16) |
478+
(int32(wmSizeBytes[3]) << 24)
479+
480+
if size < 4 {
481+
return 0, fmt.Errorf("malformed message length: %d", size)
482+
}
483+
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
484+
// defaultMaxMessageSize instead.
485+
maxMessageSize := c.desc.MaxMessageSize
486+
if maxMessageSize == 0 {
487+
maxMessageSize = defaultMaxMessageSize
488+
}
489+
if uint32(size) > maxMessageSize {
490+
return 0, errResponseTooLarge
491+
}
492+
493+
return size, nil
494+
}
495+
446496
func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
447497
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
448498
defer func() {
@@ -475,35 +525,23 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
475525
n, err := io.ReadFull(c.nc, sizeBuf[:])
476526
if err != nil {
477527
if l := int32(n); l == 0 && needToWait(err) {
478-
c.awaitingResponse = &l
528+
c.awaitRemainingBytes = &l
479529
}
480530
return nil, "incomplete read of message header", err
481531
}
482-
483-
// read the length as an int32
484-
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
485-
486-
if size < 4 {
487-
err = fmt.Errorf("malformed message length: %d", size)
532+
size, err := c.parseWmSizeBytes(sizeBuf)
533+
if err != nil {
488534
return nil, err.Error(), err
489535
}
490-
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
491-
// defaultMaxMessageSize instead.
492-
maxMessageSize := c.desc.MaxMessageSize
493-
if maxMessageSize == 0 {
494-
maxMessageSize = defaultMaxMessageSize
495-
}
496-
if uint32(size) > maxMessageSize {
497-
return nil, errResponseTooLarge.Error(), errResponseTooLarge
498-
}
499536

500537
dst := make([]byte, size)
501538
copy(dst, sizeBuf[:])
502539

503540
n, err = io.ReadFull(c.nc, dst[4:])
504541
if err != nil {
505-
if l := size - 4 - int32(n); l > 0 && needToWait(err) {
506-
c.awaitingResponse = &l
542+
remainingBytes := size - 4 - int32(n)
543+
if remainingBytes > 0 && needToWait(err) {
544+
c.awaitRemainingBytes = &remainingBytes
507545
}
508546
return dst, "incomplete read of full message", err
509547
}
@@ -551,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) {
551589
c.canStream = canStream
552590
}
553591

554-
func (c initConnection) supportsStreaming() bool {
555-
return c.canStream
556-
}
557-
558592
func (c *connection) setStreaming(streaming bool) {
559593
c.currentlyStreaming = streaming
560594
}
@@ -568,6 +602,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
568602
c.writeTimeout = timeout
569603
}
570604

605+
// DriverConnectionID returns the driver connection ID.
606+
// TODO(GODRIVER-2824): change return type to int64.
607+
func (c *connection) DriverConnectionID() uint64 {
608+
return c.driverConnectionID
609+
}
610+
571611
func (c *connection) ID() string {
572612
return c.id
573613
}
@@ -576,6 +616,14 @@ func (c *connection) ServerConnectionID() *int64 {
576616
return c.serverConnectionID
577617
}
578618

619+
func (c *connection) OIDCTokenGenID() uint64 {
620+
return c.oidcTokenGenID
621+
}
622+
623+
func (c *connection) SetOIDCTokenGenID(genID uint64) {
624+
c.oidcTokenGenID = genID
625+
}
626+
579627
// initConnection is an adapter used during connection initialization. It has the minimum
580628
// functionality necessary to implement the driver.Connection interface, which is required to pass a
581629
// *connection to a Handshaker.
@@ -613,7 +661,7 @@ func (c initConnection) CurrentlyStreaming() bool {
613661
return c.getCurrentlyStreaming()
614662
}
615663
func (c initConnection) SupportsStreaming() bool {
616-
return c.supportsStreaming()
664+
return c.canStream
617665
}
618666

619667
// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -847,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 {
847895
return c.connection.DriverConnectionID()
848896
}
849897

850-
func configureTLS(ctx context.Context,
851-
tlsConnSource tlsConnectionSource,
852-
nc net.Conn,
853-
addr address.Address,
854-
config *tls.Config,
855-
ocspOpts *ocsp.VerifyOptions,
856-
) (net.Conn, error) {
857-
// Ensure config.ServerName is always set for SNI.
858-
if config.ServerName == "" {
859-
hostname := addr.String()
860-
colonPos := strings.LastIndex(hostname, ":")
861-
if colonPos == -1 {
862-
colonPos = len(hostname)
863-
}
864-
865-
hostname = hostname[:colonPos]
866-
config.ServerName = hostname
867-
}
868-
869-
client := tlsConnSource.Client(nc, config)
870-
if err := clientHandshake(ctx, client); err != nil {
871-
return nil, err
872-
}
873-
874-
// Only do OCSP verification if TLS verification is requested.
875-
if !config.InsecureSkipVerify {
876-
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
877-
return nil, ocspErr
878-
}
879-
}
880-
return client, nil
881-
}
882-
883898
// OIDCTokenGenID returns the OIDC token generation ID.
884899
func (c *Connection) OIDCTokenGenID() uint64 {
885900
return c.oidcTokenGenID
@@ -933,11 +948,3 @@ func (c *cancellListener) StopListening() bool {
933948
c.done <- struct{}{}
934949
return c.aborted
935950
}
936-
937-
func (c *connection) OIDCTokenGenID() uint64 {
938-
return c.oidcTokenGenID
939-
}
940-
941-
func (c *connection) SetOIDCTokenGenID(genID uint64) {
942-
c.oidcTokenGenID = genID
943-
}

x/mongo/driver/topology/pool.go

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"context"
1111
"fmt"
1212
"io"
13-
"io/ioutil"
1413
"net"
1514
"sync"
1615
"sync/atomic"
@@ -833,22 +832,13 @@ func bgRead(pool *pool, conn *connection, size int32) {
833832
err = fmt.Errorf("error reading the message size: %w", err)
834833
return
835834
}
836-
size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
837-
if size < 4 {
838-
err = fmt.Errorf("malformed message length: %d", size)
839-
return
840-
}
841-
maxMessageSize := conn.desc.MaxMessageSize
842-
if maxMessageSize == 0 {
843-
maxMessageSize = defaultMaxMessageSize
844-
}
845-
if uint32(size) > maxMessageSize {
846-
err = errResponseTooLarge
835+
size, err = conn.parseWmSizeBytes(sizeBuf)
836+
if err != nil {
847837
return
848838
}
849839
size -= 4
850840
}
851-
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
841+
_, err = io.CopyN(io.Discard, conn.nc, int64(size))
852842
if err != nil {
853843
err = fmt.Errorf("error reading message of %d: %w", size, err)
854844
}
@@ -901,9 +891,9 @@ func (p *pool) checkInNoEvent(conn *connection) error {
901891
// means that connections in "awaiting response" state are checked in but
902892
// not usable, which is not covered by the current pool events. We may need
903893
// to add pool event information in the future to communicate that.
904-
if conn.awaitingResponse != nil {
905-
size := *conn.awaitingResponse
906-
conn.awaitingResponse = nil
894+
if conn.awaitRemainingBytes != nil {
895+
size := *conn.awaitRemainingBytes
896+
conn.awaitRemainingBytes = nil
907897
go bgRead(p, conn, size)
908898
return nil
909899
}

x/mongo/driver/topology/pool_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1197,7 +1197,7 @@ func TestPool(t *testing.T) {
11971197
`^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`,
11981198
)
11991199
assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err)
1200-
assert.Nil(t, conn.awaitingResponse, "conn.awaitingResponse should be nil")
1200+
assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitingResponse should be nil")
12011201
wg.Wait()
12021202
p.close(context.Background())
12031203
close(errCh)

0 commit comments

Comments
 (0)