diff --git a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregator.java b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregator.java index d51567aa197..423c3744de5 100644 --- a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregator.java +++ b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregator.java @@ -79,11 +79,5 @@ static AttestationBitsAggregator of( boolean requiresCommitteeBits(); /** Creates an independent copy of this instance */ - default AttestationBitsAggregator copy() { - if (requiresCommitteeBits()) { - return new AttestationBitsAggregatorElectra( - getAggregationBits(), getCommitteeBits(), getCommitteesSize()); - } - return new AttestationBitsAggregatorPhase0(getAggregationBits()); - } + AttestationBitsAggregator copy(); } diff --git a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java index 13f5084d61c..a4b7c6d2893 100644 --- a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java +++ b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java @@ -52,6 +52,19 @@ class AttestationBitsAggregatorElectra implements AttestationBitsAggregator { parseAggregationBits(initialAggregationBits, this.committeeBits, this.committeesSize); } + private AttestationBitsAggregatorElectra( + final SszBitlistSchema aggregationBitsSchema, + final SszBitvectorSchema committeeBitsSchema, + final Int2IntMap committeesSize, + final Int2ObjectMap committeeAggregationBitsMap, + final BitSet committeeBits) { + this.aggregationBitsSchema = aggregationBitsSchema; + this.committeeBitsSchema = committeeBitsSchema; + this.committeesSize = Objects.requireNonNull(committeesSize, "committeesSize cannot be null"); + this.committeeBits = committeeBits; + this.committeeAggregationBitsMap = committeeAggregationBitsMap; + } + static AttestationBitsAggregator fromAttestationSchema( final AttestationSchema attestationSchema, final Int2IntMap committeesSize) { final SszBitlist emptyAggregationBits = attestationSchema.createEmptyAggregationBits(); @@ -76,13 +89,14 @@ private static Int2ObjectMap parseAggregationBits( committeeIndex = committeeIndices.nextSetBit(committeeIndex + 1)) { final int committeeSize = committeesSizeMap.getOrDefault(committeeIndex, 0); - if (committeeSize > 0) { - int sliceEnd = - Math.min(currentOffset + committeeSize, aggregationBits.getLastSetBitIndex() + 1); - final BitSet committeeBits = aggregationBits.getAsBitSet(currentOffset, sliceEnd); - result.put(committeeIndex, committeeBits); + if (committeeSize == 0) { + throw new IllegalArgumentException( + "Committee size for committee " + committeeIndex + " not found"); } - currentOffset += committeeSize; // Always advance by the declared committee size + final BitSet committeeBits = + aggregationBits.getAsBitSet(currentOffset, currentOffset + committeeSize); + result.put(committeeIndex, committeeBits); + currentOffset += committeeSize; } return result; } @@ -113,6 +127,16 @@ public void or(final Attestation other) { performMerge(otherCommitteeBits, otherParsedAggregationMap, false); } + private static Int2ObjectMap cloneCommitteeAggregationBitsMap( + final Int2ObjectMap committeeAggregationBitsMap) { + final Int2ObjectMap clonedMap = new Int2ObjectOpenHashMap<>(); + for (final Int2ObjectMap.Entry entry : + committeeAggregationBitsMap.int2ObjectEntrySet()) { + clonedMap.put(entry.getIntKey(), (BitSet) entry.getValue().clone()); + } + return clonedMap; + } + private boolean performMerge( final BitSet otherCommitteeBits, final Int2ObjectMap otherCommitteeAggregationBitsMap, @@ -124,11 +148,7 @@ private boolean performMerge( if (isAggregation) { // If aggregating, we need to work on copies - targetAggregationBitsMap = new Int2ObjectOpenHashMap<>(); - for (final Int2ObjectMap.Entry entry : - this.committeeAggregationBitsMap.int2ObjectEntrySet()) { - targetAggregationBitsMap.put(entry.getIntKey(), (BitSet) entry.getValue().clone()); - } + targetAggregationBitsMap = cloneCommitteeAggregationBitsMap(this.committeeAggregationBitsMap); } else { // if not aggregating, we can modify in place targetAggregationBitsMap = this.committeeAggregationBitsMap; @@ -256,7 +276,8 @@ public SszBitlist getAggregationBits() { public SszBitvector getCommitteeBits() { if (cachedCommitteeBits == null) { cachedCommitteeBits = - committeeBitsSchema.wrapBitSet(committeeBitsSchema.getLength(), this.committeeBits); + committeeBitsSchema.wrapBitSet( + committeeBitsSchema.getLength(), (BitSet) this.committeeBits.clone()); } return cachedCommitteeBits; } @@ -276,6 +297,16 @@ public boolean requiresCommitteeBits() { return true; } + @Override + public AttestationBitsAggregator copy() { + return new AttestationBitsAggregatorElectra( + aggregationBitsSchema, + committeeBitsSchema, + committeesSize, + cloneCommitteeAggregationBitsMap(committeeAggregationBitsMap), + (BitSet) committeeBits.clone()); + } + @Override public String toString() { long totalSetBits = 0; diff --git a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorPhase0.java b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorPhase0.java index 65acd1e3aef..87b9ec1d135 100644 --- a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorPhase0.java +++ b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorPhase0.java @@ -80,6 +80,11 @@ public boolean requiresCommitteeBits() { return false; } + @Override + public AttestationBitsAggregator copy() { + return new AttestationBitsAggregatorPhase0(aggregationBits); + } + @Override public String toString() { return MoreObjects.toStringHelper(this).add("aggregationBits", aggregationBits).toString(); diff --git a/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java b/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java index 586d159e421..34009aae4f6 100644 --- a/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java +++ b/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java @@ -18,6 +18,7 @@ import static org.mockito.Mockito.when; import it.unimi.dsi.fastutil.ints.Int2IntMap; +import it.unimi.dsi.fastutil.ints.Int2IntMap.Entry; import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; import java.util.List; import java.util.Optional; @@ -643,6 +644,55 @@ void isSuperSetOf6() { assertThat(aggregator.isSuperSetOf(otherAttestation.getAttestation())).isTrue(); } + @Test + void getAggregationBits_shouldBeConsistent_singleCommittee() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(0), 0); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + assertThat(aggregator.getAggregationBits().size()).isEqualTo(committeeSizes.get(0)); + + assertThat(aggregator.getAggregationBits()) + .isEqualTo(initialAttestation.getAttestation().getAggregationBits()); + } + + @Test + void getAggregationBits_shouldBeConsistent_multiCommittee() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(0, 1), 0, 3); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + assertThat(aggregator.getAggregationBits().size()) + .isEqualTo(committeeSizes.get(0) + committeeSizes.get(1)); + + assertThat(aggregator.getAggregationBits()) + .isEqualTo(initialAttestation.getAttestation().getAggregationBits()); + } + + @Test + void copy_shouldNotModifyOriginal() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(0), 0); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + // check aggregator is initialized correctly + assertThat(aggregator.getCommitteeBits()) + .isEqualTo(initialAttestation.getAttestation().getCommitteeBitsRequired()); + assertThat(aggregator.getAggregationBits()) + .isEqualTo(initialAttestation.getAttestation().getAggregationBits()); + + final AttestationBitsAggregator copy = aggregator.copy(); + + assertThat(copy.getCommitteeBits()).isEqualTo(aggregator.getCommitteeBits()); + assertThat(copy.getAggregationBits()).isEqualTo(aggregator.getAggregationBits()); + assertThat(copy).isNotSameAs(aggregator); + + assertThat(copy.aggregateWith(createAttestation(List.of(1), 1).getAttestation())).isTrue(); + + // the original should not be modified + assertThat(aggregator.getCommitteeBits()) + .isEqualTo(initialAttestation.getAttestation().getCommitteeBitsRequired()); + assertThat(aggregator.getAggregationBits()) + .isEqualTo(initialAttestation.getAttestation().getAggregationBits()); + } + private ValidatableAttestation createAttestation(final String commBits, final String aggBits) { assertThat(commBits).matches(Pattern.compile("^[0-1]+$")); assertThat(aggBits).matches(Pattern.compile("^[0-1]+$")); @@ -664,7 +714,12 @@ private ValidatableAttestation createAttestation( final SszBitlist aggregationBits = attestationSchema .getAggregationBitsSchema() - .ofBits(committeeSizes.values().intStream().sum(), validators); + .ofBits( + committeeSizes.int2IntEntrySet().stream() + .filter(entry -> committeeIndices.contains(entry.getIntKey())) + .mapToInt(Entry::getIntValue) + .sum(), + validators); final Supplier committeeBits = () -> attestationSchema.getCommitteeBitsSchema().orElseThrow().ofBits(committeeIndices);