@@ -9,6 +9,7 @@ package topology
99import (
1010 "context"
1111 "crypto/tls"
12+ "encoding/binary"
1213 "errors"
1314 "fmt"
1415 "io"
@@ -80,9 +81,9 @@ type connection struct {
8081 // accessTokens in the OIDC authenticator cache.
8182 oidcTokenGenID uint64
8283
83- // awaitingResponse indicates that the server response was not completely
84+ // awaitRemainingBytes indicates the size of server response that was not completely
8485 // read before returning the connection to the pool.
85- awaitingResponse bool
86+ awaitRemainingBytes * int32
8687}
8788
8889// newConnection handles the creation of a connection. It does not connect the connection.
@@ -111,11 +112,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
111112 return c
112113}
113114
114- // DriverConnectionID returns the driver connection ID.
115- func (c * connection ) DriverConnectionID () int64 {
116- return c .driverConnectionID
117- }
118-
119115// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
120116// configuration.
121117func (c * connection ) setGenerationNumber () {
@@ -137,6 +133,39 @@ func (c *connection) hasGenerationNumber() bool {
137133 return driverutil .IsServerLoadBalanced (c .desc )
138134}
139135
136+ func configureTLS (ctx context.Context ,
137+ tlsConnSource tlsConnectionSource ,
138+ nc net.Conn ,
139+ addr address.Address ,
140+ config * tls.Config ,
141+ ocspOpts * ocsp.VerifyOptions ,
142+ ) (net.Conn , error ) {
143+ // Ensure config.ServerName is always set for SNI.
144+ if config .ServerName == "" {
145+ hostname := addr .String ()
146+ colonPos := strings .LastIndex (hostname , ":" )
147+ if colonPos == - 1 {
148+ colonPos = len (hostname )
149+ }
150+
151+ hostname = hostname [:colonPos ]
152+ config .ServerName = hostname
153+ }
154+
155+ client := tlsConnSource .Client (nc , config )
156+ if err := clientHandshake (ctx , client ); err != nil {
157+ return nil , err
158+ }
159+
160+ // Only do OCSP verification if TLS verification is requested.
161+ if ! config .InsecureSkipVerify {
162+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
163+ return nil , ocspErr
164+ }
165+ }
166+ return client , nil
167+ }
168+
140169// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
141170// handshakes. All errors returned by connect are considered "before the handshake completes" and
142171// must be handled by calling the appropriate SDAM handshake error handler.
@@ -291,6 +320,10 @@ func (c *connection) closeConnectContext() {
291320 }
292321}
293322
323+ func (c * connection ) cancellationListenerCallback () {
324+ _ = c .close ()
325+ }
326+
294327func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
295328 if originalError == nil {
296329 return nil
@@ -313,10 +346,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
313346 return originalError
314347}
315348
316- func (c * connection ) cancellationListenerCallback () {
317- _ = c .close ()
318- }
319-
320349func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
321350 var err error
322351 if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -377,14 +406,9 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
377406
378407 dst , errMsg , err := c .read (ctx )
379408 if err != nil {
380- if nerr := net .Error (nil ); errors .As (err , & nerr ) && nerr .Timeout () {
381- // If the error was a timeout error, instead of closing the
382- // connection mark it as awaiting response so the pool can read the
383- // response before making it available to other operations.
384- c .awaitingResponse = true
385- } else {
386- // Otherwise, and close the connection because we don't know what
387- // the connection state is.
409+ if c .awaitRemainingBytes == nil {
410+ // If the connection was not marked as awaiting response, close the
411+ // connection because we don't know what the connection state is.
388412 c .close ()
389413 }
390414 message := errMsg
@@ -401,6 +425,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
401425 return dst , nil
402426}
403427
428+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
429+ // read the length as an int32
430+ size := int32 (binary .LittleEndian .Uint32 (wmSizeBytes [:]))
431+
432+ if size < 4 {
433+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
434+ }
435+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
436+ // defaultMaxMessageSize instead.
437+ maxMessageSize := c .desc .MaxMessageSize
438+ if maxMessageSize == 0 {
439+ maxMessageSize = defaultMaxMessageSize
440+ }
441+ if uint32 (size ) > maxMessageSize {
442+ return 0 , errResponseTooLarge
443+ }
444+
445+ return size , nil
446+ }
447+
404448func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
405449 go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
406450 defer func () {
@@ -414,36 +458,42 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
414458 }
415459 }()
416460
461+ isCSOTTimeout := func (err error ) bool {
462+ // If the error was a timeout error, instead of closing the
463+ // connection mark it as awaiting response so the pool can read the
464+ // response before making it available to other operations.
465+ nerr := net .Error (nil )
466+ return errors .As (err , & nerr ) && nerr .Timeout ()
467+ }
468+
417469 // We use an array here because it only costs 4 bytes on the stack and means we'll only need to
418470 // reslice dst once instead of twice.
419471 var sizeBuf [4 ]byte
420472
421473 // We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
422474 // because there might be more than one wire message waiting to be read, for example when
423475 // reading messages from an exhaust cursor.
424- _ , err = io .ReadFull (c .nc , sizeBuf [:])
476+ n , err : = io .ReadFull (c .nc , sizeBuf [:])
425477 if err != nil {
478+ if l := int32 (n ); l == 0 && isCSOTTimeout (err ) {
479+ c .awaitRemainingBytes = & l
480+ }
426481 return nil , "incomplete read of message header" , err
427482 }
428-
429- // read the length as an int32
430- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
431-
432- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
433- // defaultMaxMessageSize instead.
434- maxMessageSize := c .desc .MaxMessageSize
435- if maxMessageSize == 0 {
436- maxMessageSize = defaultMaxMessageSize
437- }
438- if uint32 (size ) > maxMessageSize {
439- return nil , errResponseTooLarge .Error (), errResponseTooLarge
483+ size , err := c .parseWmSizeBytes (sizeBuf )
484+ if err != nil {
485+ return nil , err .Error (), err
440486 }
441487
442488 dst := make ([]byte , size )
443489 copy (dst , sizeBuf [:])
444490
445- _ , err = io .ReadFull (c .nc , dst [4 :])
491+ n , err = io .ReadFull (c .nc , dst [4 :])
446492 if err != nil {
493+ remainingBytes := size - 4 - int32 (n )
494+ if remainingBytes > 0 && isCSOTTimeout (err ) {
495+ c .awaitRemainingBytes = & remainingBytes
496+ }
447497 return dst , "incomplete read of full message" , err
448498 }
449499
@@ -496,10 +546,6 @@ func (c *connection) setCanStream(canStream bool) {
496546 c .canStream = canStream
497547}
498548
499- func (c initConnection ) supportsStreaming () bool {
500- return c .canStream
501- }
502-
503549func (c * connection ) setStreaming (streaming bool ) {
504550 c .currentlyStreaming = streaming
505551}
@@ -508,6 +554,14 @@ func (c *connection) getCurrentlyStreaming() bool {
508554 return c .currentlyStreaming
509555}
510556
557+ func (c * connection ) previousCanceled () bool {
558+ if val := c .prevCanceled .Load (); val != nil {
559+ return val .(bool )
560+ }
561+
562+ return false
563+ }
564+
511565func (c * connection ) ID () string {
512566 return c .id
513567}
@@ -516,12 +570,17 @@ func (c *connection) ServerConnectionID() *int64 {
516570 return c .serverConnectionID
517571}
518572
519- func ( c * connection ) previousCanceled () bool {
520- if val := c . prevCanceled . Load (); val != nil {
521- return val .( bool )
522- }
573+ // DriverConnectionID returns the driver connection ID.
574+ func ( c * connection ) DriverConnectionID () int64 {
575+ return c . driverConnectionID
576+ }
523577
524- return false
578+ func (c * connection ) OIDCTokenGenID () uint64 {
579+ return c .oidcTokenGenID
580+ }
581+
582+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
583+ c .oidcTokenGenID = genID
525584}
526585
527586// initConnection is an adapter used during connection initialization. It has the minimum
@@ -562,7 +621,7 @@ func (c initConnection) CurrentlyStreaming() bool {
562621 return c .getCurrentlyStreaming ()
563622}
564623func (c initConnection ) SupportsStreaming () bool {
565- return c .supportsStreaming ()
624+ return c .canStream
566625}
567626
568627// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -797,39 +856,6 @@ func (c *Connection) DriverConnectionID() int64 {
797856 return c .connection .DriverConnectionID ()
798857}
799858
800- func configureTLS (ctx context.Context ,
801- tlsConnSource tlsConnectionSource ,
802- nc net.Conn ,
803- addr address.Address ,
804- config * tls.Config ,
805- ocspOpts * ocsp.VerifyOptions ,
806- ) (net.Conn , error ) {
807- // Ensure config.ServerName is always set for SNI.
808- if config .ServerName == "" {
809- hostname := addr .String ()
810- colonPos := strings .LastIndex (hostname , ":" )
811- if colonPos == - 1 {
812- colonPos = len (hostname )
813- }
814-
815- hostname = hostname [:colonPos ]
816- config .ServerName = hostname
817- }
818-
819- client := tlsConnSource .Client (nc , config )
820- if err := clientHandshake (ctx , client ); err != nil {
821- return nil , err
822- }
823-
824- // Only do OCSP verification if TLS verification is requested.
825- if ! config .InsecureSkipVerify {
826- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
827- return nil , ocspErr
828- }
829- }
830- return client , nil
831- }
832-
833859// OIDCTokenGenID returns the OIDC token generation ID.
834860func (c * Connection ) OIDCTokenGenID () uint64 {
835861 return c .oidcTokenGenID
@@ -839,11 +865,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
839865func (c * Connection ) SetOIDCTokenGenID (genID uint64 ) {
840866 c .oidcTokenGenID = genID
841867}
842-
843- func (c * connection ) OIDCTokenGenID () uint64 {
844- return c .oidcTokenGenID
845- }
846-
847- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
848- c .oidcTokenGenID = genID
849- }
0 commit comments