diff --git a/pkg/wallet/dsbackend.go b/pkg/wallet/dsbackend.go index 0e5881d721..7f6b5a6bf1 100644 --- a/pkg/wallet/dsbackend.go +++ b/pkg/wallet/dsbackend.go @@ -7,6 +7,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "github.com/awnumar/memguard" "github.com/filecoin-project/go-address" @@ -50,7 +51,7 @@ type DSBackend struct { password *memguard.Enclave unLocked map[address.Address]*key.KeyInfo - state int + state atomic.Int64 } var _ Backend = (*DSBackend)(nil) @@ -198,11 +199,8 @@ func (backend *DSBackend) putKeyInfo(ctx context.Context, ki *key.KeyInfo) error if err := backend.ds.Put(ctx, ds.NewKey(key.Address.String()), keyJSON); err != nil { return errors.Wrapf(err, "failed to store new address: %s", key.Address.String()) } - - backend.lk.Lock() backend.cache[addr] = struct{}{} backend.unLocked[addr] = ki - backend.lk.Unlock() return nil } @@ -285,9 +283,7 @@ func (backend *DSBackend) getKey(ctx context.Context, addr address.Address, pass } func (backend *DSBackend) LockWallet(ctx context.Context) error { - backend.lk.Lock() - defer backend.lk.Unlock() - if backend.state == Lock { + if backend.state.Load() == Lock { return fmt.Errorf("already locked") } @@ -296,24 +292,24 @@ func (backend *DSBackend) LockWallet(ctx context.Context) error { } for _, addr := range backend.Addresses(ctx) { + backend.lk.Lock() delete(backend.unLocked, addr) + backend.lk.Unlock() } backend.cleanPassword() - backend.state = Lock + backend.state.Store(Lock) return nil } // UnLockWallet unlock wallet with password, decrypt local key in db and save to protected memory func (backend *DSBackend) UnLockWallet(ctx context.Context, password []byte) error { - backend.lk.Lock() defer func() { - backend.lk.Unlock() for i := range password { password[i] = 0 } }() - if backend.state == Unlock { + if backend.state.Load() == Unlock { return fmt.Errorf("already unlocked") } @@ -327,17 +323,17 @@ func (backend *DSBackend) UnLockWallet(ctx context.Context, password []byte) err return err } + backend.lk.Lock() backend.unLocked[addr] = ki + backend.lk.Unlock() } - backend.state = Unlock + backend.state.Store(Unlock) return nil } // SetPassword set password for wallet , and wallet used this password to encrypt private key func (backend *DSBackend) SetPassword(ctx context.Context, password []byte) error { - backend.lk.Lock() - defer backend.lk.Unlock() if backend.password != nil { return ErrRepeatPassword } @@ -347,13 +343,12 @@ func (backend *DSBackend) SetPassword(ctx context.Context, password []byte) erro if err != nil { return err } - + backend.lk.Lock() backend.unLocked[addr] = ki - } - if backend.state == undetermined { - backend.state = Unlock + backend.lk.Unlock() } + backend.state.CompareAndSwap(undetermined, Unlock) backend.setPassword(password) return nil @@ -366,9 +361,7 @@ func (backend *DSBackend) HasPassword() bool { // WalletState return wallet state(lock/unlock) func (backend *DSBackend) WalletState(ctx context.Context) int { - backend.lk.Lock() - defer backend.lk.Unlock() - return backend.state + return int(backend.state.Load()) } func (backend *DSBackend) setPassword(password []byte) {