diff --git a/consensus/consensus.go b/consensus/consensus.go index 115a28a76e..db43b3db58 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -68,7 +68,7 @@ type ChainHeaderReader interface { } type VotePool interface { - FetchVoteByBlockHash(blockHash common.Hash) []*types.VoteEnvelope + FetchVotesByBlockHash(blockHash common.Hash) []*types.VoteEnvelope } // ChainReader defines a small collection of methods needed to access the local diff --git a/consensus/parlia/parlia.go b/consensus/parlia/parlia.go index d9166568af..c5e6e49f7c 100644 --- a/consensus/parlia/parlia.go +++ b/consensus/parlia/parlia.go @@ -85,6 +85,8 @@ const ( // `finalityRewardInterval` should be smaller than `inMemorySnapshots`, otherwise, it will result in excessive computation. finalityRewardInterval = 200 + + kAncestorGenerationDepth = 2 ) var ( @@ -452,8 +454,17 @@ func (p *Parlia) getParent(chain consensus.ChainHeaderReader, header *types.Head return parent, nil } +// trimParents safely removes last element if exists. +func trimParents(parents []*types.Header) []*types.Header { + if len(parents) > 1 { + return parents[:len(parents)-1] + } + return nil +} + // verifyVoteAttestation checks whether the vote attestation in the header is valid. func (p *Parlia) verifyVoteAttestation(chain consensus.ChainHeaderReader, header *types.Header, parents []*types.Header) error { + // === Step 1: Extract attestation === epochLength, err := p.epochLength(chain, header, parents) if err != nil { return err @@ -471,21 +482,15 @@ func (p *Parlia) verifyVoteAttestation(chain consensus.ChainHeaderReader, header if len(attestation.Extra) > types.MaxAttestationExtraLength { return fmt.Errorf("invalid attestation, too large extra length: %d", len(attestation.Extra)) } + if attestation.Data.SourceNumber >= attestation.Data.TargetNumber { + return errors.New("invalid attestation, SourceNumber not lower than TargetNumber") + } - // Get parent block + // === Step 2: Verify source block === parent, err := p.getParent(chain, header, parents) if err != nil { return err } - - // The target block should be direct parent. - targetNumber := attestation.Data.TargetNumber - targetHash := attestation.Data.TargetHash - if targetNumber != parent.Number.Uint64() || targetHash != parent.Hash() { - return fmt.Errorf("invalid attestation, target mismatch, expected block: %d, hash: %s; real block: %d, hash: %s", - parent.Number.Uint64(), parent.Hash(), targetNumber, targetHash) - } - // The source block should be the highest justified block. sourceNumber := attestation.Data.SourceNumber sourceHash := attestation.Data.SourceHash @@ -502,17 +507,34 @@ func (p *Parlia) verifyVoteAttestation(chain consensus.ChainHeaderReader, header justifiedBlockNumber, justifiedBlockHash, sourceNumber, sourceHash) } - // The snapshot should be the targetNumber-1 block's snapshot. - if len(parents) > 1 { - parents = parents[:len(parents)-1] - } else { - parents = nil + // === Step 3: Verify target block === + targetNumber := attestation.Data.TargetNumber + targetHash := attestation.Data.TargetHash + match := false + ancestor := parent + ancestorParents := trimParents(parents) + for range p.GetAncestorGenerationDepth(header) { + if targetNumber == ancestor.Number.Uint64() && targetHash == ancestor.Hash() { + match = true + break + } + + ancestor, err = p.getParent(chain, ancestor, ancestorParents) + if err != nil { + return err + } + ancestorParents = trimParents(ancestorParents) } - snap, err := p.snapshot(chain, parent.Number.Uint64()-1, parent.ParentHash, parents) + if !match { + return fmt.Errorf("invalid attestation, target mismatch, real block: %d, hash: %s", targetNumber, targetHash) + } + + // === Step 4: Check quorum === + // The snapshot should be the targetNumber-1 block's snapshot. + snap, err := p.snapshot(chain, ancestor.Number.Uint64()-1, ancestor.ParentHash, ancestorParents) if err != nil { return err } - // Filter out valid validator from attestation. validators := snap.validators() validatorsBitSet := bitset.From([]uint64{uint64(attestation.VoteAddressSet)}) @@ -531,13 +553,12 @@ func (p *Parlia) verifyVoteAttestation(chain consensus.ChainHeaderReader, header } votedAddrs = append(votedAddrs, voteAddr) } - // The valid voted validators should be no less than 2/3 validators. if len(votedAddrs) < cmath.CeilDiv(len(snap.Validators)*2, 3) { return errors.New("invalid attestation, not enough validators voted") } - // Verify the aggregated signature. + // === Step 5: Signature verification === aggSig, err := bls.SignatureFromBytes(attestation.AggSignature[:]) if err != nil { return fmt.Errorf("BLS signature converts failed: %v", err) @@ -1025,44 +1046,61 @@ func (p *Parlia) prepareTurnLength(chain consensus.ChainHeaderReader, header *ty return nil } +// assembleVoteAttestation collects votes and assembles the vote attestation into the block header. func (p *Parlia) assembleVoteAttestation(chain consensus.ChainHeaderReader, header *types.Header) error { - if !p.chainConfig.IsLuban(header.Number) || header.Number.Uint64() < 2 { - return nil - } - - if p.VotePool == nil { + // === Step 1: Preconditions === + if !p.chainConfig.IsLuban(header.Number) || header.Number.Uint64() < 3 || p.VotePool == nil { return nil } - // Fetch direct parent's votes + // === Step 2: Find target header with quorum votes === parent := chain.GetHeaderByHash(header.ParentHash) if parent == nil { return errors.New("parent not found") } - snap, err := p.snapshot(chain, parent.Number.Uint64()-1, parent.ParentHash, nil) + justifiedBlockNumber, justifiedBlockHash, err := p.GetJustifiedNumberAndHash(chain, []*types.Header{parent}) if err != nil { - return err + return errors.New("unexpected error when getting the highest justified number and hash") + } + var ( + votes []*types.VoteEnvelope + targetHeader = parent + targetHeaderParentSnap *Snapshot + ) + for range p.GetAncestorGenerationDepth(header) { + snap, err := p.snapshot(chain, targetHeader.Number.Uint64()-1, targetHeader.ParentHash, nil) + if err != nil { + return err + } + votes = p.VotePool.FetchVotesByBlockHash(targetHeader.Hash()) + quorum := cmath.CeilDiv(len(snap.Validators)*2, 3) + if len(votes) >= quorum { + targetHeaderParentSnap = snap + break + } + + targetHeader = chain.GetHeaderByHash(targetHeader.ParentHash) + if targetHeader == nil { + return errors.New("parent not found") + } + if targetHeader.Number.Uint64() <= justifiedBlockNumber { + break + } } - votes := p.VotePool.FetchVoteByBlockHash(parent.Hash()) - if len(votes) < cmath.CeilDiv(len(snap.Validators)*2, 3) { + if targetHeaderParentSnap == nil { return nil } - // Prepare vote attestation - // Prepare vote data - justifiedBlockNumber, justifiedBlockHash, err := p.GetJustifiedNumberAndHash(chain, []*types.Header{parent}) - if err != nil { - return errors.New("unexpected error when getting the highest justified number and hash") - } + // === Step 3: Build vote attestation === attestation := &types.VoteAttestation{ Data: &types.VoteData{ SourceNumber: justifiedBlockNumber, SourceHash: justifiedBlockHash, - TargetNumber: parent.Number.Uint64(), - TargetHash: parent.Hash(), + TargetNumber: targetHeader.Number.Uint64(), + TargetHash: targetHeader.Hash(), }, } - // Check vote data from votes + // Validate vote data consistency for _, vote := range votes { if vote.Data.Hash() != attestation.Data.Hash() { return fmt.Errorf("vote check error, expected: %v, real: %v", attestation.Data, vote) @@ -1070,10 +1108,10 @@ func (p *Parlia) assembleVoteAttestation(chain consensus.ChainHeaderReader, head } // Prepare aggregated vote signature voteAddrSet := make(map[types.BLSPublicKey]struct{}, len(votes)) - signatures := make([][]byte, 0, len(votes)) - for _, vote := range votes { + signatures := make([][]byte, len(votes)) + for i, vote := range votes { voteAddrSet[vote.VoteAddress] = struct{}{} - signatures = append(signatures, vote.Signature[:]) + signatures[i] = vote.Signature[:] } sigs, err := bls.MultipleSignaturesFromBytes(signatures) if err != nil { @@ -1081,28 +1119,25 @@ func (p *Parlia) assembleVoteAttestation(chain consensus.ChainHeaderReader, head } copy(attestation.AggSignature[:], bls.AggregateSignatures(sigs).Marshal()) // Prepare vote address bitset. - for _, valInfo := range snap.Validators { + for _, valInfo := range targetHeaderParentSnap.Validators { if _, ok := voteAddrSet[valInfo.VoteAddress]; ok { attestation.VoteAddressSet |= 1 << (valInfo.Index - 1) // Index is offset by 1 } } - validatorsBitSet := bitset.From([]uint64{uint64(attestation.VoteAddressSet)}) - if validatorsBitSet.Count() < uint(len(signatures)) { - log.Warn(fmt.Sprintf("assembleVoteAttestation, check VoteAddress Set failed, expected:%d, real:%d", len(signatures), validatorsBitSet.Count())) + bitsetCount := bitset.From([]uint64{uint64(attestation.VoteAddressSet)}).Count() + if bitsetCount < uint(len(signatures)) { + log.Warn(fmt.Sprintf("assembleVoteAttestation, check VoteAddress Set failed, expected:%d, real:%d", len(signatures), bitsetCount)) return errors.New("invalid attestation, check VoteAddress Set failed") } - // Append attestation to header extra field. + // === Step 4: Encode & insert into header extra === buf := new(bytes.Buffer) - err = rlp.Encode(buf, attestation) - if err != nil { - return err + if err = rlp.Encode(buf, attestation); err != nil { + return fmt.Errorf("attestation: failed to encode: %w", err) } - - // Insert vote attestation into header extra ahead extra seal. extraSealStart := len(header.Extra) - extraSeal extraSealBytes := header.Extra[extraSealStart:] - header.Extra = append(header.Extra[0:extraSealStart], buf.Bytes()...) + header.Extra = append(header.Extra[:extraSealStart], buf.Bytes()...) header.Extra = append(header.Extra, extraSealBytes...) return nil @@ -1662,7 +1697,7 @@ func (p *Parlia) Delay(chain consensus.ChainReader, header *types.Header, leftOv // The blocking time should be no more than half of period when snap.TurnLength == 1 timeForMining := time.Duration(snap.BlockInterval) * time.Millisecond / 2 if !snap.lastBlockInOneTurn(header.Number.Uint64()) { - timeForMining = time.Duration(snap.BlockInterval) * time.Millisecond * 4 / 5 + timeForMining = time.Duration(snap.BlockInterval) * time.Millisecond } if delay > timeForMining { delay = timeForMining @@ -2371,6 +2406,15 @@ func (p *Parlia) detectNewVersionWithFork(chain consensus.ChainHeaderReader, hea } } +// TODO(Nathan): use kAncestorGenerationDepth directly instead of this func once Fermi hardfork passed +func (p *Parlia) GetAncestorGenerationDepth(header *types.Header) uint64 { + if p.chainConfig.IsFermi(header.Number, header.Time) { + return kAncestorGenerationDepth + } + + return 1 +} + // chain context type chainContext struct { Chain consensus.ChainHeaderReader diff --git a/core/vote/vote_manager.go b/core/vote/vote_manager.go index f231447792..c3db8f0881 100644 --- a/core/vote/vote_manager.go +++ b/core/vote/vote_manager.go @@ -158,10 +158,10 @@ func (voteManager *VoteManager) loop() { if err != nil { log.Debug("failed to get BlockInterval when voting") } - nextBlockMinedTime := time.UnixMilli(int64((curHead.MilliTimestamp() + blockInterval))) - timeForBroadcast := 50 * time.Millisecond // enough to broadcast a vote - if time.Now().Add(timeForBroadcast).After(nextBlockMinedTime) { - log.Warn("too late to vote", "Head.Time(Second)", curHead.Time, "Now(Millisecond)", time.Now().UnixMilli()) + voteAssembledTime := time.UnixMilli(int64((curHead.MilliTimestamp() + p.GetAncestorGenerationDepth(curHead)*blockInterval))) + timeForBroadcast := 50 * time.Millisecond // enough to broadcast a vote in the same region + if time.Now().Add(timeForBroadcast).After(voteAssembledTime) { + log.Warn("too late to vote", "Head.Time(Millisecond)", curHead.MilliTimestamp(), "Now(Millisecond)", time.Now().UnixMilli()) continue } } diff --git a/core/vote/vote_pool.go b/core/vote/vote_pool.go index ee4dace62f..c7a0cf7b3d 100644 --- a/core/vote/vote_pool.go +++ b/core/vote/vote_pool.go @@ -345,7 +345,7 @@ func (pool *VotePool) GetVotes() []*types.VoteEnvelope { return votesRes } -func (pool *VotePool) FetchVoteByBlockHash(blockHash common.Hash) []*types.VoteEnvelope { +func (pool *VotePool) FetchVotesByBlockHash(blockHash common.Hash) []*types.VoteEnvelope { pool.mu.RLock() defer pool.mu.RUnlock() if _, ok := pool.curVotes[blockHash]; ok { diff --git a/eth/handler_test.go b/eth/handler_test.go index 03cbe7803d..a09285443d 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -376,7 +376,7 @@ func (t *testVotePool) PutVote(vote *types.VoteEnvelope) { t.voteFeed.Send(core.NewVoteEvent{Vote: vote}) } -func (t *testVotePool) FetchVoteByBlockHash(blockHash common.Hash) []*types.VoteEnvelope { +func (t *testVotePool) FetchVotesByBlockHash(blockHash common.Hash) []*types.VoteEnvelope { panic("implement me") }