@@ -9,6 +9,7 @@ package topology
99import (
1010 "context"
1111 "crypto/tls"
12+ "encoding/binary"
1213 "errors"
1314 "fmt"
1415 "io"
@@ -18,6 +19,7 @@ import (
1819 "sync/atomic"
1920 "time"
2021
22+ "go.mongodb.org/mongo-driver/v2/internal/csot"
2123 "go.mongodb.org/mongo-driver/v2/internal/driverutil"
2224 "go.mongodb.org/mongo-driver/v2/mongo/address"
2325 "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
@@ -80,9 +82,9 @@ type connection struct {
8082 // accessTokens in the OIDC authenticator cache.
8183 oidcTokenGenID uint64
8284
83- // awaitingResponse indicates that the server response was not completely
85+ // awaitRemainingBytes indicates the size of server response that was not completely
8486 // read before returning the connection to the pool.
85- awaitingResponse bool
87+ awaitRemainingBytes * int32
8688}
8789
8890// newConnection handles the creation of a connection. It does not connect the connection.
@@ -111,11 +113,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
111113 return c
112114}
113115
114- // DriverConnectionID returns the driver connection ID.
115- func (c * connection ) DriverConnectionID () int64 {
116- return c .driverConnectionID
117- }
118-
119116// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
120117// configuration.
121118func (c * connection ) setGenerationNumber () {
@@ -137,6 +134,39 @@ func (c *connection) hasGenerationNumber() bool {
137134 return driverutil .IsServerLoadBalanced (c .desc )
138135}
139136
137+ func configureTLS (ctx context.Context ,
138+ tlsConnSource tlsConnectionSource ,
139+ nc net.Conn ,
140+ addr address.Address ,
141+ config * tls.Config ,
142+ ocspOpts * ocsp.VerifyOptions ,
143+ ) (net.Conn , error ) {
144+ // Ensure config.ServerName is always set for SNI.
145+ if config .ServerName == "" {
146+ hostname := addr .String ()
147+ colonPos := strings .LastIndex (hostname , ":" )
148+ if colonPos == - 1 {
149+ colonPos = len (hostname )
150+ }
151+
152+ hostname = hostname [:colonPos ]
153+ config .ServerName = hostname
154+ }
155+
156+ client := tlsConnSource .Client (nc , config )
157+ if err := clientHandshake (ctx , client ); err != nil {
158+ return nil , err
159+ }
160+
161+ // Only do OCSP verification if TLS verification is requested.
162+ if ! config .InsecureSkipVerify {
163+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
164+ return nil , ocspErr
165+ }
166+ }
167+ return client , nil
168+ }
169+
140170// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
141171// handshakes. All errors returned by connect are considered "before the handshake completes" and
142172// must be handled by calling the appropriate SDAM handshake error handler.
@@ -291,6 +321,10 @@ func (c *connection) closeConnectContext() {
291321 }
292322}
293323
324+ func (c * connection ) cancellationListenerCallback () {
325+ _ = c .close ()
326+ }
327+
294328func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
295329 if originalError == nil {
296330 return nil
@@ -313,10 +347,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
313347 return originalError
314348}
315349
316- func (c * connection ) cancellationListenerCallback () {
317- _ = c .close ()
318- }
319-
320350func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
321351 var err error
322352 if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -377,14 +407,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
377407
378408 dst , errMsg , err := c .read (ctx )
379409 if err != nil {
380- if nerr := net .Error (nil ); errors .As (err , & nerr ) && nerr .Timeout () {
381- // If the error was a timeout error, instead of closing the
382- // connection mark it as awaiting response so the pool can read the
383- // response before making it available to other operations.
384- c .awaitingResponse = true
385- } else {
386- // Otherwise, and close the connection because we don't know what
387- // the connection state is.
410+ if c .awaitRemainingBytes == nil {
411+ // If the connection was not marked as awaiting response, use the
412+ // pre-CSOT behavior and close the connection because we don't know
413+ // if there are other bytes left to read.
388414 c .close ()
389415 }
390416 message := errMsg
@@ -401,6 +427,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
401427 return dst , nil
402428}
403429
430+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
431+ // read the length as an int32
432+ size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
433+
434+ if size < 4 {
435+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
436+ }
437+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
438+ // defaultMaxMessageSize instead.
439+ maxMessageSize := c .desc .MaxMessageSize
440+ if maxMessageSize == 0 {
441+ maxMessageSize = defaultMaxMessageSize
442+ }
443+ if uint32 (size ) > maxMessageSize {
444+ return 0 , errResponseTooLarge
445+ }
446+
447+ return size , nil
448+ }
449+
404450func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
405451 go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
406452 defer func () {
@@ -414,36 +460,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
414460 }
415461 }()
416462
463+ isCSOTTimeout := func (err error ) bool {
464+ // If the error was a timeout error and CSOT is enabled, instead of
465+ // closing the connection mark it as awaiting response so the pool
466+ // can read the response before making it available to other
467+ // operations.
468+ nerr := net .Error (nil )
469+ return errors .As (err , & nerr ) && nerr .Timeout () && csot .IsTimeoutContext (ctx )
470+ }
471+
417472 // We use an array here because it only costs 4 bytes on the stack and means we'll only need to
418473 // reslice dst once instead of twice.
419474 var sizeBuf [4 ]byte
420475
421476 // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
422477 // because there might be more than one wire message waiting to be read, for example when
423478 // reading messages from an exhaust cursor.
424- _ , err = io .ReadFull (c .nc , sizeBuf [:])
479+ n , err : = io .ReadFull (c .nc , sizeBuf [:])
425480 if err != nil {
481+ if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
482+ c .awaitRemainingBytes = & l
483+ }
426484 return nil , "incomplete read of message header" , err
427485 }
428-
429- // read the length as an int32
430- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
431-
432- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
433- // defaultMaxMessageSize instead.
434- maxMessageSize := c .desc .MaxMessageSize
435- if maxMessageSize == 0 {
436- maxMessageSize = defaultMaxMessageSize
437- }
438- if uint32 (size ) > maxMessageSize {
439- return nil , errResponseTooLarge .Error (), errResponseTooLarge
486+ size , err := c .parseWmSizeBytes (sizeBuf )
487+ if err != nil {
488+ return nil , err .Error (), err
440489 }
441490
442491 dst := make ([]byte , size )
443492 copy (dst , sizeBuf [:])
444493
445- _ , err = io .ReadFull (c .nc , dst [4 :])
494+ n , err = io .ReadFull (c .nc , dst [4 :])
446495 if err != nil {
496+ remainingBytes := size - 4 - int32 (n )
497+ if remainingBytes > 0 && isCSOTTimeout (err ) {
498+ c .awaitRemainingBytes = & remainingBytes
499+ }
447500 return dst , "incomplete read of full message" , err
448501 }
449502
@@ -496,10 +549,6 @@ func (c *connection) setCanStream(canStream bool) {
496549 c .canStream = canStream
497550}
498551
499- func (c initConnection ) supportsStreaming () bool {
500- return c .canStream
501- }
502-
503552func (c * connection ) setStreaming (streaming bool ) {
504553 c .currentlyStreaming = streaming
505554}
@@ -508,6 +557,14 @@ func (c *connection) getCurrentlyStreaming() bool {
508557 return c .currentlyStreaming
509558}
510559
560+ func (c * connection ) previousCanceled () bool {
561+ if val := c .prevCanceled .Load (); val != nil {
562+ return val .(bool )
563+ }
564+
565+ return false
566+ }
567+
511568func (c * connection ) ID () string {
512569 return c .id
513570}
@@ -516,12 +573,17 @@ func (c *connection) ServerConnectionID() *int64 {
516573 return c .serverConnectionID
517574}
518575
519- func ( c * connection ) previousCanceled () bool {
520- if val := c . prevCanceled . Load (); val != nil {
521- return val .( bool )
522- }
576+ // DriverConnectionID returns the driver connection ID.
577+ func ( c * connection ) DriverConnectionID () int64 {
578+ return c . driverConnectionID
579+ }
523580
524- return false
581+ func (c * connection ) OIDCTokenGenID () uint64 {
582+ return c .oidcTokenGenID
583+ }
584+
585+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
586+ c .oidcTokenGenID = genID
525587}
526588
527589// initConnection is an adapter used during connection initialization. It has the minimum
@@ -562,7 +624,7 @@ func (c initConnection) CurrentlyStreaming() bool {
562624 return c .getCurrentlyStreaming ()
563625}
564626func (c initConnection ) SupportsStreaming () bool {
565- return c .supportsStreaming ()
627+ return c .canStream
566628}
567629
568630// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -797,39 +859,6 @@ func (c *Connection) DriverConnectionID() int64 {
797859 return c .connection .DriverConnectionID ()
798860}
799861
800- func configureTLS (ctx context.Context ,
801- tlsConnSource tlsConnectionSource ,
802- nc net.Conn ,
803- addr address.Address ,
804- config * tls.Config ,
805- ocspOpts * ocsp.VerifyOptions ,
806- ) (net.Conn , error ) {
807- // Ensure config.ServerName is always set for SNI.
808- if config .ServerName == "" {
809- hostname := addr .String ()
810- colonPos := strings .LastIndex (hostname , ":" )
811- if colonPos == - 1 {
812- colonPos = len (hostname )
813- }
814-
815- hostname = hostname [:colonPos ]
816- config .ServerName = hostname
817- }
818-
819- client := tlsConnSource .Client (nc , config )
820- if err := clientHandshake (ctx , client ); err != nil {
821- return nil , err
822- }
823-
824- // Only do OCSP verification if TLS verification is requested.
825- if ! config .InsecureSkipVerify {
826- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
827- return nil , ocspErr
828- }
829- }
830- return client , nil
831- }
832-
833862// OIDCTokenGenID returns the OIDC token generation ID.
834863func (c * Connection ) OIDCTokenGenID () uint64 {
835864 return c .oidcTokenGenID
@@ -839,11 +868,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
839868func (c * Connection ) SetOIDCTokenGenID (genID uint64 ) {
840869 c .oidcTokenGenID = genID
841870}
842-
843- func (c * connection ) OIDCTokenGenID () uint64 {
844- return c .oidcTokenGenID
845- }
846-
847- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
848- c .oidcTokenGenID = genID
849- }
0 commit comments