From 9b6d2acb88dfac6bccd1e3a4f9a2450c082ee174 Mon Sep 17 00:00:00 2001 From: hc-github-team-secure-vault-core <82990506+hc-github-team-secure-vault-core@users.noreply.github.com> Date: Wed, 6 Sep 2023 12:09:00 -0400 Subject: [PATCH] backport of commit f97822da31c1374df4bab4a90880a1f094fad313 (#22796) Co-authored-by: Victor Rodriguez --- vault/seal/seal.go | 338 +++++++++++++++++++++++++++---------- vault/seal/seal_wrapper.go | 73 ++++---- vault/seal_autoseal.go | 2 + 3 files changed, 286 insertions(+), 127 deletions(-) diff --git a/vault/seal/seal.go b/vault/seal/seal.go index 02213e36796d..7935200a9dac 100644 --- a/vault/seal/seal.go +++ b/vault/seal/seal.go @@ -161,6 +161,16 @@ func (sgi *SealGenerationInfo) UnmarshalJSON(b []byte) error { return nil } +// OldKey is used as a return value from Decrypt to indicate that the old +// key was used for decryption and that the value should be re-encrypted +// with the new key and saved. It is not returned as an error by any +// function. +var OldKey = errors.New("decrypted with old key") + +func IsOldKeyError(err error) bool { + return errors.Is(err, OldKey) +} + // Access is the embedded implementation of autoSeal that contains logic // specific to encrypting and decrypting data, or in this case keys. type Access interface { @@ -270,33 +280,38 @@ func NewAccessFromWrapper(logger hclog.Logger, wrapper wrapping.Wrapper, sealCon } func (a *access) GetAllSealWrappersByPriority() []*SealWrapper { - return copySealWrappers(a.wrappersByPriority, false) + return a.filterSealWrappers(enabledAndDisabled, healthyAndUnhealthy) } func (a *access) GetEnabledSealWrappersByPriority() []*SealWrapper { - return copySealWrappers(a.wrappersByPriority, true) + return a.filterSealWrappers(enabledOnly, healthyAndUnhealthy) } func (a *access) AllSealWrappersHealthy() bool { - for _, sw := range a.wrappersByPriority { - // Ignore disabled seals - if sw.Disabled { - continue - } - if !sw.IsHealthy() { - return false - } - } - return true + return len(a.wrappersByPriority) == len(a.filterSealWrappers(enabledAndDisabled, healthyOnly)) } -func copySealWrappers(sealWrappers []*SealWrapper, enabledOnly bool) []*SealWrapper { - ret := make([]*SealWrapper, 0, len(sealWrappers)) - for _, sw := range sealWrappers { - if enabledOnly && sw.Disabled { +type enabledFilter bool +type healthyFilter bool + +const ( + enabledOnly = enabledFilter(true) + enabledAndDisabled = !enabledOnly + healthyOnly = healthyFilter(true) + healthyAndUnhealthy = !healthyOnly +) + +func (a *access) filterSealWrappers(enabled enabledFilter, healthy healthyFilter) []*SealWrapper { + ret := make([]*SealWrapper, 0, len(a.wrappersByPriority)) + for _, sw := range a.wrappersByPriority { + switch { + case enabled == enabledOnly && sw.Disabled: + continue + case healthy == healthyOnly && !sw.IsHealthy(): continue + default: + ret = append(ret, sw) } - ret = append(ret, sw) } return ret } @@ -348,11 +363,10 @@ func (a *access) IsUpToDate(ctx context.Context, value *MultiWrapValue, forceKey a.logger.Error("error refreshing seal key IDs") return false, JoinSealWrapErrors("cannot determine key IDs of Access wrappers", errs) } - // TODO(SEALHA): What to do if there are partial failures? if len(errs) > 0 { msg := "could not determine key IDs of some Access wrappers" - a.logger.Warn(msg) - a.logger.Trace("partial failure refreshing seal key IDs", "err", JoinSealWrapErrors(msg, errs)) + a.logger.Error("partial failure refreshing seal key IDs", "err", JoinSealWrapErrors(msg, errs)) + return false, JoinSealWrapErrors(msg, errs) } a.keyIdSet.set(test) } @@ -360,43 +374,89 @@ func (a *access) IsUpToDate(ctx context.Context, value *MultiWrapValue, forceKey return a.keyIdSet.equal(value), nil } +const ( + // wrapperEncryptTimeout is the duration we will wait for seal wrappers to return from an encrypt call. + // After the timeout, we return any successful results and errors for the rest of the wrappers, so + // that a partial seal wrap entry can be recorded. + wrapperEncryptTimeout = 10 * time.Second + + // wrapperDecryptHighPriorityHeadStart is the duration we wait for the highest priority wrapper + // to return from a decrypt call before we try decrypting with any additional wrappers. + wrapperDecryptHighPriorityHeadStart = 2 * time.Second +) + // Encrypt uses the underlying seal to encrypt the plaintext and returns it. func (a *access) Encrypt(ctx context.Context, plaintext []byte, options ...wrapping.Option) (*MultiWrapValue, map[string]error) { - var slots []*wrapping.BlobInfo - errs := make(map[string]error) - - for _, sealWrapper := range a.GetEnabledSealWrappersByPriority() { - now := time.Now() - var encryptErr error - defer func(now time.Time) { - metrics.MeasureSince([]string{"seal", "encrypt", "time"}, now) - metrics.MeasureSince([]string{"seal", sealWrapper.Name, "encrypt", "time"}, now) - - if encryptErr != nil { - metrics.IncrCounter([]string{"seal", "encrypt", "error"}, 1) - metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt", "error"}, 1) + // Note that we do not encrypt with disabled wrappers. Disabled wrappers are only used to decrypt. + enabledWrappersByPriority := a.filterSealWrappers(enabledOnly, healthyOnly) + if len(enabledWrappersByPriority) == 0 { + // If all seals are unhealthy, try any way since a seal may have recovered + enabledWrappersByPriority = a.filterSealWrappers(enabledOnly, healthyAndUnhealthy) + } + + type result struct { + name string + ciphertext *wrapping.BlobInfo + err error + } + resultCh := make(chan *result) + + encryptCtx, cancelEncryptCtx := context.WithTimeout(ctx, wrapperEncryptTimeout) + defer cancelEncryptCtx() + + // Start goroutines to encrypt the value using each of the wrappers. + for _, sealWrapper := range enabledWrappersByPriority { + go func(sealWrapper *SealWrapper) { + ciphertext, err := a.tryEncrypt(encryptCtx, sealWrapper, plaintext, options...) + resultCh <- &result{ + name: sealWrapper.Name, + ciphertext: ciphertext, + err: err, } - }(now) - - metrics.IncrCounter([]string{"seal", "encrypt"}, 1) - metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt"}, 1) + }(sealWrapper) + } - ciphertext, encryptErr := sealWrapper.Wrapper.Encrypt(ctx, plaintext, options...) - if encryptErr != nil { - a.logger.Warn("error encrypting with seal", "seal", sealWrapper.Name) - a.logger.Trace("error encrypting with seal", "seal", sealWrapper.Name, "err", encryptErr) + results := make(map[string]*result) +GATHER_RESULTS: + for { + select { + case result := <-resultCh: + results[result.name] = result + if len(results) == len(enabledWrappersByPriority) { + break GATHER_RESULTS + } + case <-encryptCtx.Done(): + break GATHER_RESULTS + case <-ctx.Done(): + cancelEncryptCtx() + break GATHER_RESULTS + } + } - errs[sealWrapper.Name] = encryptErr - sealWrapper.SetHealthy(false, now) + // Sort out the successful results from the errors + var slots []*wrapping.BlobInfo + errs := make(map[string]error) + for _, sealWrapper := range enabledWrappersByPriority { + if result, ok := results[sealWrapper.Name]; ok { + if result.err != nil { + errs[sealWrapper.Name] = result.err + } else { + slots = append(slots, result.ciphertext) + } } else { - a.logger.Trace("encrypted value using seal", "seal", sealWrapper.Name, "keyId", ciphertext.KeyInfo.KeyId) - - slots = append(slots, ciphertext) + if encryptCtx.Err() != nil { + errs[sealWrapper.Name] = encryptCtx.Err() + } else { + // Just being paranoid, encryptCtx.Err() should never be nil in this case + errs[sealWrapper.Name] = errors.New("context timeout exceeded") + } + // This failure did not happen on tryDecrypt, so we must log it here + a.logger.Trace("error encrypting with seal", "seal", sealWrapper.Name, "err", errs[sealWrapper.Name]) } } if len(slots) == 0 { - a.logger.Error("all seals failed to encrypt value") + a.logger.Error("failed to encrypt value using any seal wrappers") return nil, errs } @@ -407,12 +467,44 @@ func (a *access) Encrypt(ctx context.Context, plaintext []byte, options ...wrapp Slots: slots, } - // cache key IDs - a.keyIdSet.set(ret) + if len(errs) == 0 { + // cache key IDs + a.keyIdSet.set(ret) + } return ret, errs } +func (a *access) tryEncrypt(ctx context.Context, sealWrapper *SealWrapper, plaintext []byte, options ...wrapping.Option) (*wrapping.BlobInfo, error) { + now := time.Now() + var encryptErr error + defer func(now time.Time) { + metrics.MeasureSince([]string{"seal", "encrypt", "time"}, now) + metrics.MeasureSince([]string{"seal", sealWrapper.Name, "encrypt", "time"}, now) + + if encryptErr != nil { + metrics.IncrCounter([]string{"seal", "encrypt", "error"}, 1) + metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt", "error"}, 1) + } + }(now) + + metrics.IncrCounter([]string{"seal", "encrypt"}, 1) + metrics.IncrCounter([]string{"seal", sealWrapper.Name, "encrypt"}, 1) + + ciphertext, encryptErr := sealWrapper.Wrapper.Encrypt(ctx, plaintext, options...) + if encryptErr != nil { + a.logger.Warn("error encrypting with seal", "seal", sealWrapper.Name) + a.logger.Trace("error encrypting with seal", "seal", sealWrapper.Name, "err", encryptErr) + + sealWrapper.SetHealthy(false, now) + return nil, encryptErr + } + a.logger.Trace("encrypted value using seal", "seal", sealWrapper.Name, "keyId", ciphertext.KeyInfo.KeyId) + + sealWrapper.SetHealthy(true, now) + return ciphertext, nil +} + // Decrypt uses the underlying seal to decrypt the ciphertext and returns it. // Note that it is possible depending on the wrapper used that both pt and err // are populated. @@ -426,46 +518,85 @@ func (a *access) Decrypt(ctx context.Context, ciphertext *MultiWrapValue, option return nil, false, err } - // First, lets try the wrappers in order of priority and look for an exact key ID match - for _, sealWrapper := range a.GetAllSealWrappersByPriority() { - if keyId, err := sealWrapper.Wrapper.KeyId(ctx); err == nil { - if blobInfo, ok := blobInfoMap[keyId]; ok { - pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfo, options) - if oldKey { - a.logger.Trace("decrypted using OldKey", "seal", sealWrapper.Name) - return pt, false, err - } - if err == nil { - a.logger.Trace("decrypted value using seal", "seal", sealWrapper.Name) - return pt, isUpToDate, nil - } - // If there is an error, keep trying with the other wrappers - a.logger.Trace("error decrypting with seal, will try other seals", "seal", sealWrapper.Name, "keyId", keyId, "err", err) - } + wrappersByPriority := a.filterSealWrappers(enabledAndDisabled, healthyOnly) + if len(wrappersByPriority) == 0 { + // If all seals are unhealthy, try any way since a seal may have recovered + wrappersByPriority = a.filterSealWrappers(enabledAndDisabled, healthyAndUnhealthy) + } + + type result struct { + name string + pt []byte + oldKey bool + err error + } + resultCh := make(chan *result) + + decrypt := func(sealWrapper *SealWrapper) { + pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfoMap, options) + resultCh <- &result{ + name: sealWrapper.Name, + pt: pt, + oldKey: oldKey, + err: err, + } + } + + // Start goroutines to decrypt the value + for i, sealWrapper := range wrappersByPriority { + sealWrapper := sealWrapper + if i == 0 { + // start the highest priority wrapper right away + go decrypt(sealWrapper) + } else { + timer := time.AfterFunc(wrapperDecryptHighPriorityHeadStart, func() { + decrypt(sealWrapper) + }) + defer timer.Stop() } } - // No key ID match, so try each wrapper with all slots + // Gathering failures, but return right away if there is a succesful result errs := make(map[string]error) - for _, sealWrapper := range a.GetAllSealWrappersByPriority() { - for _, blobInfo := range ciphertext.Slots { - pt, oldKey, err := a.tryDecrypt(ctx, sealWrapper, blobInfo, options) - if oldKey { - a.logger.Trace("decrypted using OldKey", "seal", sealWrapper.Name) - return pt, false, err - } - if err == nil { - a.logger.Trace("decrypted value using seal", "seal", sealWrapper.Name) - return pt, isUpToDate, nil +GATHER_RESULTS: + for { + select { + case result := <-resultCh: + switch { + case result.err != nil: + errs[result.name] = result.err + if len(errs) == len(wrappersByPriority) { + break GATHER_RESULTS + } + + case result.oldKey: + return result.pt, false, OldKey + + default: + return result.pt, isUpToDate, nil } - errs[sealWrapper.Name] = err + case <-ctx.Done(): + break GATHER_RESULTS } } - return nil, false, JoinSealWrapErrors("error decrypting seal wrapped value", errs) + // No wrapper was able to decrypt the value, return an error + + if len(errs) > 0 { + return nil, false, JoinSealWrapErrors("error decrypting seal wrapped value", errs) + } + + if ctx.Err() != nil { + return nil, false, ctx.Err() + } + // Just being paranoid, ctx.Err() should never be nil in this case + return nil, false, errors.New("context timeout exceeded") } -func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphertext *wrapping.BlobInfo, options []wrapping.Option) ([]byte, bool, error) { +// tryDecrypt returns the plaintext and a flad indicating whether the decryption was done by the "unwrapSeal" (see +// sealWrapMigration.Decrypt). +func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphertextByKeyId map[string]*wrapping.BlobInfo, options []wrapping.Option) ([]byte, bool, error) { + now := time.Now() var decryptErr error defer func(now time.Time) { metrics.MeasureSince([]string{"seal", "decrypt", "time"}, now) @@ -475,19 +606,52 @@ func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphe metrics.IncrCounter([]string{"seal", "decrypt", "error"}, 1) metrics.IncrCounter([]string{"seal", sealWrapper.Name, "decrypt", "error"}, 1) } - // TODO (multiseal): log an error? - }(time.Now()) + }(now) metrics.IncrCounter([]string{"seal", "decrypt"}, 1) metrics.IncrCounter([]string{"seal", sealWrapper.Name, "decrypt"}, 1) - pt, err := sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...) - isOldKey := false - if err != nil && err.Error() == "decrypted with old key" { - // This is for compatibility with sealWrapMigration - isOldKey = true + var pt []byte + + // First, let's look for an exact key ID match + var keyId string + if id, err := sealWrapper.Wrapper.KeyId(ctx); err == nil { + keyId = id + if ciphertext, ok := ciphertextByKeyId[keyId]; ok { + pt, decryptErr = sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...) + + sealWrapper.SetHealthy(decryptErr == nil || IsOldKeyError(decryptErr), now) + } + } + // If we don't get a result, try all the slots + if pt == nil && decryptErr == nil { + for _, ciphertext := range ciphertextByKeyId { + pt, decryptErr = sealWrapper.Wrapper.Decrypt(ctx, ciphertext, options...) + if decryptErr == nil { + // Note that we only update wrapper health for failures on exact key ID match, + // otherwise we would have false negatives. + sealWrapper.SetHealthy(true, now) + break + } + } + } + + switch { + case decryptErr != nil && IsOldKeyError(decryptErr): + // an OldKey error is not an actual error, it just means that the decryption was done + // by the "unwrapSeal" of a seal migration (see sealWrapMigration.Decrypt). + a.logger.Trace("decrypted using OldKey", "seal_name", sealWrapper.Name) + return pt, true, nil + + case decryptErr != nil: + // Note that if there are more than one ciphertext, the error may be misleading... + a.logger.Trace("error decrypting with seal, this may be a harmless mismatch between wrapper and ciphertext", "seal_name", sealWrapper.Name, "keyId", keyId, "err", decryptErr) + return nil, false, decryptErr + + default: + a.logger.Trace("decrypted value using seal", "seal_name", sealWrapper.Name) + return pt, false, nil } - return pt, isOldKey, err } func JoinSealWrapErrors(msg string, errorMap map[string]error) error { diff --git a/vault/seal/seal_wrapper.go b/vault/seal/seal_wrapper.go index 474bf3815f8d..543407154b96 100644 --- a/vault/seal/seal_wrapper.go +++ b/vault/seal/seal_wrapper.go @@ -26,7 +26,9 @@ type SealWrapper struct { // Disabled indicates, when true indicates that this wrapper should only be used for decryption. Disabled bool - // hcLock protects lastHealthy, lastSeenHealthy, and healthy. Do not modify those fields directly, use setHealth instead. + // hcLock protects lastHealthy, lastSeenHealthy, and healthy. + // Do not modify those fields directly, use setHealth instead. + // Do not access these fields directly, use getHealth instead. hcLock sync.RWMutex lastHealthCheck time.Time lastSeenHealthy time.Time @@ -42,53 +44,36 @@ func NewSealWrapper(wrapper wrapping.Wrapper, priority int, name string, sealCon Disabled: disabled, } - ret.setHealth(true, time.Now(), ret.lastHealthCheck) + setHealth(ret, true, time.Now(), ret.lastHealthCheck) return ret } -func (sw *SealWrapper) rlock() func() { - sw.hcLock.RLock() - return sw.hcLock.RUnlock -} - -func (sw *SealWrapper) lock() func() { - sw.hcLock.Lock() - return sw.hcLock.Unlock -} - func (sw *SealWrapper) SetHealthy(healthy bool, checkTime time.Time) { - unlock := sw.lock() - defer unlock() - - wasHealthy := sw.healthy - lastHealthy := sw.lastSeenHealthy - if !wasHealthy && healthy { - lastHealthy = checkTime + if healthy { + setHealth(sw, true, checkTime, checkTime) + } else { + // do not update lastSeenHealthy + setHealth(sw, false, sw.lastHealthCheck, checkTime) } - - sw.setHealth(healthy, lastHealthy, checkTime) } func (sw *SealWrapper) IsHealthy() bool { - unlock := sw.rlock() - defer unlock() + healthy, _, _ := getHealth(sw) - return sw.healthy + return healthy } func (sw *SealWrapper) LastSeenHealthy() time.Time { - unlock := sw.rlock() - defer unlock() + _, lastSeenHealthy, _ := getHealth(sw) - return sw.lastSeenHealthy + return lastSeenHealthy } func (sw *SealWrapper) LastHealthCheck() time.Time { - unlock := sw.rlock() - defer unlock() + _, _, lastHealthCheck := getHealth(sw) - return sw.lastHealthCheck + return lastHealthCheck } var ( @@ -99,35 +84,43 @@ var ( ) func (sw *SealWrapper) CheckHealth(ctx context.Context, checkTime time.Time) error { - unlock := sw.lock() - defer unlock() - - // Assume the wrapper is unhealthy, if we make it to the end we'll set it to true - sw.setHealth(false, sw.lastSeenHealthy, checkTime) - testVal := fmt.Sprintf("Heartbeat %d", mathrand.Intn(1000)) ciphertext, err := sw.Wrapper.Encrypt(ctx, []byte(testVal), nil) if err != nil { + sw.SetHealthy(false, checkTime) return fmt.Errorf("failed to encrypt test value, seal wrapper may be unreachable: %w", err) } ctx, cancel := context.WithTimeout(ctx, HealthTestTimeout) defer cancel() plaintext, err := sw.Wrapper.Decrypt(ctx, ciphertext, nil) - if err != nil { + if err != nil && !IsOldKeyError(err) { + sw.SetHealthy(false, checkTime) return fmt.Errorf("failed to decrypt test value, seal wrapper may be unreachable: %w", err) } if !bytes.Equal([]byte(testVal), plaintext) { + sw.SetHealthy(false, checkTime) return errors.New("failed to decrypt health test value to expected result") } - sw.setHealth(true, checkTime, checkTime) + sw.SetHealthy(true, checkTime) return nil } -// setHealth sets the fields protected by sw.hcLock, callers *must* hold the write lock. -func (sw *SealWrapper) setHealth(healthy bool, lastSeenHealthy, lastHealthCheck time.Time) { +// getHealth is the only function allowed to inspect the health fields directly +func getHealth(sw *SealWrapper) (healthy bool, lastSeenHealthy time.Time, lastHealthCheck time.Time) { + sw.hcLock.RLock() + defer sw.hcLock.RUnlock() + + return sw.healthy, sw.lastSeenHealthy, sw.lastHealthCheck +} + +// setHealth is the only function allowed to mutate the health fields +func setHealth(sw *SealWrapper, healthy bool, lastSeenHealthy, lastHealthCheck time.Time) { + sw.hcLock.Lock() + defer sw.hcLock.Unlock() + sw.healthy = healthy sw.lastSeenHealthy = lastSeenHealthy sw.lastHealthCheck = lastHealthCheck diff --git a/vault/seal_autoseal.go b/vault/seal_autoseal.go index 01135e287a05..6a3a225c7f15 100644 --- a/vault/seal_autoseal.go +++ b/vault/seal_autoseal.go @@ -470,6 +470,8 @@ func (d *autoSeal) StartHealthCheck() { ctx, cancel := context.WithTimeout(ctx, seal.HealthTestTimeout) defer cancel() + d.logger.Trace("performing a seal health check") + allHealthy := true allUnhealthy := true for _, sealWrapper := range d.Access.GetAllSealWrappersByPriority() {