Skip to content

Commit 98bda43

Browse files
kumarlokeshprestonvasquez
authored andcommitted
GODRIVER-2117 - Check clientSession is not nil inside executeTestRunnerOperation (mongodb#1457)
1 parent 730e825 commit 98bda43

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

mongo/integration/unified_spec_test.go

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,46 +463,64 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,
463463

464464
var fp mtest.FailPoint
465465
if err := bson.Unmarshal(fpDoc.Document(), &fp); err != nil {
466-
return fmt.Errorf("Unmarshal error: %v", err)
466+
return fmt.Errorf("Unmarshal error: %w", err)
467467
}
468468

469+
if clientSession == nil {
470+
return errors.New("expected valid session, got nil")
471+
}
469472
targetHost := clientSession.PinnedServer.Addr.String()
470473
opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost})
471474
testutil.AddTestServerAPIVersion(opts)
472475
client, err := mongo.Connect(context.Background(), opts)
473476
if err != nil {
474-
return fmt.Errorf("Connect error for targeted client: %v", err)
477+
return fmt.Errorf("Connect error for targeted client: %w", err)
475478
}
476479
defer func() { _ = client.Disconnect(context.Background()) }()
477480

478481
if err = client.Database("admin").RunCommand(context.Background(), fp).Err(); err != nil {
479-
return fmt.Errorf("error setting targeted fail point: %v", err)
482+
return fmt.Errorf("error setting targeted fail point: %w", err)
480483
}
481484
mt.TrackFailPoint(fp.ConfigureFailPoint)
482485
case "configureFailPoint":
483486
fp, err := op.Arguments.LookupErr("failPoint")
484-
assert.Nil(mt, err, "failPoint not found in arguments")
487+
if err != nil {
488+
return fmt.Errorf("unable to find 'failPoint' in arguments: %w", err)
489+
}
485490
mt.SetFailPointFromDocument(fp.Document())
486491
case "assertSessionTransactionState":
487492
stateVal, err := op.Arguments.LookupErr("state")
488-
assert.Nil(mt, err, "state not found in arguments")
493+
if err != nil {
494+
return fmt.Errorf("unable to find 'state' in arguments: %w", err)
495+
}
489496
expectedState, ok := stateVal.StringValueOK()
490-
assert.True(mt, ok, "state argument is not a string")
497+
if !ok {
498+
return errors.New("expected 'state' argument to be string")
499+
}
491500

492-
assert.NotNil(mt, clientSession, "expected valid session, got nil")
501+
if clientSession == nil {
502+
return errors.New("expected valid session, got nil")
503+
}
493504
actualState := clientSession.TransactionState.String()
494505

495506
// actualState should match expectedState, but "in progress" is the same as
496507
// "in_progress".
497508
stateMatch := actualState == expectedState ||
498509
actualState == "in progress" && expectedState == "in_progress"
499-
assert.True(mt, stateMatch, "expected transaction state %v, got %v",
500-
expectedState, actualState)
510+
if !stateMatch {
511+
return fmt.Errorf("expected transaction state %v, got %v", expectedState, actualState)
512+
}
501513
case "assertSessionPinned":
514+
if clientSession == nil {
515+
return errors.New("expected valid session, got nil")
516+
}
502517
if clientSession.PinnedServer == nil {
503518
return errors.New("expected pinned server, got nil")
504519
}
505520
case "assertSessionUnpinned":
521+
if clientSession == nil {
522+
return errors.New("expected valid session, got nil")
523+
}
506524
// We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned
507525
// case provides the pinned server address in the error msg for debugging.
508526
if clientSession.PinnedServer != nil {
@@ -545,7 +563,7 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,
545563
case "waitForThread":
546564
waitForThread(mt, testCase, op)
547565
default:
548-
mt.Fatalf("unrecognized testRunner operation %v", op.Name)
566+
return fmt.Errorf("unrecognized testRunner operation %v", op.Name)
549567
}
550568

551569
return nil
@@ -572,7 +590,7 @@ func indexExists(dbName, collName, indexName string) (bool, error) {
572590
iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
573591
cursor, err := iv.List(context.Background())
574592
if err != nil {
575-
return false, fmt.Errorf("IndexView.List error: %v", err)
593+
return false, fmt.Errorf("IndexView.List error: %w", err)
576594
}
577595
defer cursor.Close(context.Background())
578596

@@ -607,7 +625,7 @@ func collectionExists(dbName, collName string) (bool, error) {
607625
// Use global client because listCollections cannot be executed inside a transaction.
608626
collections, err := mtest.GlobalClient().Database(dbName).ListCollectionNames(context.Background(), filter)
609627
if err != nil {
610-
return false, fmt.Errorf("ListCollectionNames error: %v", err)
628+
return false, fmt.Errorf("ListCollectionNames error: %w", err)
611629
}
612630

613631
return len(collections) > 0, nil
@@ -637,9 +655,8 @@ func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) err
637655
case "withTransaction":
638656
return executeWithTransaction(mt, sess, op.Arguments)
639657
default:
640-
mt.Fatalf("unrecognized session operation: %v", op.Name)
658+
return fmt.Errorf("unrecognized session operation: %v", op.Name)
641659
}
642-
return nil
643660
}
644661

645662
func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error {

0 commit comments

Comments
 (0)