diff --git a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java index 361b12caa40..13f5084d61c 100644 --- a/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java +++ b/ethereum/statetransition/src/main/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectra.java @@ -15,229 +15,255 @@ import com.google.common.base.MoreObjects; import it.unimi.dsi.fastutil.ints.Int2IntMap; -import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; +import java.util.ArrayList; import java.util.BitSet; -import java.util.stream.IntStream; +import java.util.List; +import java.util.Objects; import tech.pegasys.teku.infrastructure.ssz.collections.SszBitlist; import tech.pegasys.teku.infrastructure.ssz.collections.SszBitvector; +import tech.pegasys.teku.infrastructure.ssz.schema.collections.SszBitlistSchema; +import tech.pegasys.teku.infrastructure.ssz.schema.collections.SszBitvectorSchema; import tech.pegasys.teku.spec.datastructures.operations.Attestation; import tech.pegasys.teku.spec.datastructures.operations.AttestationSchema; class AttestationBitsAggregatorElectra implements AttestationBitsAggregator { - private SszBitlist aggregationBits; - private SszBitvector committeeBits; - private Int2IntMap committeeBitsStartingPositions; + + private final SszBitlistSchema aggregationBitsSchema; + private final SszBitvectorSchema committeeBitsSchema; private final Int2IntMap committeesSize; + private Int2ObjectMap committeeAggregationBitsMap; + private BitSet committeeBits; + + private SszBitlist cachedAggregationBits = null; + private SszBitvector cachedCommitteeBits = null; + AttestationBitsAggregatorElectra( - final SszBitlist aggregationBits, - final SszBitvector committeeBits, + final SszBitlist initialAggregationBits, + final SszBitvector initialCommitteeBits, final Int2IntMap committeesSize) { - this.aggregationBits = aggregationBits; - this.committeeBits = committeeBits; - this.committeesSize = committeesSize; - this.committeeBitsStartingPositions = calculateCommitteeStartingPositions(committeeBits); + this.aggregationBitsSchema = initialAggregationBits.getSchema(); + this.committeeBitsSchema = initialCommitteeBits.getSchema(); + this.committeesSize = Objects.requireNonNull(committeesSize, "committeesSize cannot be null"); + this.committeeBits = initialCommitteeBits.getAsBitSet(); + this.committeeAggregationBitsMap = + parseAggregationBits(initialAggregationBits, this.committeeBits, this.committeesSize); } static AttestationBitsAggregator fromAttestationSchema( final AttestationSchema attestationSchema, final Int2IntMap committeesSize) { + final SszBitlist emptyAggregationBits = attestationSchema.createEmptyAggregationBits(); + final SszBitvector emptyCommitteeBits = + attestationSchema + .createEmptyCommitteeBits() + .orElseThrow( + () -> new IllegalStateException("Electra schema must provide committee bits")); return new AttestationBitsAggregatorElectra( - attestationSchema.createEmptyAggregationBits(), - attestationSchema.createEmptyCommitteeBits().orElseThrow(), - committeesSize); + emptyAggregationBits, emptyCommitteeBits, committeesSize); + } + + private static Int2ObjectMap parseAggregationBits( + final SszBitlist aggregationBits, + final BitSet committeeIndices, + final Int2IntMap committeesSizeMap) { + final Int2ObjectMap result = new Int2ObjectOpenHashMap<>(); + + int currentOffset = 0; + for (int committeeIndex = committeeIndices.nextSetBit(0); + committeeIndex >= 0; + committeeIndex = committeeIndices.nextSetBit(committeeIndex + 1)) { + + final int committeeSize = committeesSizeMap.getOrDefault(committeeIndex, 0); + if (committeeSize > 0) { + int sliceEnd = + Math.min(currentOffset + committeeSize, aggregationBits.getLastSetBitIndex() + 1); + final BitSet committeeBits = aggregationBits.getAsBitSet(currentOffset, sliceEnd); + result.put(committeeIndex, committeeBits); + } + currentOffset += committeeSize; // Always advance by the declared committee size + } + return result; } @Override public void or(final AttestationBitsAggregator other) { - or(other.getCommitteeBits(), other.getAggregationBits(), false); + if (!(other instanceof AttestationBitsAggregatorElectra otherElectra)) { + throw new IllegalArgumentException( + "AttestationBitsAggregatorElectra.or requires an argument of the same type."); + } + + performMerge(otherElectra.committeeBits, otherElectra.committeeAggregationBitsMap, false); } @Override public boolean aggregateWith(final Attestation other) { - return or(other.getCommitteeBitsRequired(), other.getAggregationBits(), true); + final BitSet otherCommitteeBits = other.getCommitteeBitsRequired().getAsBitSet(); + final Int2ObjectMap otherParsedAggregationMap = + parseAggregationBits(other.getAggregationBits(), otherCommitteeBits, this.committeesSize); + return performMerge(otherCommitteeBits, otherParsedAggregationMap, true); } @Override public void or(final Attestation other) { - or(other.getCommitteeBitsRequired(), other.getAggregationBits(), false); + final BitSet otherCommitteeBits = other.getCommitteeBitsRequired().getAsBitSet(); + final Int2ObjectMap otherParsedAggregationMap = + parseAggregationBits(other.getAggregationBits(), otherCommitteeBits, this.committeesSize); + performMerge(otherCommitteeBits, otherParsedAggregationMap, false); } - private static class CannotAggregateException extends RuntimeException {} - - private boolean or( - final SszBitvector otherCommitteeBits, - final SszBitlist otherAggregatedBits, + private boolean performMerge( + final BitSet otherCommitteeBits, + final Int2ObjectMap otherCommitteeAggregationBitsMap, final boolean isAggregation) { + final BitSet mergedCommitteeBits = (BitSet) this.committeeBits.clone(); + mergedCommitteeBits.or(otherCommitteeBits); - if (otherCommitteeBits.equals(committeeBits)) { - // If the committee bits are the same, we can directly combine the aggregation bits - if (isAggregation && aggregationBits.intersects(otherAggregatedBits)) { - return false; + final Int2ObjectMap targetAggregationBitsMap; + + if (isAggregation) { + // If aggregating, we need to work on copies + targetAggregationBitsMap = new Int2ObjectOpenHashMap<>(); + for (final Int2ObjectMap.Entry entry : + this.committeeAggregationBitsMap.int2ObjectEntrySet()) { + targetAggregationBitsMap.put(entry.getIntKey(), (BitSet) entry.getValue().clone()); } - aggregationBits = aggregationBits.or(otherAggregatedBits); - return true; + } else { + // if not aggregating, we can modify in place + targetAggregationBitsMap = this.committeeAggregationBitsMap; } - final SszBitvector combinedCommitteeBits = committeeBits.or(otherCommitteeBits); - - final Int2IntMap otherCommitteeBitsStartingPositions = - calculateCommitteeStartingPositions(otherCommitteeBits); - final Int2IntMap aggregatedCommitteeBitsStartingPositions = - calculateCommitteeStartingPositions(combinedCommitteeBits); - - // create an aggregation bit big as last boundary for last committee bit - final int lastCommitteeIndex = combinedCommitteeBits.getLastSetBitIndex(); - final int lastCommitteeStartingPosition = - aggregatedCommitteeBitsStartingPositions.get(lastCommitteeIndex); - final int combinedAggregationBitsSize = - lastCommitteeStartingPosition + committeesSize.get(lastCommitteeIndex); - - final BitSet combinedAggregationIndices = new BitSet(combinedAggregationBitsSize); - - // let's go over all aggregated committees to calculate indices for the combined aggregation - // bits - try { - combinedCommitteeBits - .streamAllSetBits() - .forEach( - committeeIndex -> { - int committeeSize = committeesSize.get(committeeIndex); - int destinationStart = aggregatedCommitteeBitsStartingPositions.get(committeeIndex); - - SszBitlist source1 = null, maybeSource2 = null; - int source1StartingPosition = 0, source2StartingPosition = 0; - - if (committeeBitsStartingPositions.containsKey(committeeIndex)) { - source1 = aggregationBits; - source1StartingPosition = committeeBitsStartingPositions.get(committeeIndex); - } - if (otherCommitteeBitsStartingPositions.containsKey(committeeIndex)) { - if (source1 != null) { - maybeSource2 = otherAggregatedBits; - source2StartingPosition = - otherCommitteeBitsStartingPositions.get(committeeIndex); - } else { - source1 = otherAggregatedBits; - source1StartingPosition = - otherCommitteeBitsStartingPositions.get(committeeIndex); - } - } - - // Now that we know: - // 1. which aggregationBits (this or other or both) will contribute to the result - // 2. the offset of the committee for each contributing aggregation bits - // We can go over the committee and calculate the combined aggregate bits - for (int positionInCommittee = 0; - positionInCommittee < committeeSize; - positionInCommittee++) { - if (orSingleBit( - positionInCommittee, - source1, - source1StartingPosition, - maybeSource2, - source2StartingPosition, - isAggregation)) { - combinedAggregationIndices.set(destinationStart + positionInCommittee); - } - } - }); - } catch (final CannotAggregateException __) { - return false; - } + for (int committeeIndex = mergedCommitteeBits.nextSetBit(0); + committeeIndex >= 0; + committeeIndex = mergedCommitteeBits.nextSetBit(committeeIndex + 1)) { - committeeBits = combinedCommitteeBits; - aggregationBits = - aggregationBits - .getSchema() - .wrapBitSet(combinedAggregationBitsSize, combinedAggregationIndices); - committeeBitsStartingPositions = aggregatedCommitteeBitsStartingPositions; + final boolean inThis = this.committeeBits.get(committeeIndex); + final boolean inOther = otherCommitteeBits.get(committeeIndex); - return true; - } + if (inThis && inOther) { + final BitSet otherAggregationBitsForCommittee = + otherCommitteeAggregationBitsMap.get(committeeIndex); + final BitSet targetAggregationBitsForCommittee = + targetAggregationBitsMap.get(committeeIndex); - private boolean orSingleBit( - final int positionInCommittee, - final SszBitlist source1, - final int source1StartingPosition, - final SszBitlist maybeSource2, - final int source2StartingPosition, - final boolean isAggregation) { - - final boolean source1Bit = source1.getBit(source1StartingPosition + positionInCommittee); + if (isAggregation) { + // For intersection check, use the original bits of 'this' + final BitSet thisAggregationBitsForCommittee = + this.committeeAggregationBitsMap.get(committeeIndex); + if (thisAggregationBitsForCommittee != null + && thisAggregationBitsForCommittee.intersects(otherAggregationBitsForCommittee)) { + return false; + } + } - if (maybeSource2 == null) { - return source1Bit; - } + targetAggregationBitsForCommittee.or(otherAggregationBitsForCommittee); - final boolean source2Bit = maybeSource2.getBit(source2StartingPosition + positionInCommittee); + } else if (inOther) { + // Committee only in 'other'. + final BitSet otherDataForCommittee = otherCommitteeAggregationBitsMap.get(committeeIndex); - if (isAggregation && source1Bit && source2Bit) { - throw new CannotAggregateException(); + targetAggregationBitsMap.put(committeeIndex, (BitSet) otherDataForCommittee.clone()); + } + // Committee only in 'this', do nothing. } - return source1Bit || source2Bit; - } + this.committeeBits = mergedCommitteeBits; + if (isAggregation) { + this.committeeAggregationBitsMap = targetAggregationBitsMap; + } - private Int2IntMap calculateCommitteeStartingPositions(final SszBitvector committeeBits) { - final Int2IntMap committeeBitsStartingPositions = new Int2IntOpenHashMap(); - final int[] currentOffset = {0}; - committeeBits - .streamAllSetBits() - .forEach( - index -> { - committeeBitsStartingPositions.put(index, currentOffset[0]); - currentOffset[0] += committeesSize.get(index); - }); - - return committeeBitsStartingPositions; + invalidateCache(); + return true; } @Override public boolean isSuperSetOf(final Attestation other) { - if (committeeBits.equals(other.getCommitteeBitsRequired())) { - return aggregationBits.isSuperSetOf(other.getAggregationBits()); - } + final BitSet otherInternalCommitteeBits = other.getCommitteeBitsRequired().getAsBitSet(); - if (!committeeBits.isSuperSetOf(other.getCommitteeBitsRequired())) { + final BitSet committeeIntersection = (BitSet) this.committeeBits.clone(); + committeeIntersection.and(otherInternalCommitteeBits); + if (!committeeIntersection.equals(otherInternalCommitteeBits)) { return false; } - final SszBitvector otherCommitteeBits = other.getCommitteeBitsRequired(); + final Int2ObjectMap otherCommitteeAggregationBitsMap = + parseAggregationBits( + other.getAggregationBits(), otherInternalCommitteeBits, this.committeesSize); - final Int2IntMap otherCommitteeBitsStartingPositions = - calculateCommitteeStartingPositions(otherCommitteeBits); + for (int committeeIndex = otherInternalCommitteeBits.nextSetBit(0); + committeeIndex >= 0; + committeeIndex = otherInternalCommitteeBits.nextSetBit(committeeIndex + 1)) { - final SszBitvector commonCommittees = committeeBits.and(otherCommitteeBits); + final BitSet thisAggregationBitsForCommittee = + this.committeeAggregationBitsMap.get(committeeIndex); + final BitSet otherAggregationBitsForCommittee = + otherCommitteeAggregationBitsMap.get(committeeIndex); - return commonCommittees - .streamAllSetBits() - .mapToObj( - committeeIndex -> { - int committeeSize = committeesSize.get(committeeIndex); - - final int startingPosition = committeeBitsStartingPositions.get(committeeIndex); - final int otherStartingPosition = - otherCommitteeBitsStartingPositions.get(committeeIndex); + if (thisAggregationBitsForCommittee == null) { + return false; + } + if (otherAggregationBitsForCommittee == null || otherAggregationBitsForCommittee.isEmpty()) { + continue; + } - return IntStream.range(0, committeeSize) - .anyMatch( - positionInCommittee -> - other - .getAggregationBits() - .getBit(otherStartingPosition + positionInCommittee) - && !aggregationBits.getBit(startingPosition + positionInCommittee)); - }) - .noneMatch(aBitFoundInOtherButNotInThis -> aBitFoundInOtherButNotInThis); + final BitSet otherBitsNotInThis = (BitSet) otherAggregationBitsForCommittee.clone(); + otherBitsNotInThis.andNot(thisAggregationBitsForCommittee); + if (!otherBitsNotInThis.isEmpty()) { + return false; + } + } + return true; } @Override public SszBitlist getAggregationBits() { - return aggregationBits; + if (cachedAggregationBits != null) { + return cachedAggregationBits; + } + final List committeeIndicesInOrder = new ArrayList<>(); + for (int i = committeeBits.nextSetBit(0); i >= 0; i = committeeBits.nextSetBit(i + 1)) { + committeeIndicesInOrder.add(i); + } + // No explicit sort needed as BitSet.nextSetBit() iterates in ascending order. + + int totalBitlistSize = 0; + for (final int committeeIndex : committeeIndicesInOrder) { + totalBitlistSize += this.committeesSize.getOrDefault(committeeIndex, 0); + } + + final BitSet combinedAggregationBits = new BitSet(totalBitlistSize); + int currentOffset = 0; + for (final int committeeIndex : committeeIndicesInOrder) { + final BitSet committeeBitsData = this.committeeAggregationBitsMap.get(committeeIndex); + final int committeeSize = this.committeesSize.getOrDefault(committeeIndex, 0); + + if (committeeBitsData != null && committeeSize > 0) { + for (int bitIndex = committeeBitsData.nextSetBit(0); + bitIndex >= 0 && bitIndex < committeeSize; + bitIndex = committeeBitsData.nextSetBit(bitIndex + 1)) { + combinedAggregationBits.set(currentOffset + bitIndex); + } + } + currentOffset += committeeSize; + } + cachedAggregationBits = + aggregationBitsSchema.wrapBitSet(totalBitlistSize, combinedAggregationBits); + return cachedAggregationBits; } @Override public SszBitvector getCommitteeBits() { - return committeeBits; + if (cachedCommitteeBits == null) { + cachedCommitteeBits = + committeeBitsSchema.wrapBitSet(committeeBitsSchema.getLength(), this.committeeBits); + } + return cachedCommitteeBits; + } + + private void invalidateCache() { + this.cachedAggregationBits = null; + this.cachedCommitteeBits = null; } @Override @@ -252,11 +278,19 @@ public boolean requiresCommitteeBits() { @Override public String toString() { + long totalSetBits = 0; + if (committeeAggregationBitsMap != null) { + for (final BitSet bitSet : committeeAggregationBitsMap.values()) { + if (bitSet != null) { + totalSetBits += bitSet.cardinality(); + } + } + } return MoreObjects.toStringHelper(this) - .add("aggregationBits", aggregationBits) - .add("committeeBits", committeeBits) - .add("committeesSize", committeesSize) - .add("committeeBitsStartingPositions", committeeBitsStartingPositions) + .add("committeeBits", committeeBits.cardinality()) + .add("committeeAggregationBitsMap", totalSetBits) + .add("committeesSize", committeesSize.size()) + .add("cached", cachedAggregationBits != null || cachedCommitteeBits != null) .toString(); } } diff --git a/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java b/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java index c3b6fc2174e..586d159e421 100644 --- a/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java +++ b/ethereum/statetransition/src/test/java/tech/pegasys/teku/statetransition/attestation/utils/AttestationBitsAggregatorElectraTest.java @@ -232,6 +232,34 @@ void aggregateOnMultipleOverlappingCommitteeBitsButWithSomeOfAggregationOverlapp assertThat(aggregator.getAggregationBits().streamAllSetBits()).containsExactly(0, 1, 2, 3, 8); } + @Test + void aggregateSingleAttestationFillUp() { + /* + 01|234 <- committee 0 and 1 indices + 10|100 <- bits + */ + final ValidatableAttestation initialAttestation = createAttestation(List.of(0, 1), 0, 2); + + /* + 123 <- committee 1 indices + 001 <- bits + */ + final ValidatableAttestation otherAttestation = createAttestation(List.of(1), 2); + + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + // calculate the or + aggregator.or(otherAttestation.getAttestation()); + + /* + 01|234 <- committee 0 and 1 indices + 10|101 <- bits + */ + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactly(0, 1); + assertThat(aggregator.getAggregationBits().streamAllSetBits()).containsExactly(0, 2, 4); + } + @Test void aggregateOnMultipleOverlappingCommitteeBitsButWithSomeOfAggregationOverlappingWhenNoCheck2() { @@ -393,6 +421,107 @@ void bigAggregation() { .isEqualTo(result.getAttestation().getAggregationBits()); } + @Test + void orIntoEmptyAggregator() { + final AttestationBitsAggregator aggregator = + AttestationBitsAggregator.fromEmptyFromAttestationSchema( + attestationSchema, Optional.of(committeeSizes)); + + final ValidatableAttestation otherAttestation = createAttestation(List.of(1), 2); + + aggregator.or(otherAttestation.getAttestation()); + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactly(1); + assertThat(aggregator.getAggregationBits().streamAllSetBits()).containsExactly(2); + } + + @Test + void orWithNewCommittee() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(0), 1); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + final ValidatableAttestation otherAttestation = createAttestation(List.of(1), 2); + + aggregator.or(otherAttestation.getAttestation()); + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactlyInAnyOrder(0, 1); + assertThat(aggregator.getAggregationBits().streamAllSetBits()).containsExactlyInAnyOrder(1, 4); + } + + @Test + void orWithExistingCommitteeAddNewBits() { + + final ValidatableAttestation initialAttestation = createAttestation(List.of(1), 0); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + final ValidatableAttestation otherAttestation = createAttestation(List.of(1), 2); + + aggregator.or(otherAttestation.getAttestation()); + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactly(1); + assertThat(aggregator.getAggregationBits().streamAllSetBits()).containsExactlyInAnyOrder(0, 2); + } + + @Test + void orWithExistingCommitteeOverlapAndNewBits() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(1), 0, 1); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + final ValidatableAttestation otherAttestation = createAttestation(List.of(1), 1, 2); + + aggregator.or(otherAttestation.getAttestation()); + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactly(1); + assertThat(aggregator.getAggregationBits().streamAllSetBits()) + .containsExactlyInAnyOrder(0, 1, 2); + } + + @Test + void orWithStrictSubsetAttestation() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(1), 0, 2); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + final ValidatableAttestation otherAttestation = createAttestation(List.of(1), 0); + + aggregator.or(otherAttestation.getAttestation()); + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactly(1); + assertThat(aggregator.getAggregationBits().streamAllSetBits()).containsExactlyInAnyOrder(0, 2); + } + + @Test + void orWithMultipleCommitteesMixedNewAndExisting() { + final ValidatableAttestation initialAttestation = createAttestation(List.of(0, 1), 0, 3); + final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation); + + final ValidatableAttestation otherAttestation = createAttestation(List.of(1, 2), 2, 5); + + aggregator.or(otherAttestation.getAttestation()); + + assertThat(aggregator.getCommitteeBits().streamAllSetBits()).containsExactlyInAnyOrder(0, 1, 2); + assertThat(aggregator.getAggregationBits().streamAllSetBits()) + .containsExactlyInAnyOrder(0, 3, 4, 7); + } + + @Test + void orAggregatorWithAggregator() { + final ValidatableAttestation att1Data = createAttestation(List.of(0), 0); + final AttestationBitsAggregator aggregator1 = AttestationBitsAggregator.of(att1Data); + + final ValidatableAttestation att2Data = createAttestation(List.of(0, 1), 1, 2); + final AttestationBitsAggregator aggregator2 = AttestationBitsAggregator.of(att2Data); + + aggregator1.or(aggregator2); + + assertThat(aggregator1.getCommitteeBits().streamAllSetBits()).containsExactlyInAnyOrder(0, 1); + assertThat(aggregator1.getAggregationBits().streamAllSetBits()) + .containsExactlyInAnyOrder(0, 1, 2); + + // aggregator2 should remain unchanged + assertThat(aggregator2.getCommitteeBits().streamAllSetBits()).containsExactlyInAnyOrder(0, 1); + assertThat(aggregator2.getAggregationBits().streamAllSetBits()).containsExactlyInAnyOrder(1, 2); + } + @Test void isSuperSetOf1() { diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlist.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlist.java index cf99656e69f..ca0c769fa0d 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlist.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlist.java @@ -14,6 +14,7 @@ package tech.pegasys.teku.infrastructure.ssz.collections; import it.unimi.dsi.fastutil.ints.IntList; +import java.util.BitSet; import java.util.stream.IntStream; import javax.annotation.Nullable; import tech.pegasys.teku.infrastructure.ssz.collections.impl.SszBitlistImpl; @@ -43,6 +44,12 @@ default boolean isWritableSupported() { // Bitlist methods + BitSet getAsBitSet(); + + BitSet getAsBitSet(int start, int end); + + int getLastSetBitIndex(); + /** * Performs a logical OR of this bit list with the bit list argument. * diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java index 64dfe8bb388..edc51be9bc9 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvector.java @@ -14,6 +14,7 @@ package tech.pegasys.teku.infrastructure.ssz.collections; import it.unimi.dsi.fastutil.ints.IntList; +import java.util.BitSet; import java.util.stream.IntStream; import tech.pegasys.teku.infrastructure.ssz.primitive.SszBit; import tech.pegasys.teku.infrastructure.ssz.schema.collections.SszBitvectorSchema; @@ -48,6 +49,8 @@ default boolean isWritableSupported() { /** Returns the number of bits set to {@code true} in this {@code SszBitlist}. */ int getBitCount(); + BitSet getAsBitSet(); + @Override default boolean isSet(final int i) { return i < size() && getBit(i); diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitlistImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitlistImpl.java index 134a704458c..a0c15c60cde 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitlistImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitlistImpl.java @@ -51,12 +51,10 @@ public BitlistImpl(final int size, final long maxSize, final int... bitIndices) } } - public BitlistImpl(final int size, final long maxSize, final BitSet bitSet) { + public static BitlistImpl wrapBitSet(final int size, final long maxSize, final BitSet bitSet) { checkArgument(size >= 0, "Negative size"); checkArgument(maxSize >= size, "maxSize should be >= size"); - this.size = size; - this.data = bitSet; - this.maxSize = maxSize; + return new BitlistImpl(size, bitSet, maxSize); } private BitlistImpl(final int size, final BitSet data, final long maxSize) { @@ -65,6 +63,18 @@ private BitlistImpl(final int size, final BitSet data, final long maxSize) { this.maxSize = maxSize; } + public BitSet getAsBitSet() { + return (BitSet) data.clone(); + } + + public BitSet getAsBitSet(final int start, final int end) { + return data.get(start, end); + } + + public int getLastSetBitIndex() { + return data.length() - 1; + } + /** * Returns new instance of this BitlistImpl with set bits from the other BitlistImpl * diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java index df546bb9b4a..e9ea75160da 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/BitvectorImpl.java @@ -43,6 +43,12 @@ public static BitvectorImpl fromBytes(final Bytes bytes, final int size) { return new BitvectorImpl(bitset, size); } + public static BitvectorImpl wrapBitSet(final BitSet bitSet, final int size) { + final int length = bitSet.length(); + checkArgument(length <= size, "BitSet length (%s) exceeds the size (%s)", length, size); + return new BitvectorImpl(bitSet, size); + } + public static int sszSerializationLength(final int size) { return bitsCeilToBytes(size); } @@ -112,6 +118,10 @@ public boolean getBit(final int i) { return data.get(i); } + public BitSet getAsBitSet() { + return (BitSet) data.clone(); + } + public int getSize() { return size; } diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitlistImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitlistImpl.java index f7f88442498..2ea5d624924 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitlistImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitlistImpl.java @@ -75,7 +75,7 @@ public static SszBitlistImpl ofBits( public static SszBitlistImpl wrapBitSet( final SszBitlistSchema schema, final int size, final BitSet bitSet) { - return new SszBitlistImpl(schema, new BitlistImpl(size, schema.getMaxLength(), bitSet)); + return new SszBitlistImpl(schema, BitlistImpl.wrapBitSet(size, schema.getMaxLength(), bitSet)); } private final BitlistImpl value; @@ -90,6 +90,21 @@ public SszBitlistImpl(final SszListSchema schema, final BitlistImpl v this.value = value; } + @Override + public BitSet getAsBitSet() { + return value.getAsBitSet(); + } + + @Override + public BitSet getAsBitSet(final int start, final int end) { + return value.getAsBitSet(start, end); + } + + @Override + public int getLastSetBitIndex() { + return value.getLastSetBitIndex(); + } + @SuppressWarnings("unchecked") @Override public SszBitlistSchema getSchema() { diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java index 6eeb39d6bf3..e4ae96ffec6 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/collections/impl/SszBitvectorImpl.java @@ -17,6 +17,7 @@ import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; +import java.util.BitSet; import java.util.stream.IntStream; import org.apache.tuweni.bytes.Bytes; import tech.pegasys.teku.infrastructure.ssz.cache.IntCache; @@ -35,6 +36,11 @@ public static SszBitvectorImpl ofBits(final SszBitvectorSchema schema, final return new SszBitvectorImpl(schema, new BitvectorImpl(schema.getLength(), bits)); } + public static SszBitvectorImpl wrapBitSet( + final SszBitvectorSchema schema, final int size, final BitSet bitSet) { + return new SszBitvectorImpl(schema, BitvectorImpl.wrapBitSet(bitSet, size)); + } + public static SszBitvector fromBytes( final SszBitvectorSchema schema, final Bytes value, final int size) { return new SszBitvectorImpl(schema, BitvectorImpl.fromBytes(value, size)); @@ -58,6 +64,11 @@ public SszBitvectorImpl(final SszBitvectorSchema schema, final BitvectorImpl this.value = value; } + @Override + public BitSet getAsBitSet() { + return value.getAsBitSet(); + } + @SuppressWarnings("unchecked") @Override public SszBitvectorSchema getSchema() { diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitlistSchema.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitlistSchema.java index 5c3bb6d5d2f..19e4e549cce 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitlistSchema.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitlistSchema.java @@ -33,13 +33,9 @@ default SszBitlistT empty() { SszBitlistT ofBits(int size, int... setBitIndices); /** - * Creates a SszBitlist by wrapping a given bitSet. This is an optimized constructor that DOES NOT - * clone the bitSet. It used in aggregating attestation pool. DO NOT MUTATE the after the + * Creates an SszBitlist by wrapping a given bitSet. This is an optimized constructor that DOES + * NOT clone the bitSet. It is used in aggregating attestation pool. DO NOT MUTATE after the * wrapping!! SszBitlist is supposed to be immutable. - * - * @param size size of the SszBitlist - * @param bitSet data backing the ssz - * @return SszBitlist instance */ SszBitlistT wrapBitSet(int size, BitSet bitSet); diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitvectorSchema.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitvectorSchema.java index 7c60370b68e..78ea66e0c9a 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitvectorSchema.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/SszBitvectorSchema.java @@ -13,6 +13,7 @@ package tech.pegasys.teku.infrastructure.ssz.schema.collections; +import java.util.BitSet; import java.util.stream.StreamSupport; import org.apache.tuweni.bytes.Bytes; import tech.pegasys.teku.infrastructure.ssz.collections.SszBitvector; @@ -28,6 +29,13 @@ static SszBitvectorSchema create(final long length) { SszBitvectorT ofBits(int... setBitIndices); + /** + * Creates an SszBitvector by wrapping a given bitSet. This is an optimized constructor that DOES + * NOT clone the bitSet. It is used in aggregating attestation pool. DO NOT MUTATE after the + * wrapping!! SszBitvector is supposed to be immutable. + */ + SszBitvectorT wrapBitSet(int size, BitSet bitSet); + default SszBitvectorT fromBytes(final Bytes bitvectorBytes) { return sszDeserialize(bitvectorBytes); } diff --git a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/impl/SszBitvectorSchemaImpl.java b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/impl/SszBitvectorSchemaImpl.java index 92341b82000..c400d429f68 100644 --- a/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/impl/SszBitvectorSchemaImpl.java +++ b/infrastructure/ssz/src/main/java/tech/pegasys/teku/infrastructure/ssz/schema/collections/impl/SszBitvectorSchemaImpl.java @@ -16,6 +16,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static tech.pegasys.teku.infrastructure.ssz.schema.json.SszPrimitiveTypeDefinitions.sszSerializedType; +import java.util.BitSet; import java.util.List; import java.util.stream.IntStream; import tech.pegasys.teku.infrastructure.json.types.DeserializableTypeDefinition; @@ -54,6 +55,11 @@ public SszBitvector ofBits(final int... setBitIndices) { return SszBitvectorImpl.ofBits(this, setBitIndices); } + @Override + public SszBitvector wrapBitSet(final int size, final BitSet bitSet) { + return SszBitvectorImpl.wrapBitSet(this, size, bitSet); + } + @Override public SszBitvector createFromElements(final List elements) { return ofBits(IntStream.range(0, elements.size()).filter(i -> elements.get(i).get()).toArray()); diff --git a/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlistTest.java b/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlistTest.java index 9917ed7d338..3dd5280f79c 100644 --- a/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlistTest.java +++ b/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitlistTest.java @@ -31,6 +31,7 @@ import org.junit.jupiter.params.provider.MethodSource; import tech.pegasys.teku.infrastructure.crypto.Hash; import tech.pegasys.teku.infrastructure.ssz.SszCollection; +import tech.pegasys.teku.infrastructure.ssz.SszDataAssert; import tech.pegasys.teku.infrastructure.ssz.SszList; import tech.pegasys.teku.infrastructure.ssz.SszTestUtils; import tech.pegasys.teku.infrastructure.ssz.impl.AbstractSszPrimitive; @@ -125,12 +126,9 @@ void testSszRoundtrip(final SszBitlist bitlist1) { @ParameterizedTest @MethodSource("bitlistArgs") - void testWrapBitSet(final SszBitlist bitlist1) { - final BitSet bits = new BitSet(bitlist1.size()); - - bitlist1.streamAllSetBits().forEach(bits::set); - - final SszBitlist bitlist2 = bitlist1.getSchema().wrapBitSet(bitlist1.size(), bits); + void testBitSetRoundtrip(final SszBitlist bitlist1) { + final SszBitlist bitlist2 = + bitlist1.getSchema().wrapBitSet(bitlist1.size(), bitlist1.getAsBitSet()); for (int i = 0; i < bitlist1.size(); i++) { assertThat(bitlist2.getBit(i)).isEqualTo(bitlist1.getBit(i)); @@ -143,6 +141,35 @@ void testWrapBitSet(final SszBitlist bitlist1) { assertThat(bitlist2.sszSerialize()).isEqualTo(bitlist1.sszSerialize()); } + @Test + void getAsBitSet_withFullStartEnd() { + final SszBitlist list = random(SCHEMA, 100); + + final BitSet fullSlice = list.getAsBitSet(0, 100); + + final SszBitlist newList = SCHEMA.wrapBitSet(100, fullSlice); + + assertThat(newList.getAsBitSet()).isEqualTo(fullSlice); + SszDataAssert.assertThatSszData(newList).isEqualByAllMeansTo(list); + } + + @Test + void getAsBitSet_withSubsetStartEnd() { + final SszBitlist list = random(SCHEMA, 100); + + final BitSet slice = list.getAsBitSet(5, 55); + + final SszBitlist newList = SCHEMA.wrapBitSet(50, slice); + + assertThat(newList.getAsBitSet()).isEqualTo(slice); + IntStream.range(0, 50) + .forEach( + i -> + assertThat(newList.getBit(i)) + .describedAs("Bit %s should be equal", i) + .isEqualTo(list.getBit(i + 5))); + } + @Test void wrapBitSet_shouldDropBitsIfBitSetIsLarger() { final BitSet bitSet = new BitSet(100); diff --git a/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java b/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java index e12b138a7d1..0b0b1c92b29 100644 --- a/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java +++ b/infrastructure/ssz/src/test/java/tech/pegasys/teku/infrastructure/ssz/collections/SszBitvectorTest.java @@ -261,4 +261,15 @@ void testFromHexString(final SszBitvector bitvector) { SszBitvectorImpl.fromHexString(bitvector.getSchema(), hexString, (int) size); SszDataAssert.assertThatSszData(result).isEqualByAllMeansTo(bitvector); } + + @ParameterizedTest + @MethodSource("bitvectorArgs") + void testBitSetRoundtrip(final SszBitvector bitvector) { + + final SszBitvectorSchema schema = bitvector.getSchema(); + + final SszBitvector newVector = schema.wrapBitSet(bitvector.size(), bitvector.getAsBitSet()); + + SszDataAssert.assertThatSszData(newVector).isEqualByAllMeansTo(bitvector); + } }