@@ -79,9 +79,9 @@ type connection struct {
7979 driverConnectionID uint64
8080 generation uint64
8181
82- // awaitingResponse indicates the size of server response that was not completely
82+ // awaitRemainingBytes indicates the size of server response that was not completely
8383 // read before returning the connection to the pool.
84- awaitingResponse * int32
84+ awaitRemainingBytes * int32
8585
8686 // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
8787 // accessTokens in the OIDC authenticator cache.
@@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
115115 return c
116116}
117117
118- // DriverConnectionID returns the driver connection ID.
119- // TODO(GODRIVER-2824): change return type to int64.
120- func (c * connection ) DriverConnectionID () uint64 {
121- return c .driverConnectionID
122- }
123-
124118// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
125119// configuration.
126120func (c * connection ) setGenerationNumber () {
@@ -142,6 +136,39 @@ func (c *connection) hasGenerationNumber() bool {
142136 return c .desc .LoadBalanced ()
143137}
144138
139+ func configureTLS (ctx context.Context ,
140+ tlsConnSource tlsConnectionSource ,
141+ nc net.Conn ,
142+ addr address.Address ,
143+ config * tls.Config ,
144+ ocspOpts * ocsp.VerifyOptions ,
145+ ) (net.Conn , error ) {
146+ // Ensure config.ServerName is always set for SNI.
147+ if config .ServerName == "" {
148+ hostname := addr .String ()
149+ colonPos := strings .LastIndex (hostname , ":" )
150+ if colonPos == - 1 {
151+ colonPos = len (hostname )
152+ }
153+
154+ hostname = hostname [:colonPos ]
155+ config .ServerName = hostname
156+ }
157+
158+ client := tlsConnSource .Client (nc , config )
159+ if err := clientHandshake (ctx , client ); err != nil {
160+ return nil , err
161+ }
162+
163+ // Only do OCSP verification if TLS verification is requested.
164+ if ! config .InsecureSkipVerify {
165+ if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
166+ return nil , ocspErr
167+ }
168+ }
169+ return client , nil
170+ }
171+
145172// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
146173// handshakes. All errors returned by connect are considered "before the handshake completes" and
147174// must be handled by calling the appropriate SDAM handshake error handler.
@@ -317,6 +344,10 @@ func (c *connection) closeConnectContext() {
317344 }
318345}
319346
347+ func (c * connection ) cancellationListenerCallback () {
348+ _ = c .close ()
349+ }
350+
320351func transformNetworkError (ctx context.Context , originalError error , contextDeadlineUsed bool ) error {
321352 if originalError == nil {
322353 return nil
@@ -339,10 +370,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
339370 return originalError
340371}
341372
342- func (c * connection ) cancellationListenerCallback () {
343- _ = c .close ()
344- }
345-
346373func (c * connection ) writeWireMessage (ctx context.Context , wm []byte ) error {
347374 var err error
348375 if atomic .LoadInt64 (& c .state ) != connConnected {
@@ -423,7 +450,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
423450
424451 dst , errMsg , err := c .read (ctx )
425452 if err != nil {
426- if c .awaitingResponse == nil {
453+ if c .awaitRemainingBytes == nil {
427454 // If the connection was not marked as awaiting response, use the
428455 // pre-CSOT behavior and close the connection because we don't know
429456 // if there are other bytes left to read.
@@ -443,6 +470,29 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
443470 return dst , nil
444471}
445472
473+ func (c * connection ) parseWmSizeBytes (wmSizeBytes [4 ]byte ) (int32 , error ) {
474+ // read the length as an int32
475+ size := (int32 (wmSizeBytes [0 ])) |
476+ (int32 (wmSizeBytes [1 ]) << 8 ) |
477+ (int32 (wmSizeBytes [2 ]) << 16 ) |
478+ (int32 (wmSizeBytes [3 ]) << 24 )
479+
480+ if size < 4 {
481+ return 0 , fmt .Errorf ("malformed message length: %d" , size )
482+ }
483+ // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
484+ // defaultMaxMessageSize instead.
485+ maxMessageSize := c .desc .MaxMessageSize
486+ if maxMessageSize == 0 {
487+ maxMessageSize = defaultMaxMessageSize
488+ }
489+ if uint32 (size ) > maxMessageSize {
490+ return 0 , errResponseTooLarge
491+ }
492+
493+ return size , nil
494+ }
495+
446496func (c * connection ) read (ctx context.Context ) (bytesRead []byte , errMsg string , err error ) {
447497 go c .cancellationListener .Listen (ctx , c .cancellationListenerCallback )
448498 defer func () {
@@ -475,35 +525,23 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
475525 n , err := io .ReadFull (c .nc , sizeBuf [:])
476526 if err != nil {
477527 if l := int32 (n ); l == 0 && needToWait (err ) {
478- c .awaitingResponse = & l
528+ c .awaitRemainingBytes = & l
479529 }
480530 return nil , "incomplete read of message header" , err
481531 }
482-
483- // read the length as an int32
484- size := (int32 (sizeBuf [0 ])) | (int32 (sizeBuf [1 ]) << 8 ) | (int32 (sizeBuf [2 ]) << 16 ) | (int32 (sizeBuf [3 ]) << 24 )
485-
486- if size < 4 {
487- err = fmt .Errorf ("malformed message length: %d" , size )
532+ size , err := c .parseWmSizeBytes (sizeBuf )
533+ if err != nil {
488534 return nil , err .Error (), err
489535 }
490- // In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
491- // defaultMaxMessageSize instead.
492- maxMessageSize := c .desc .MaxMessageSize
493- if maxMessageSize == 0 {
494- maxMessageSize = defaultMaxMessageSize
495- }
496- if uint32 (size ) > maxMessageSize {
497- return nil , errResponseTooLarge .Error (), errResponseTooLarge
498- }
499536
500537 dst := make ([]byte , size )
501538 copy (dst , sizeBuf [:])
502539
503540 n , err = io .ReadFull (c .nc , dst [4 :])
504541 if err != nil {
505- if l := size - 4 - int32 (n ); l > 0 && needToWait (err ) {
506- c .awaitingResponse = & l
542+ remainingBytes := size - 4 - int32 (n )
543+ if remainingBytes > 0 && needToWait (err ) {
544+ c .awaitRemainingBytes = & remainingBytes
507545 }
508546 return dst , "incomplete read of full message" , err
509547 }
@@ -551,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) {
551589 c .canStream = canStream
552590}
553591
554- func (c initConnection ) supportsStreaming () bool {
555- return c .canStream
556- }
557-
558592func (c * connection ) setStreaming (streaming bool ) {
559593 c .currentlyStreaming = streaming
560594}
@@ -568,6 +602,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
568602 c .writeTimeout = timeout
569603}
570604
605+ // DriverConnectionID returns the driver connection ID.
606+ // TODO(GODRIVER-2824): change return type to int64.
607+ func (c * connection ) DriverConnectionID () uint64 {
608+ return c .driverConnectionID
609+ }
610+
571611func (c * connection ) ID () string {
572612 return c .id
573613}
@@ -576,6 +616,14 @@ func (c *connection) ServerConnectionID() *int64 {
576616 return c .serverConnectionID
577617}
578618
619+ func (c * connection ) OIDCTokenGenID () uint64 {
620+ return c .oidcTokenGenID
621+ }
622+
623+ func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
624+ c .oidcTokenGenID = genID
625+ }
626+
579627// initConnection is an adapter used during connection initialization. It has the minimum
580628// functionality necessary to implement the driver.Connection interface, which is required to pass a
581629// *connection to a Handshaker.
@@ -613,7 +661,7 @@ func (c initConnection) CurrentlyStreaming() bool {
613661 return c .getCurrentlyStreaming ()
614662}
615663func (c initConnection ) SupportsStreaming () bool {
616- return c .supportsStreaming ()
664+ return c .canStream
617665}
618666
619667// Connection implements the driver.Connection interface to allow reading and writing wire
@@ -847,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 {
847895 return c .connection .DriverConnectionID ()
848896}
849897
850- func configureTLS (ctx context.Context ,
851- tlsConnSource tlsConnectionSource ,
852- nc net.Conn ,
853- addr address.Address ,
854- config * tls.Config ,
855- ocspOpts * ocsp.VerifyOptions ,
856- ) (net.Conn , error ) {
857- // Ensure config.ServerName is always set for SNI.
858- if config .ServerName == "" {
859- hostname := addr .String ()
860- colonPos := strings .LastIndex (hostname , ":" )
861- if colonPos == - 1 {
862- colonPos = len (hostname )
863- }
864-
865- hostname = hostname [:colonPos ]
866- config .ServerName = hostname
867- }
868-
869- client := tlsConnSource .Client (nc , config )
870- if err := clientHandshake (ctx , client ); err != nil {
871- return nil , err
872- }
873-
874- // Only do OCSP verification if TLS verification is requested.
875- if ! config .InsecureSkipVerify {
876- if ocspErr := ocsp .Verify (ctx , client .ConnectionState (), ocspOpts ); ocspErr != nil {
877- return nil , ocspErr
878- }
879- }
880- return client , nil
881- }
882-
883898// OIDCTokenGenID returns the OIDC token generation ID.
884899func (c * Connection ) OIDCTokenGenID () uint64 {
885900 return c .oidcTokenGenID
@@ -933,11 +948,3 @@ func (c *cancellListener) StopListening() bool {
933948 c .done <- struct {}{}
934949 return c .aborted
935950}
936-
937- func (c * connection ) OIDCTokenGenID () uint64 {
938- return c .oidcTokenGenID
939- }
940-
941- func (c * connection ) SetOIDCTokenGenID (genID uint64 ) {
942- c .oidcTokenGenID = genID
943- }
0 commit comments