Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport of Fix tests - Update MongoDB driver into release/1.11.x #17715

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 8 additions & 87 deletions helper/testhelpers/mongodb/mongodbhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@ package mongodb

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"os"
"strconv"
"strings"
"testing"
"time"

"github.com/hashicorp/vault/helper/testhelpers/docker"
"gopkg.in/mgo.v2"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)

// PrepareTestContainer calls PrepareTestContainerWithDatabase without a
Expand Down Expand Up @@ -44,22 +40,16 @@ func PrepareTestContainerWithDatabase(t *testing.T, version, dbName string) (fun
if dbName != "" {
connURL = fmt.Sprintf("%s/%s", connURL, dbName)
}
dialInfo, err := ParseMongoURL(connURL)
if err != nil {
return nil, err
}

session, err := mgo.DialWithInfo(dialInfo)
ctx, _ = context.WithTimeout(context.Background(), 1*time.Minute)
client, err := mongo.Connect(ctx, options.Client().ApplyURI(connURL))
if err != nil {
return nil, err
}
defer session.Close()

session.SetSyncTimeout(1 * time.Minute)
session.SetSocketTimeout(1 * time.Minute)
err = session.Ping()
if err != nil {
return nil, err
err = client.Ping(ctx, readpref.Primary())
if err = client.Disconnect(ctx); err != nil {
t.Fatal()
}

return docker.NewServiceURLParse(connURL)
Expand All @@ -70,72 +60,3 @@ func PrepareTestContainerWithDatabase(t *testing.T, version, dbName string) (fun

return svc.Cleanup, svc.Config.URL().String()
}

// ParseMongoURL will parse a connection string and return a configured dialer
func ParseMongoURL(rawURL string) (*mgo.DialInfo, error) {
url, err := url.Parse(rawURL)
if err != nil {
return nil, err
}

info := mgo.DialInfo{
Addrs: strings.Split(url.Host, ","),
Database: strings.TrimPrefix(url.Path, "/"),
Timeout: 10 * time.Second,
}

if url.User != nil {
info.Username = url.User.Username()
info.Password, _ = url.User.Password()
}

query := url.Query()
for key, values := range query {
var value string
if len(values) > 0 {
value = values[0]
}

switch key {
case "authSource":
info.Source = value
case "authMechanism":
info.Mechanism = value
case "gssapiServiceName":
info.Service = value
case "replicaSet":
info.ReplicaSetName = value
case "maxPoolSize":
poolLimit, err := strconv.Atoi(value)
if err != nil {
return nil, errors.New("bad value for maxPoolSize: " + value)
}
info.PoolLimit = poolLimit
case "ssl":
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
// ourselves. See https://github.com/go-mgo/mgo/issues/84
ssl, err := strconv.ParseBool(value)
if err != nil {
return nil, errors.New("bad value for ssl: " + value)
}
if ssl {
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{})
}
}
case "connect":
if value == "direct" {
info.Direct = true
break
}
if value == "replicaSet" {
break
}
fallthrough
default:
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
}
}

return &info, nil
}
19 changes: 5 additions & 14 deletions plugins/database/mongodb/connection_producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ import (
"time"

"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/ory/dockertest"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"gopkg.in/mgo.v2"
)

func TestInit_clientTLS(t *testing.T) {
Expand Down Expand Up @@ -215,19 +213,12 @@ func startMongoWithTLS(t *testing.T, version string, confDir string) (retURL str
// exponential backoff-retry
err = pool.Retry(func() error {
var err error
dialInfo, err := mongodb.ParseMongoURL(retURL)
if err != nil {
return err
}

session, err := mgo.DialWithInfo(dialInfo)
if err != nil {
return err
ctx, _ := context.WithTimeout(context.Background(), 1*time.Minute)
client, err := mongo.Connect(ctx, options.Client().ApplyURI(retURL))
if err = client.Disconnect(ctx); err != nil {
t.Fatal()
}
defer session.Close()
session.SetSyncTimeout(1 * time.Minute)
session.SetSocketTimeout(1 * time.Minute)
return session.Ping()
return client.Ping(ctx, readpref.Primary())
})
if err != nil {
cleanup()
Expand Down
13 changes: 6 additions & 7 deletions plugins/database/mongodb/mongodb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
"github.com/hashicorp/vault/helper/testhelpers/mongodb"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
Expand All @@ -27,7 +26,7 @@ import (
const mongoAdminRole = `{ "db": "admin", "roles": [ { "role": "readWrite" } ] }`

func TestMongoDB_Initialize(t *testing.T) {
cleanup, connURL := mongodb.PrepareTestContainer(t, "5.0.10")
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
defer cleanup()

db := new()
Expand Down Expand Up @@ -120,7 +119,7 @@ func TestNewUser_usernameTemplate(t *testing.T) {

for name, test := range tests {
t.Run(name, func(t *testing.T) {
cleanup, connURL := mongodb.PrepareTestContainer(t, "5.0.10")
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
defer cleanup()

db := new()
Expand All @@ -146,7 +145,7 @@ func TestNewUser_usernameTemplate(t *testing.T) {
}

func TestMongoDB_CreateUser(t *testing.T) {
cleanup, connURL := mongodb.PrepareTestContainer(t, "5.0.10")
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
defer cleanup()

db := new()
Expand Down Expand Up @@ -178,7 +177,7 @@ func TestMongoDB_CreateUser(t *testing.T) {
}

func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
cleanup, connURL := mongodb.PrepareTestContainer(t, "5.0.10")
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
defer cleanup()

initReq := dbplugin.InitializeRequest{
Expand Down Expand Up @@ -212,7 +211,7 @@ func TestMongoDB_CreateUser_writeConcern(t *testing.T) {
}

func TestMongoDB_DeleteUser(t *testing.T) {
cleanup, connURL := mongodb.PrepareTestContainer(t, "5.0.10")
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
defer cleanup()

db := new()
Expand Down Expand Up @@ -252,7 +251,7 @@ func TestMongoDB_DeleteUser(t *testing.T) {
}

func TestMongoDB_UpdateUser_Password(t *testing.T) {
cleanup, connURL := mongodb.PrepareTestContainer(t, "5.0.10")
cleanup, connURL := mongodb.PrepareTestContainer(t, "latest")
defer cleanup()

// The docker test method PrepareTestContainer defaults to a database "test"
Expand Down