diff --git a/tss/node/tsslib/common/tss.go b/tss/node/tsslib/common/tss.go index 34c558780..00d633043 100644 --- a/tss/node/tsslib/common/tss.go +++ b/tss/node/tsslib/common/tss.go @@ -11,52 +11,58 @@ import ( "github.com/binance-chain/tss-lib/tss" "github.com/libp2p/go-libp2p/core/peer" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/mantlenetworkio/mantle/l2geth/crypto" abnormal2 "github.com/mantlenetworkio/mantle/tss/node/tsslib/abnormal" "github.com/mantlenetworkio/mantle/tss/node/tsslib/conversion" "github.com/mantlenetworkio/mantle/tss/node/tsslib/messages" "github.com/mantlenetworkio/mantle/tss/node/tsslib/p2p" - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" ) type TssCommon struct { - conf TssConfig - logger zerolog.Logger - partyLock *sync.Mutex - partyInfo *abnormal2.PartyInfo - PartyIDtoP2PID map[string]peer.ID - unConfirmedMsgLock *sync.Mutex - unConfirmedMessages map[string]*LocalCacheItem - localPeerID string - broadcastChannel chan *messages.BroadcastMsgChan - TssMsg chan *p2p.Message - P2PPeersLock *sync.RWMutex - P2PPeers []peer.ID // most of tss message are broadcast, we store the peers ID to avoid iterating - msgID string - privateKey *ecdsa.PrivateKey - taskDone chan struct{} - abnormalMgr *abnormal2.Manager - finishedPeers map[string]bool - culprits []*tss.PartyID - culpritsLock *sync.RWMutex + conf TssConfig + logger zerolog.Logger + + partyLock sync.RWMutex + partyInfo *abnormal2.PartyInfo + + partyIDtoP2PIDMap *sync.Map // map[string]peer.ID + unConfirmedMessagesMap *sync.Map // map[string]*LocalCacheItem + + localPeerID string + broadcastChannel chan *messages.BroadcastMsgChan + TssMsg chan *p2p.Message + + P2PPeersLock sync.RWMutex + P2PPeers []peer.ID + + msgID string + privateKey *ecdsa.PrivateKey + taskDone chan struct{} + abnormalMgr *abnormal2.Manager + finishedPeers map[string]bool + + culpritsLock sync.RWMutex + culprits []*tss.PartyID + cachedWireBroadcastMsgLists *sync.Map cachedWireUnicastMsgLists *sync.Map - thresHold int + threshHold int } func NewTssCommon(peerID string, broadcastChannel chan *messages.BroadcastMsgChan, conf TssConfig, msgID string, privKey *ecdsa.PrivateKey, thresHold int) *TssCommon { return &TssCommon{ conf: conf, logger: log.With().Str("module", "tsscommon").Logger(), - partyLock: &sync.Mutex{}, + partyLock: sync.RWMutex{}, partyInfo: nil, - PartyIDtoP2PID: make(map[string]peer.ID), - unConfirmedMsgLock: &sync.Mutex{}, - unConfirmedMessages: make(map[string]*LocalCacheItem), + partyIDtoP2PIDMap: &sync.Map{}, + unConfirmedMessagesMap: &sync.Map{}, broadcastChannel: broadcastChannel, TssMsg: make(chan *p2p.Message), - P2PPeersLock: &sync.RWMutex{}, + P2PPeersLock: sync.RWMutex{}, P2PPeers: nil, msgID: msgID, localPeerID: peerID, @@ -64,10 +70,10 @@ func NewTssCommon(peerID string, broadcastChannel chan *messages.BroadcastMsgCha taskDone: make(chan struct{}), abnormalMgr: abnormal2.NewAbnormalManager(), finishedPeers: make(map[string]bool), - culpritsLock: &sync.RWMutex{}, + culpritsLock: sync.RWMutex{}, cachedWireBroadcastMsgLists: &sync.Map{}, cachedWireUnicastMsgLists: &sync.Map{}, - thresHold: thresHold, + threshHold: thresHold, } } @@ -147,8 +153,23 @@ func (t *TssCommon) GetTaskDone() chan struct{} { return t.taskDone } -func (t *TssCommon) GetThresHold() int { - return t.thresHold +func (t *TssCommon) GetThreshHold() int { + return t.threshHold +} + +func (t *TssCommon) InsertPartyIDtoP2PID(newMap map[string]peer.ID) { + for k, v := range newMap { + t.partyIDtoP2PIDMap.Store(k, v) + } +} + +func (t *TssCommon) GetPartyIDtoP2PID() map[string]peer.ID { + result := make(map[string]peer.ID) + t.partyIDtoP2PIDMap.Range(func(key, value any) bool { + result[key.(string)] = value.(peer.ID) + return true + }) + return result } func (t *TssCommon) GetAbnormalMgr() *abnormal2.Manager { @@ -162,8 +183,8 @@ func (t *TssCommon) SetPartyInfo(partyInfo *abnormal2.PartyInfo) { } func (t *TssCommon) getPartyInfo() *abnormal2.PartyInfo { - t.partyLock.Lock() - defer t.partyLock.Unlock() + t.partyLock.RLock() + defer t.partyLock.RUnlock() return t.partyInfo } @@ -248,14 +269,14 @@ func (t *TssCommon) updateLocal(wireMsg *messages.WireMessage) error { return fmt.Errorf("get message from unknown party %s", partyID.Id) } - dataOwnerPeerID, ok := t.PartyIDtoP2PID[wireMsg.Routing.From.Id] + dataOwnerPeerID, ok := t.partyIDtoP2PIDMap.Load(wireMsg.Routing.From.Id) if !ok { t.logger.Error().Msg("fail to find the peer ID of this party") return errors.New("fail to find the peer") } // here we log down this peer as the latest unicast peer if !wireMsg.Routing.IsBroadcast { - t.abnormalMgr.SetLastUnicastPeer(dataOwnerPeerID, wireMsg.RoundInfo) + t.abnormalMgr.SetLastUnicastPeer(dataOwnerPeerID.(peer.ID), wireMsg.RoundInfo) } var bulkMsg BulkWireMsg @@ -306,13 +327,20 @@ func (t *TssCommon) updateLocal(wireMsg *messages.WireMessage) error { } return false } - t.culpritsLock.RLock() - if len(t.culprits) != 0 && partyInlist(partyID, t.culprits) { - t.logger.Error().Msgf("the malicious party (party ID:%s) try to send incorrect message to me (party ID:%s)", partyID.Id, localMsgParty.PartyID().Id) - t.culpritsLock.RUnlock() - return errors.New("tss share verification failed") + err = func() error { + t.culpritsLock.RLock() + defer t.culpritsLock.RUnlock() + if len(t.culprits) != 0 && partyInlist(partyID, t.culprits) { + return errors.New("tss share verification failed") + } + return nil + }() + if err != nil { + t.logger.Err(err).Msgf("the malicious party (party ID:%s) try to send incorrect message to me (party ID:%s)", + partyID.Id, localMsgParty.PartyID().Id) + return err } - t.culpritsLock.RUnlock() + job := newJob(localMsgParty, bulkMsg.WiredBulkMsg, round.MsgIdentifier, partyID, bulkMsg.Routing.IsBroadcast) tssJobChan <- job } @@ -361,12 +389,12 @@ func (t *TssCommon) sendBulkMsg(wiredMsgType string, tssMsgType messages.TSSMess t.P2PPeersLock.RUnlock() } else { for _, each := range r.To { - peerID, ok := t.PartyIDtoP2PID[each.Id] + peerID, ok := t.partyIDtoP2PIDMap.Load(each.Id) if !ok { t.logger.Error().Msg("error in find the P2P ID") continue } - peerIDs = append(peerIDs, peerID) + peerIDs = append(peerIDs, peerID.(peer.ID)) } } t.renderToP2P(&messages.BroadcastMsgChan{ @@ -437,19 +465,13 @@ func (t *TssCommon) applyShare(localCacheItem *LocalCacheItem, key string, msgTy } t.logger.Debug().Msgf("remove key: %s", key) // the information had been confirmed by all party , we don't need it anymore - t.removeKey(key) + t.unConfirmedMessagesMap.Delete(key) return nil } -func (t *TssCommon) removeKey(key string) { - t.unConfirmedMsgLock.Lock() - defer t.unConfirmedMsgLock.Unlock() - delete(t.unConfirmedMessages, key) -} - func (t *TssCommon) hashCheck(localCacheItem *LocalCacheItem, threshold int) error { dataOwner := localCacheItem.Msg.Routing.From - dataOwnerP2PID, ok := t.PartyIDtoP2PID[dataOwner.Id] + dataOwnerP2PID, ok := t.partyIDtoP2PIDMap.Load(dataOwner.Id) if !ok { t.logger.Warn().Msgf("error in find the data Owner P2PID\n") return errors.New("error in find the data Owner P2PID") @@ -464,7 +486,7 @@ func (t *TssCommon) hashCheck(localCacheItem *LocalCacheItem, threshold int) err targetHashValue := localCacheItem.Hash for P2PID := range localCacheItem.ConfirmedList { - if P2PID == dataOwnerP2PID.String() { + if P2PID == dataOwnerP2PID.(peer.ID).String() { t.logger.Warn().Msgf("we detect that the data owner try to send the hash for his own message\n") delete(localCacheItem.ConfirmedList, P2PID) return abnormal2.ErrHashFromOwner @@ -602,15 +624,15 @@ func (t *TssCommon) processTSSMsg(wireMsg *messages.WireMessage, msgType message return fmt.Errorf("fail to calculate hash of the wire message: %w", err) } localCacheItem := t.TryGetLocalCacheItem(key) - if nil == localCacheItem { - t.logger.Debug().Msgf("++%s doesn't exist yet,add a new one", key) + if localCacheItem == nil { + t.logger.Debug().Msgf("%s doesn't exist yet,add a new one", key) localCacheItem = NewLocalCacheItem(wireMsg, msgHash) - t.updateLocalUnconfirmedMessages(key, localCacheItem) + t.unConfirmedMessagesMap.Store(key, localCacheItem) } else { // this means we received the broadcast confirm message from other party first - t.logger.Debug().Msgf("==%s exist", key) + t.logger.Debug().Msgf("%s exist", key) if localCacheItem.Msg == nil { - t.logger.Debug().Msgf("==%s exist, set message", key) + t.logger.Debug().Msgf("%s exist, set message", key) localCacheItem.Msg = wireMsg localCacheItem.Hash = msgHash } @@ -624,17 +646,9 @@ func (t *TssCommon) processTSSMsg(wireMsg *messages.WireMessage, msgType message } func (t *TssCommon) TryGetLocalCacheItem(key string) *LocalCacheItem { - t.unConfirmedMsgLock.Lock() - defer t.unConfirmedMsgLock.Unlock() - localCacheItem, ok := t.unConfirmedMessages[key] + localCacheItem, ok := t.unConfirmedMessagesMap.Load(key) if !ok { return nil } - return localCacheItem -} - -func (t *TssCommon) updateLocalUnconfirmedMessages(key string, cacheItem *LocalCacheItem) { - t.unConfirmedMsgLock.Lock() - defer t.unConfirmedMsgLock.Unlock() - t.unConfirmedMessages[key] = cacheItem + return localCacheItem.(*LocalCacheItem) } diff --git a/tss/node/tsslib/conversion/conversion.go b/tss/node/tsslib/conversion/conversion.go index cdf54de38..57fe73d0c 100644 --- a/tss/node/tsslib/conversion/conversion.go +++ b/tss/node/tsslib/conversion/conversion.go @@ -5,15 +5,16 @@ import ( "encoding/hex" "errors" "fmt" + "math/big" + "sort" + "strconv" + "github.com/binance-chain/tss-lib/crypto" "github.com/binance-chain/tss-lib/tss" "github.com/btcsuite/btcd/btcec" ethcrypto "github.com/ethereum/go-ethereum/crypto" crypto2 "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" - "math/big" - "sort" - "strconv" ) func GetParties(keys []string, localPartyKey string) ([]*tss.PartyID, *tss.PartyID, error) { @@ -52,15 +53,16 @@ func SetupPartyIDMap(partiesID []*tss.PartyID) map[string]*tss.PartyID { return partyIDMap } -func SetupIDMaps(parties map[string]*tss.PartyID, partyIDtoP2PID map[string]peer.ID) error { +func GeneratePartyIDtoP2PIDMaps(parties map[string]*tss.PartyID) (map[string]peer.ID, error) { + partyIDtoP2PID := make(map[string]peer.ID) for id, party := range parties { peerID, err := GetPeerIDFromPartyID(party) if err != nil { - return err + return nil, err } partyIDtoP2PID[id] = peerID } - return nil + return partyIDtoP2PID, nil } func GetPeerIDFromPartyID(partyID *tss.PartyID) (peer.ID, error) { @@ -107,10 +109,10 @@ func BytesToHashString(msg []byte) (string, error) { return hex.EncodeToString(h.Sum(nil)), nil } -func GetTssPubKey(pubKeyPoint *crypto.ECPoint) (string, []byte,[]byte, error) { +func GetTssPubKey(pubKeyPoint *crypto.ECPoint) (string, []byte, []byte, error) { // we check whether the point is on curve according to Kudelski report if pubKeyPoint == nil || !isOnCurve(pubKeyPoint.X(), pubKeyPoint.Y()) { - return "", nil,nil, errors.New("invalid points") + return "", nil, nil, errors.New("invalid points") } tssPubKey := btcec.PublicKey{ Curve: btcec.S256(), @@ -123,7 +125,7 @@ func GetTssPubKey(pubKeyPoint *crypto.ECPoint) (string, []byte,[]byte, error) { pubKeyBytes := tssPubKey.SerializeUncompressed() pubKeyBytes = pubKeyBytes[1:] - return pubKeyStr, address,pubKeyBytes, nil + return pubKeyStr, address, pubKeyBytes, nil } func PartyIDtoPubKey(party *tss.PartyID) (string, error) { diff --git a/tss/node/tsslib/keygen/tss_keygen.go b/tss/node/tsslib/keygen/tss_keygen.go index d4ca53e00..aba35333d 100644 --- a/tss/node/tsslib/keygen/tss_keygen.go +++ b/tss/node/tsslib/keygen/tss_keygen.go @@ -90,7 +90,7 @@ func (tKeyGen *TssKeyGen) GenerateNewKey(keygenReq Request) (*bcrypto.ECPoint, e keyGenLocalStateItem := storage2.KeygenLocalState{ ParticipantKeys: tKeyGen.ParticipantKeys, LocalPartyKey: tKeyGen.localNodePubKey, - Threshold: tKeyGen.tssCommonStruct.GetThresHold(), + Threshold: tKeyGen.tssCommonStruct.GetThreshHold(), } if err != nil { @@ -108,11 +108,12 @@ func (tKeyGen *TssKeyGen) GenerateNewKey(keygenReq Request) (*bcrypto.ECPoint, e abnormalMgr := tKeyGen.tssCommonStruct.GetAbnormalMgr() keyGenParty := keygen.NewLocalParty(params, outCh, endCh, *tKeyGen.preParams) partyIDMap := conversion.SetupPartyIDMap(partiesID) - err1 := conversion.SetupIDMaps(partyIDMap, tKeyGen.tssCommonStruct.PartyIDtoP2PID) - if err1 != nil { - tKeyGen.logger.Error().Msgf("error in creating mapping between partyID and P2P ID") + partyIDtoP2PIDMaps, err := conversion.GeneratePartyIDtoP2PIDMaps(partyIDMap) + if err != nil { + tKeyGen.logger.Err(err).Msgf("error in creating mapping between partyID and P2P ID") return nil, err } + tKeyGen.tssCommonStruct.InsertPartyIDtoP2PID(partyIDtoP2PIDMaps) // we never run multi keygen, so the moniker is set to default empty value partyInfo := &abnormal.PartyInfo{ Party: keyGenParty, @@ -122,7 +123,7 @@ func (tKeyGen *TssKeyGen) GenerateNewKey(keygenReq Request) (*bcrypto.ECPoint, e tKeyGen.tssCommonStruct.SetPartyInfo(partyInfo) abnormalMgr.SetPartyInfo(keyGenParty, partyIDMap) tKeyGen.tssCommonStruct.P2PPeersLock.Lock() - tKeyGen.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeyGen.tssCommonStruct.PartyIDtoP2PID, tKeyGen.tssCommonStruct.GetLocalPeerID()) + tKeyGen.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeyGen.tssCommonStruct.GetPartyIDtoP2PID(), tKeyGen.tssCommonStruct.GetLocalPeerID()) tKeyGen.tssCommonStruct.P2PPeersLock.Unlock() var keyGenWg sync.WaitGroup keyGenWg.Add(2) diff --git a/tss/node/tsslib/keysign/tss_keysign.go b/tss/node/tsslib/keysign/tss_keysign.go index 8060d1ac8..8bbff6827 100644 --- a/tss/node/tsslib/keysign/tss_keysign.go +++ b/tss/node/tsslib/keysign/tss_keysign.go @@ -90,21 +90,18 @@ func (tKeySign *TssKeySign) SignMessage(msgToSign []byte, localStateItem storage tKeySign.logger.Info().Msgf("message: (%s) keysign parties: %+v", m.String(), parties) eachLocalPartyID.Moniker = moniker tKeySign.localParties = nil - params := tss.NewParameters(btcec.S256(), ctx, eachLocalPartyID, len(partiesID), tKeySign.GetTssCommonStruct().GetThresHold()) + params := tss.NewParameters(btcec.S256(), ctx, eachLocalPartyID, len(partiesID), tKeySign.GetTssCommonStruct().GetThreshHold()) keySignParty := signing.NewLocalParty(m, params, localStateItem.LocalData, outCh, endCh) abnormalMgr := tKeySign.tssCommonStruct.GetAbnormalMgr() partyIDMap := conversion.SetupPartyIDMap(partiesID) - err = conversion.SetupIDMaps(partyIDMap, tKeySign.tssCommonStruct.PartyIDtoP2PID) + partyIDtoP2PIDMap, err := conversion.GeneratePartyIDtoP2PIDMaps(partyIDMap) if err != nil { - tKeySign.logger.Error().Err(err).Msgf("error in creating mapping between partyID and P2P ID") - return nil, err - } - err = conversion.SetupIDMaps(partyIDMap, abnormalMgr.PartyIDtoP2PID) - if err != nil { - tKeySign.logger.Error().Err(err).Msgf("error in creating mapping between partyID and P2P ID") + tKeySign.logger.Err(err).Msgf("error in creating mapping between partyID and P2P ID") return nil, err } + tKeySign.tssCommonStruct.InsertPartyIDtoP2PID(partyIDtoP2PIDMap) + abnormalMgr.PartyIDtoP2PID = partyIDtoP2PIDMap tKeySign.tssCommonStruct.SetPartyInfo(&abnormal.PartyInfo{ Party: keySignParty, @@ -114,7 +111,7 @@ func (tKeySign *TssKeySign) SignMessage(msgToSign []byte, localStateItem storage abnormalMgr.SetPartyInfo(keySignParty, partyIDMap) tKeySign.tssCommonStruct.P2PPeersLock.Lock() - tKeySign.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeySign.tssCommonStruct.PartyIDtoP2PID, tKeySign.tssCommonStruct.GetLocalPeerID()) + tKeySign.tssCommonStruct.P2PPeers = conversion.GetPeersID(tKeySign.tssCommonStruct.GetPartyIDtoP2PID(), tKeySign.tssCommonStruct.GetLocalPeerID()) tKeySign.tssCommonStruct.P2PPeersLock.Unlock() var keySignWg sync.WaitGroup keySignWg.Add(2)