diff --git a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPool.java b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPool.java index b9e658cc645..5030b4cc6ce 100644 --- a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPool.java +++ b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPool.java @@ -15,6 +15,7 @@ import it.unimi.dsi.fastutil.ints.Int2IntMap; import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -26,6 +27,7 @@ import java.util.TreeMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; +import java.util.stream.Stream; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.tuweni.bytes.Bytes; @@ -57,6 +59,11 @@ public class AggregatingAttestationPool implements SlotEventsChannel { /** The valid attestation retention period is 64 slots in deneb */ static final long ATTESTATION_RETENTION_SLOTS = 64; + static final Comparator ATTESTATION_INCLUSION_COMPARATOR = + Comparator.comparingInt( + attestation -> attestation.getAggregationBits().getBitCount()) + .reversed(); + /** * Default maximum number of attestations to store in the pool. * @@ -238,22 +245,20 @@ public synchronized SszList getAttestationsForBlock( schemaDefinitions.getAttestationSchema().requiresCommitteeBits(); final AtomicInteger prevEpochCount = new AtomicInteger(0); + return dataHashBySlot // We can immediately skip any attestations from the block slot or later .headMap(stateAtBlockSlot.getSlot(), false) .descendingMap() .values() .stream() - .flatMap(Collection::stream) - .map(attestationGroupByDataHash::get) - .filter(Objects::nonNull) - .filter(group -> isValid(stateAtBlockSlot, group.getAttestationData())) - .filter(forkChecker::areAttestationsFromCorrectFork) - .flatMap(MatchingDataAttestationGroup::stream) - .map(ValidatableAttestation::getAttestation) - .filter( - attestation -> - attestation.requiresCommitteeBits() == blockRequiresAttestationsWithCommitteeBits) + .flatMap( + dataHashSetForSlot -> + streamAggregatesForDataHashesBySlot( + dataHashSetForSlot, + stateAtBlockSlot, + forkChecker, + blockRequiresAttestationsWithCommitteeBits)) .limit(attestationsSchema.getMaxLength()) .filter( attestation -> { @@ -267,6 +272,25 @@ public synchronized SszList getAttestationsForBlock( .collect(attestationsSchema.collector()); } + private Stream streamAggregatesForDataHashesBySlot( + final Set dataHashSetForSlot, + final BeaconState stateAtBlockSlot, + final AttestationForkChecker forkChecker, + final boolean blockRequiresAttestationsWithCommitteeBits) { + + return dataHashSetForSlot.stream() + .map(attestationGroupByDataHash::get) + .filter(Objects::nonNull) + .filter(group -> isValid(stateAtBlockSlot, group.getAttestationData())) + .filter(forkChecker::areAttestationsFromCorrectFork) + .flatMap(MatchingDataAttestationGroup::stream) + .map(ValidatableAttestation::getAttestation) + .filter( + attestation -> + attestation.requiresCommitteeBits() == blockRequiresAttestationsWithCommitteeBits) + .sorted(ATTESTATION_INCLUSION_COMPARATOR); + } + public synchronized List getAttestations( final Optional maybeSlot, final Optional maybeCommitteeIndex) { diff --git a/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPoolTest.java b/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPoolTest.java index 22d2277ff45..40588f52c70 100644 --- a/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPoolTest.java +++ b/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/AggregatingAttestationPoolTest.java @@ -63,7 +63,7 @@ class AggregatingAttestationPoolTest { public static final UInt64 SLOT = UInt64.valueOf(1234); - private static final int COMMITTEE_SIZE = 20; + private static final int COMMITTEE_SIZE = 130; private Spec spec; private SpecMilestone specMilestone; @@ -260,15 +260,61 @@ void getAttestationsForBlock_shouldIncludeMoreRecentAttestationsFirst() { @TestTemplate public void getAttestationsForBlock_shouldNotAddMoreAttestationsThanAllowedInBlock() { - final BeaconState state = dataStructureUtil.randomBeaconState(ONE); + final int allowed = + Math.toIntExact( + spec.atSlot(ONE) + .getSchemaDefinitions() + .getBeaconBlockBodySchema() + .getAttestationsSchema() + .getMaxLength()); + + final int validatorCount = allowed + 1; + final BeaconState state = dataStructureUtil.randomBeaconState(validatorCount, 100, ONE); final AttestationData attestationData = dataStructureUtil.randomAttestationData(ZERO); - final Attestation attestation1 = addAttestationFromValidators(attestationData, 1, 2, 3, 4); - final Attestation attestation2 = addAttestationFromValidators(attestationData, 2, 5); - // Won't be included because of the 2 attestation limit. - addAttestationFromValidators(attestationData, 2); - assertThat(aggregatingPool.getAttestationsForBlock(state, forkChecker)) - .containsExactly(attestation1, attestation2); + final int lastValidatorIndex = validatorCount - 1; + + // add non aggregatable attestations, more than allowed in block + for (int i = 0; i < validatorCount; i++) { + addAttestationFromValidators(attestationData, i, lastValidatorIndex); + } + + assertThat(aggregatingPool.getAttestationsForBlock(state, forkChecker)).hasSize(allowed); + } + + @TestTemplate + public void getAttestationsForBlock_shouldGivePriorityToBestAggregationForEachSlot() { + // let's test this on electra only, which has only 8 attestations for block + assumeThat(specMilestone).isGreaterThanOrEqualTo(ELECTRA); + assertThat( + spec.atSlot(ONE) + .getSchemaDefinitions() + .getBeaconBlockBodySchema() + .getAttestationsSchema() + .getMaxLength()) + .isEqualTo(8); + + final BeaconState state = dataStructureUtil.randomBeaconState(ONE); + + // let's prepare 2 different attestationData for the same slot + final AttestationData attestationData0 = dataStructureUtil.randomAttestationData(ZERO); + final AttestationData attestationData1 = dataStructureUtil.randomAttestationData(ZERO); + + // let's fill up the pool with non-aggregatable attestationsData0 + addAttestationFromValidators(attestationData0, 1, 2); + addAttestationFromValidators(attestationData0, 1, 3); + addAttestationFromValidators(attestationData0, 1, 4); + addAttestationFromValidators(attestationData0, 1, 5); + addAttestationFromValidators(attestationData0, 1, 6); + addAttestationFromValidators(attestationData0, 1, 7); + addAttestationFromValidators(attestationData0, 1, 8); + addAttestationFromValidators(attestationData0, 1, 9); + + // let's add a better aggregation for attestationData1 + final Attestation bestAttestation = addAttestationFromValidators(attestationData1, 11, 14, 15); + + assertThat(aggregatingPool.getAttestationsForBlock(state, forkChecker).get(0)) + .isEqualTo(bestAttestation); } @TestTemplate