Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

database: Avoid race condition in connection creation #26147

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions builtin/logical/database/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]m

type databaseBackend struct {
// connections holds configured database connections by config name
connections *syncmap.SyncMap[string, *dbPluginInstance]
logger log.Logger
createConnectionLock sync.Mutex
connections *syncmap.SyncMap[string, *dbPluginInstance]
logger log.Logger

*framework.Backend
// credRotationQueue is an in-memory priority queue used to track Static Roles
Expand Down Expand Up @@ -291,11 +292,23 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage,
}

func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) {
// fast path, reuse the existing connection
dbi := b.connections.Get(name)
if dbi != nil {
return dbi, nil
}

// slow path, create a new connection
// if we don't lock the rest of the operation, there is a race condition for multiple callers of this function
b.createConnectionLock.Lock()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we delete this block, the fix will still work; we still see a spike in open file descriptors on leadership transfers under load, but the file descriptors are closed correctly (because we switch to PutIfEmpty below).

defer b.createConnectionLock.Unlock()

// check again in case we lost the race
dbi = b.connections.Get(name)
if dbi != nil {
return dbi, nil
}

id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
Expand Down Expand Up @@ -332,14 +345,17 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri
name: name,
runningPluginVersion: pluginVersion,
}
oldConn := b.connections.Put(name, dbi)
if oldConn != nil {
err := oldConn.Close()
conn, ok := b.connections.PutIfEmpty(name, dbi)
if !ok {
// this is a bug
b.Logger().Warn("BUG: there was a race condition adding to the database connection map")
// There was already an existing connection, so we will use that and close our new one to avoid a race condition.
err := dbi.Close()
if err != nil {
b.Logger().Warn("Error closing database connection", "error", err)
b.Logger().Warn("Error closing new database connection", "error", err)
}
}
return dbi, nil
return conn, nil
}

// ClearConnection closes the database connection and
Expand Down
109 changes: 109 additions & 0 deletions builtin/logical/database/backend_get_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package database

import (
"context"
"sync"
"testing"

"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/queue"
)

func newSystemViewWrapper(view logical.SystemView) logical.SystemView {
return &systemViewWrapper{
view,
}
}

type systemViewWrapper struct {
logical.SystemView
}

var _ logical.ExtendedSystemView = (*systemViewWrapper)(nil)

func (s *systemViewWrapper) RequestWellKnownRedirect(ctx context.Context, src, dest string) error {
panic("nope")
}

func (s *systemViewWrapper) DeregisterWellKnownRedirect(ctx context.Context, src string) bool {
panic("nope")
}

func (s *systemViewWrapper) Auditor() logical.Auditor {
panic("nope")
}

func (s *systemViewWrapper) ForwardGenericRequest(ctx context.Context, request *logical.Request) (*logical.Response, error) {
panic("nope")
}

func (s *systemViewWrapper) APILockShouldBlockRequest() (bool, error) {
panic("nope")
}

func (s *systemViewWrapper) GetPinnedPluginVersion(ctx context.Context, pluginType consts.PluginType, pluginName string) (*pluginutil.PinnedVersion, error) {
return nil, pluginutil.ErrPinnedVersionNotFound
}

func (s *systemViewWrapper) LookupPluginVersion(ctx context.Context, pluginName string, pluginType consts.PluginType, version string) (*pluginutil.PluginRunner, error) {
return &pluginutil.PluginRunner{
Name: mockv5,
Type: consts.PluginTypeDatabase,
Builtin: true,
BuiltinFactory: New,
}, nil
}

func getDbBackend(t *testing.T) (*databaseBackend, logical.Storage) {
t.Helper()
config := logical.TestBackendConfig()
config.System = newSystemViewWrapper(config.System)
config.StorageView = &logical.InmemStorage{}
// Create and init the backend ourselves instead of using a Factory because
// the factory function kicks off threads that cause racy tests.
b := Backend(config)
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}
b.schedule = &TestSchedule{}
b.credRotationQueue = queue.New()
b.populateQueue(context.Background(), config.StorageView)

