diff --git a/server/accounts.go b/server/accounts.go index e4ab5e0941f..3c5bed95e69 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -3598,15 +3598,16 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim clients := map[*client]struct{}{} // We need to check all accounts that have an import claim from this account. awcsti := map[string]struct{}{} + + // We must only allow one goroutine to go through here, otherwise we could deadlock + // due to locking two accounts in succession. + s.mu.Lock() s.accounts.Range(func(k, v any) bool { acc := v.(*Account) // Move to the next if this account is actually account "a". if acc.Name == a.Name { return true } - // TODO: checkStreamImportAuthorized() stack should not be trying - // to lock "acc". If we find that to be needed, we will need to - // rework this to ensure we don't lock acc. acc.mu.Lock() for _, im := range acc.imports.streams { if im != nil && im.acc.Name == a.Name { @@ -3621,6 +3622,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim acc.mu.Unlock() return true }) + s.mu.Unlock() // Now walk clients. for c := range clients { c.processSubsOnConfigReload(awcsti) @@ -3628,15 +3630,15 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim } // Now check if service exports have changed. if !a.checkServiceExportsEqual(old) || signersChanged || serviceTokenExpirationChanged { + // We must only allow one goroutine to go through here, otherwise we could deadlock + // due to locking two accounts in succession. + s.mu.Lock() s.accounts.Range(func(k, v any) bool { acc := v.(*Account) // Move to the next if this account is actually account "a". if acc.Name == a.Name { return true } - // TODO: checkServiceImportAuthorized() stack should not be trying - // to lock "acc". If we find that to be needed, we will need to - // rework this to ensure we don't lock acc. acc.mu.Lock() for _, sis := range acc.imports.services { for _, si := range sis { @@ -3646,6 +3648,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim // Make sure we should still be tracking latency and if we // are allowed to trace. if !si.response { + a.mu.RLock() if se := a.getServiceExport(si.to); se != nil { if si.latency != nil { si.latency = se.latency @@ -3653,6 +3656,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim // Update allow trace. si.atrc = se.atrc } + a.mu.RUnlock() } } } @@ -3660,6 +3664,7 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim acc.mu.Unlock() return true }) + s.mu.Unlock() } // Now make sure we shutdown the old service import subscriptions. diff --git a/server/jwt_test.go b/server/jwt_test.go index 38d40722453..d5b6b1f5f79 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -7117,3 +7117,107 @@ func TestJWTImportsOnServerRestartAndClientsReconnect(t *testing.T) { receive(t) } } + +func TestJWTUpdateAccountClaimsStreamAndServiceImportDeadlock(t *testing.T) { + for _, exportType := range []jwt.ExportType{jwt.Stream, jwt.Service} { + t.Run(exportType.String(), func(t *testing.T) { + s := opTrustBasicSetup() + defer s.Shutdown() + buildMemAccResolver(s) + + // Get operator. + okp, err := nkeys.FromSeed(oSeed) + require_NoError(t, err) + + type Acc struct { + pub string + ac *jwt.AccountClaims + a *Account + c *client + } + + // Create accounts. + var accs []*Acc + numAccounts := 10 + for i := 0; i < numAccounts; i++ { + aKp, err := nkeys.CreateAccount() + require_NoError(t, err) + aPub, err := aKp.PublicKey() + require_NoError(t, err) + aAC := jwt.NewAccountClaims(aPub) + aJWT, err := aAC.Encode(okp) + require_NoError(t, err) + addAccountToMemResolver(s, aPub, aJWT) + + aAcc, err := s.LookupAccount(aPub) + require_NoError(t, err) + + aAcc.mu.Lock() + c := aAcc.internalClient() + aAcc.mu.Unlock() + aAcc.addClient(c) + + accs = append(accs, &Acc{aPub, aAC, aAcc, c}) + } + + addImportExport := func(i int, acc *Acc) { + localSubject := fmt.Sprintf("%s.%d", acc.pub, i) + acc.ac.Exports.Add(&jwt.Export{Subject: jwt.Subject(localSubject), Type: exportType}) + for _, oAcc := range accs { + if acc.pub == oAcc.pub { + continue + } + externalSubject := fmt.Sprintf("%s.%d", oAcc.pub, i) + acc.ac.Imports.Add(&jwt.Import{Account: oAcc.pub, Subject: jwt.Subject(externalSubject), Type: exportType}) + } + } + test := func(i int) { + var start sync.WaitGroup + var release sync.WaitGroup + var finish sync.WaitGroup + start.Add(numAccounts) + release.Add(1) + finish.Add(numAccounts) + + // Add imports/exports to both accounts and update in parallel, should not deadlock. + for _, acc := range accs { + acc := acc + go func() { + defer finish.Done() + addImportExport(i, acc) + jwt, err := acc.ac.Encode(okp) + addAccountToMemResolver(s, acc.pub, jwt) + start.Done() + require_NoError(t, err) + + release.Wait() + s.UpdateAccountClaims(acc.a, acc.ac) + }() + } + + start.Wait() + + // Lock all clients, once we release below we'll get all claim updates + // in the same place after initial checks. + for _, acc := range accs { + acc.c.mu.Lock() + } + release.Done() + + // Wait some time for them to reach that point and be blocked on the client lock. + time.Sleep(time.Second) + for _, acc := range accs { + acc.c.mu.Unlock() + } + + // Eventually all goroutines should finish. + finish.Wait() + } + + // Repeat test multiple times, increasing the amount of imports/exports along the way. + for i := 0; i < 30; i++ { + test(i) + } + }) + } +}