Skip to content

Commit

Permalink
Detect and bypass cycles during token revocation
Browse files Browse the repository at this point in the history
Fixes #4803
  • Loading branch information
Jim Kalafut committed Sep 13, 2018
1 parent e31cdb7 commit 77cb849
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 15 deletions.
38 changes: 28 additions & 10 deletions vault/token_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi

// Initialize the store
t := &TokenStore{
view: view,
cubbyholeDestroyer: destroyCubbyhole,
logger: logger,
tokenLocks: locksutil.CreateLocks(),
tokensPendingDeletion: &sync.Map{},
saltLock: sync.RWMutex{},
core: c,
view: view,
cubbyholeDestroyer: destroyCubbyhole,
logger: logger,
tokenLocks: locksutil.CreateLocks(),
tokensPendingDeletion: &sync.Map{},
saltLock: sync.RWMutex{},
core: c,
identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies,
tidyLock: new(uint32),
quitContext: c.activeContext,
Expand Down Expand Up @@ -1200,16 +1200,31 @@ func (ts *TokenStore) revokeTree(ctx context.Context, id string) error {
// Updated to be non-recursive and revoke child tokens
// before parent tokens(DFS).
func (ts *TokenStore) revokeTreeSalted(ctx context.Context, saltedID string) error {
var dfs []string
dfs = append(dfs, saltedID)
dfs := []string{saltedID}
seenIDs := map[string]struct{}{
saltedID: struct{}{},
}

for l := len(dfs); l > 0; l = len(dfs) {
id := dfs[0]
path := parentPrefix + id + "/"
children, err := ts.view.List(ctx, path)

childrenRaw, err := ts.view.List(ctx, path)
if err != nil {
return errwrap.Wrapf("failed to scan for children: {{err}}", err)
}

// Filter the child list to remove any items that have ever been in the dfs queue.
// This is a robustness check, as a parent/child cycle can lead to an OOM crash.
children := make([]string, 0, len(childrenRaw))
for _, child := range childrenRaw {
if _, seen := seenIDs[child]; !seen {
children = append(children, child)
} else {
ts.Logger().Warn("token cycle found", "token", child)
}
}

// If the length of the children array is zero,
// then we are at a leaf node.
if len(children) == 0 {
Expand All @@ -1231,6 +1246,9 @@ func (ts *TokenStore) revokeTreeSalted(ctx context.Context, saltedID string) err
// If we make it here, there are children and they must
// be prepended.
dfs = append(children, dfs...)
for _, child := range children {
seenIDs[child] = struct{}{}
}
}
}

Expand Down
34 changes: 29 additions & 5 deletions vault/token_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,17 +954,41 @@ func TestTokenStore_Revoke_Orphan(t *testing.T) {
// This was the original function name, and now it just calls
// the non recursive version for a variety of depths.
func TestTokenStore_RevokeTree(t *testing.T) {
testTokenStore_RevokeTree_NonRecursive(t, 1)
testTokenStore_RevokeTree_NonRecursive(t, 2)
testTokenStore_RevokeTree_NonRecursive(t, 10)
testTokenStore_RevokeTree_NonRecursive(t, 1, false)
testTokenStore_RevokeTree_NonRecursive(t, 2, false)
testTokenStore_RevokeTree_NonRecursive(t, 10, false)

// corrupted trees with cycles
testTokenStore_RevokeTree_NonRecursive(t, 1, true)
testTokenStore_RevokeTree_NonRecursive(t, 10, true)
}

// Revokes a given Token Store tree non recursively.
// The second parameter refers to the depth of the tree.
func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64) {
func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64, injectCycles bool) {
c, _, _ := TestCoreUnsealed(t)
ts := c.tokenStore
root, children := buildTokenTree(t, ts, depth)

if injectCycles {
// Make the root the parent of itself
saltedRoot, _ := ts.SaltID(context.Background(), root.ID)
le := &logical.StorageEntry{Key: fmt.Sprintf("parent/%s/%s", saltedRoot, saltedRoot)}

if err := ts.view.Put(context.Background(), le); err != nil {
t.Fatalf("err: %v", err)
}

// Make a deep child the parent of a shallow child
shallow, _ := ts.SaltID(context.Background(), children[0].ID)
deep, _ := ts.SaltID(context.Background(), children[len(children)-1].ID)
le = &logical.StorageEntry{Key: fmt.Sprintf("parent/%s/%s", deep, shallow)}

if err := ts.view.Put(context.Background(), le); err != nil {
t.Fatalf("err: %v", err)
}
}

err := ts.revokeTree(context.Background(), "")

if err.Error() != "cannot tree-revoke blank token" {
Expand Down Expand Up @@ -998,7 +1022,7 @@ func BenchmarkTokenStore_RevokeTree(b *testing.B) {
for _, depth := range benchmarks {
b.Run(fmt.Sprintf("Tree of Depth %d", depth), func(b *testing.B) {
for i := 0; i < b.N; i++ {
testTokenStore_RevokeTree_NonRecursive(b, depth)
testTokenStore_RevokeTree_NonRecursive(b, depth, false)
}
})
}
Expand Down

0 comments on commit 77cb849

Please sign in to comment.