return b, config.StorageView
}

// TestGetConnectionRaceCondition checks that GetConnection always returns the same instance, even when asked
// by multiple goroutines in parallel.
func TestGetConnectionRaceCondition(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

ctx := context.Background()
b, s := getDbBackend(t)
defer b.Cleanup(ctx)
configureDBMount(t, s)

goroutines := 16

wg := sync.WaitGroup{}
wg.Add(goroutines)
dbis := make([]*dbPluginInstance, goroutines)
errs := make([]error, goroutines)
for i := 0; i < goroutines; i++ {
go func(i int) {
defer wg.Done()
dbis[i], errs[i] = b.GetConnection(ctx, s, mockv5)
}(i)
}
wg.Wait()
for i := 0; i < goroutines; i++ {
if errs[i] != nil {
t.Fatal(errs[i])
}
if dbis[0] != dbis[i] {
t.Fatal("Error: database instances did not match")
}
}
}
3 changes: 3 additions & 0 deletions builtin/logical/database/mockv5.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ func (m MockDatabaseV5) Initialize(ctx context.Context, req v5.InitializeRequest
"req", req)

config := req.Config
if config == nil {
config = map[string]interface{}{}
}
config["from-plugin"] = "this value is from the plugin itself"

resp := v5.InitializeResponse{
Expand Down
11 changes: 6 additions & 5 deletions builtin/logical/database/rotation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
)

const (
mockv5 = "mockv5"
dbUser = "vaultstatictest"
dbUserDefaultPassword = "password"
testMinRotationWindowSeconds = 5
Expand Down Expand Up @@ -1446,7 +1447,7 @@ func TestStoredWALsCorrectlyProcessed(t *testing.T) {

rotationPeriodData := map[string]interface{}{
"username": "hashicorp",
"db_name": "mockv5",
"db_name": mockv5,
"rotation_period": "86400s",
}

Expand Down Expand Up @@ -1500,7 +1501,7 @@ func TestStoredWALsCorrectlyProcessed(t *testing.T) {
},
map[string]interface{}{
"username": "hashicorp",
"db_name": "mockv5",
"db_name": mockv5,
"rotation_schedule": "*/10 * * * * *",
},
},
Expand Down Expand Up @@ -1699,9 +1700,9 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
dbi := &dbPluginInstance{
database: dbw,
id: "foo-id",
name: "mockV5",
name: mockv5,
}
b.connections.Put("mockv5", dbi)
b.connections.Put(mockv5, dbi)

return mockDB
}
Expand All @@ -1710,7 +1711,7 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase {
// plugin init code paths, allowing us to use a manually populated mock DB object.
func configureDBMount(t *testing.T, storage logical.Storage) {
t.Helper()
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/mockv5"), &DatabaseConfig{
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/"+mockv5), &DatabaseConfig{
AllowedRoles: []string{"*"},
})
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions changelog/26147.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
secret/database: Fixed race condition where database mounts may leak connections
```
14 changes: 14 additions & 0 deletions helper/syncmap/syncmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@ func (m *SyncMap[K, V]) Put(k K, v V) V {
return oldV
}

// PutIfEmpty adds the given key-value pair to the map only if there is no value already in it,
// and returns the new value and true if so.
// If there is already a value, it returns the existing value and false.
func (m *SyncMap[K, V]) PutIfEmpty(k K, v V) (V, bool) {
m.lock.Lock()
defer m.lock.Unlock()
oldV, ok := m.data[k]
if ok {
return oldV, false
}
m.data[k] = v
return v, true
}

// Clear deletes all entries from the map, and returns the previous map.
func (m *SyncMap[K, V]) Clear() map[K]V {
m.lock.Lock()
Expand Down
Loading