diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 3967bf7f82..239074db6d 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -364,7 +364,9 @@ func (t *T) ResetClient(opts *options.ClientOptions) { t.clientOpts = opts } - _ = t.Client.Disconnect(context.Background()) + if t.Client != nil { + _ = t.Client.Disconnect(context.Background()) + } t.createTestClient() t.DB = t.Client.Database(t.dbName) t.Coll = t.DB.Collection(t.collName, t.collOpts) diff --git a/internal/integration/sessions_test.go b/internal/integration/sessions_test.go index 02a345fd38..f0fcd21eb6 100644 --- a/internal/integration/sessions_test.go +++ b/internal/integration/sessions_test.go @@ -13,10 +13,12 @@ import ( "fmt" "reflect" "sync" + "sync/atomic" "testing" "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" @@ -508,6 +510,78 @@ func TestSessionsProse(t *testing.T) { assert.True(mt, limitedSessionUse, limitedSessMsg, len(ops)) }) + + mt.ResetClient(options.Client()) + client := mt.Client + heartbeatStarted := make(chan struct{}, 1) + heartbeatSucceeded := make(chan struct{}, 1) + var clusterTimeAdvanced uint32 + serverMonitor := &event.ServerMonitor{ + ServerHeartbeatStarted: func(*event.ServerHeartbeatStartedEvent) { + if atomic.LoadUint32(&clusterTimeAdvanced) == 1 { + select { + case heartbeatStarted <- struct{}{}: + // NOOP + default: + // NOOP + } + } + }, + ServerHeartbeatSucceeded: func(*event.ServerHeartbeatSucceededEvent) { + if atomic.LoadUint32(&clusterTimeAdvanced) == 1 { + select { + case heartbeatSucceeded <- struct{}{}: + // NOOP + default: + // NOOP + } + } + }, + } + pingOpts := mtest.NewOptions(). + CreateCollection(false). + ClientOptions(options.Client(). + SetServerMonitor(serverMonitor). + SetHeartbeatInterval(500 * time.Millisecond). // Minimum interval + SetDirect(true)). + ClientType(mtest.Pinned) + mt.RunOpts("20 Drivers do not gossip $clusterTime on SDAM commands", pingOpts, func(mt *mtest.T) { + wait := func(mt *mtest.T, ch <-chan struct{}, label string) { + mt.Helper() + + select { + case <-ch: + case <-time.After(5 * time.Second): + mt.Fatalf("timed out waiting for %s", label) + } + } + + err := mt.Client.Ping(context.Background(), readpref.Primary()) + assert.NoError(mt, err, "expected no error, got: %v", err) + + _, err = client.Database("test").Collection("test").InsertOne(context.Background(), bson.D{{"advance", "$clusterTime"}}) + require.NoError(mt, err, "expected no error inserting document, got: %v", err) + + atomic.StoreUint32(&clusterTimeAdvanced, 1) + wait(mt, heartbeatStarted, "ServerHeartbeatStartedEvent") + wait(mt, heartbeatSucceeded, "ServerHeartbeatSucceededEvent") + + err = mt.Client.Ping(context.Background(), readpref.Primary()) + require.NoError(mt, err, "expected no error, got: %v", err) + + succeededEvents := mt.GetAllSucceededEvents() + require.Len(mt, succeededEvents, 2, "expected 2 succeeded events, got: %v", len(succeededEvents)) + require.Equal(mt, "ping", succeededEvents[0].CommandName, "expected first command to be ping, got: %v", succeededEvents[0].CommandName) + initialClusterTime, err := succeededEvents[0].Reply.LookupErr("$clusterTime") + require.NoError(mt, err, "$clusterTime not found in response") + + startedEvents := mt.GetAllStartedEvents() + require.Len(mt, startedEvents, 2, "expected 2 started events, got: %v", len(startedEvents)) + require.Equal(mt, "ping", startedEvents[1].CommandName, "expected second command to be ping, got: %v", startedEvents[1].CommandName) + currentClusterTime, err := startedEvents[1].Command.LookupErr("$clusterTime") + require.NoError(mt, err, "$clusterTime not found in command") + assert.Equal(mt, initialClusterTime, currentClusterTime, "expected same cluster time, got %v and %v", initialClusterTime, currentClusterTime) + }) } type sessionFunction struct { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 50136456e4..b3a36d79a3 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1103,9 +1103,11 @@ func (op Operation) readWireMessage(ctx context.Context, conn *mnet.Connection) // decode res, err := op.decodeResult(opcode, rem) - // Update cluster/operation time and recovery tokens before handling the error to ensure we're properly updating - // everything. - op.updateClusterTimes(res) + // When a cluster clock is given, update cluster/operation time and recovery tokens before handling the error + // to ensure we're properly updating everything. + if op.Clock != nil { + op.updateClusterTimes(res) + } op.updateOperationTime(res) op.Client.UpdateRecoveryToken(bson.Raw(res)) @@ -1699,7 +1701,10 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) if (clock == nil && client == nil) || !sessionsSupported(desc.WireVersion) { return dst } - clusterTime := clock.GetClusterTime() + var clusterTime bson.Raw + if clock != nil { + clusterTime = clock.GetClusterTime() + } if client != nil { clusterTime = session.MaxClusterTime(clusterTime, client.ClusterTime) } @@ -1711,7 +1716,6 @@ func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) return dst } return append(bsoncore.AppendHeader(dst, bsoncore.Type(val.Type), "$clusterTime"), val.Value...) - // return bsoncore.AppendDocumentElement(dst, "$clusterTime", clusterTime) } // calculateMaxTimeMS calculates the value of the 'maxTimeMS' field to potentially append diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 7fab1136f8..2d870809bf 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -842,7 +842,6 @@ func (s *Server) setupHeartbeatConnection(ctx context.Context) error { func (s *Server) createBaseOperation(conn *mnet.Connection) *operation.Hello { return operation. NewHello(). - ClusterClock(s.cfg.clock). Deployment(driver.SingleConnectionDeployment{C: conn}). ServerAPI(s.cfg.serverAPI) }