Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,5 @@ static AttestationBitsAggregator of(
boolean requiresCommitteeBits();

/** Creates an independent copy of this instance */
default AttestationBitsAggregator copy() {
if (requiresCommitteeBits()) {
return new AttestationBitsAggregatorElectra(
getAggregationBits(), getCommitteeBits(), getCommitteesSize());
}
return new AttestationBitsAggregatorPhase0(getAggregationBits());
}
AttestationBitsAggregator copy();
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ class AttestationBitsAggregatorElectra implements AttestationBitsAggregator {
parseAggregationBits(initialAggregationBits, this.committeeBits, this.committeesSize);
}

private AttestationBitsAggregatorElectra(
final SszBitlistSchema<?> aggregationBitsSchema,
final SszBitvectorSchema<?> committeeBitsSchema,
final Int2IntMap committeesSize,
final Int2ObjectMap<BitSet> committeeAggregationBitsMap,
final BitSet committeeBits) {
this.aggregationBitsSchema = aggregationBitsSchema;
this.committeeBitsSchema = committeeBitsSchema;
this.committeesSize = Objects.requireNonNull(committeesSize, "committeesSize cannot be null");
this.committeeBits = committeeBits;
this.committeeAggregationBitsMap = committeeAggregationBitsMap;
}

static AttestationBitsAggregator fromAttestationSchema(
final AttestationSchema<?> attestationSchema, final Int2IntMap committeesSize) {
final SszBitlist emptyAggregationBits = attestationSchema.createEmptyAggregationBits();
Expand All @@ -76,13 +89,14 @@ private static Int2ObjectMap<BitSet> parseAggregationBits(
committeeIndex = committeeIndices.nextSetBit(committeeIndex + 1)) {

final int committeeSize = committeesSizeMap.getOrDefault(committeeIndex, 0);
if (committeeSize > 0) {
int sliceEnd =
Math.min(currentOffset + committeeSize, aggregationBits.getLastSetBitIndex() + 1);
final BitSet committeeBits = aggregationBits.getAsBitSet(currentOffset, sliceEnd);
result.put(committeeIndex, committeeBits);
if (committeeSize == 0) {
throw new IllegalArgumentException(
"Committee size for committee " + committeeIndex + " not found");
}
currentOffset += committeeSize; // Always advance by the declared committee size
final BitSet committeeBits =
aggregationBits.getAsBitSet(currentOffset, currentOffset + committeeSize);
result.put(committeeIndex, committeeBits);
currentOffset += committeeSize;
}
return result;
}
Expand Down Expand Up @@ -113,6 +127,16 @@ public void or(final Attestation other) {
performMerge(otherCommitteeBits, otherParsedAggregationMap, false);
}

private static Int2ObjectMap<BitSet> cloneCommitteeAggregationBitsMap(
final Int2ObjectMap<BitSet> committeeAggregationBitsMap) {
final Int2ObjectMap<BitSet> clonedMap = new Int2ObjectOpenHashMap<>();
for (final Int2ObjectMap.Entry<BitSet> entry :
committeeAggregationBitsMap.int2ObjectEntrySet()) {
clonedMap.put(entry.getIntKey(), (BitSet) entry.getValue().clone());
}
return clonedMap;
}

private boolean performMerge(
final BitSet otherCommitteeBits,
final Int2ObjectMap<BitSet> otherCommitteeAggregationBitsMap,
Expand All @@ -124,11 +148,7 @@ private boolean performMerge(

if (isAggregation) {
// If aggregating, we need to work on copies
targetAggregationBitsMap = new Int2ObjectOpenHashMap<>();
for (final Int2ObjectMap.Entry<BitSet> entry :
this.committeeAggregationBitsMap.int2ObjectEntrySet()) {
targetAggregationBitsMap.put(entry.getIntKey(), (BitSet) entry.getValue().clone());
}
targetAggregationBitsMap = cloneCommitteeAggregationBitsMap(this.committeeAggregationBitsMap);
} else {
// if not aggregating, we can modify in place
targetAggregationBitsMap = this.committeeAggregationBitsMap;
Expand Down Expand Up @@ -256,7 +276,8 @@ public SszBitlist getAggregationBits() {
public SszBitvector getCommitteeBits() {
if (cachedCommitteeBits == null) {
cachedCommitteeBits =
committeeBitsSchema.wrapBitSet(committeeBitsSchema.getLength(), this.committeeBits);
committeeBitsSchema.wrapBitSet(
committeeBitsSchema.getLength(), (BitSet) this.committeeBits.clone());
}
return cachedCommitteeBits;
}
Expand All @@ -276,6 +297,16 @@ public boolean requiresCommitteeBits() {
return true;
}

@Override
public AttestationBitsAggregator copy() {
return new AttestationBitsAggregatorElectra(
aggregationBitsSchema,
committeeBitsSchema,
committeesSize,
cloneCommitteeAggregationBitsMap(committeeAggregationBitsMap),
(BitSet) committeeBits.clone());
}

@Override
public String toString() {
long totalSetBits = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public boolean requiresCommitteeBits() {
return false;
}

@Override
public AttestationBitsAggregator copy() {
return new AttestationBitsAggregatorPhase0(aggregationBits);
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this).add("aggregationBits", aggregationBits).toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.mockito.Mockito.when;

import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap.Entry;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -643,6 +644,55 @@ void isSuperSetOf6() {
assertThat(aggregator.isSuperSetOf(otherAttestation.getAttestation())).isTrue();
}

@Test
void getAggregationBits_shouldBeConsistent_singleCommittee() {
final ValidatableAttestation initialAttestation = createAttestation(List.of(0), 0);
final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation);

assertThat(aggregator.getAggregationBits().size()).isEqualTo(committeeSizes.get(0));

assertThat(aggregator.getAggregationBits())
.isEqualTo(initialAttestation.getAttestation().getAggregationBits());
}

@Test
void getAggregationBits_shouldBeConsistent_multiCommittee() {
final ValidatableAttestation initialAttestation = createAttestation(List.of(0, 1), 0, 3);
final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation);

assertThat(aggregator.getAggregationBits().size())
.isEqualTo(committeeSizes.get(0) + committeeSizes.get(1));

assertThat(aggregator.getAggregationBits())
.isEqualTo(initialAttestation.getAttestation().getAggregationBits());
}

@Test
void copy_shouldNotModifyOriginal() {
final ValidatableAttestation initialAttestation = createAttestation(List.of(0), 0);
final AttestationBitsAggregator aggregator = AttestationBitsAggregator.of(initialAttestation);

// check aggregator is initialized correctly
assertThat(aggregator.getCommitteeBits())
.isEqualTo(initialAttestation.getAttestation().getCommitteeBitsRequired());
assertThat(aggregator.getAggregationBits())
.isEqualTo(initialAttestation.getAttestation().getAggregationBits());

final AttestationBitsAggregator copy = aggregator.copy();

assertThat(copy.getCommitteeBits()).isEqualTo(aggregator.getCommitteeBits());
assertThat(copy.getAggregationBits()).isEqualTo(aggregator.getAggregationBits());
assertThat(copy).isNotSameAs(aggregator);

assertThat(copy.aggregateWith(createAttestation(List.of(1), 1).getAttestation())).isTrue();

// the original should not be modified
assertThat(aggregator.getCommitteeBits())
.isEqualTo(initialAttestation.getAttestation().getCommitteeBitsRequired());
assertThat(aggregator.getAggregationBits())
.isEqualTo(initialAttestation.getAttestation().getAggregationBits());
}

private ValidatableAttestation createAttestation(final String commBits, final String aggBits) {
assertThat(commBits).matches(Pattern.compile("^[0-1]+$"));
assertThat(aggBits).matches(Pattern.compile("^[0-1]+$"));
Expand All @@ -664,7 +714,12 @@ private ValidatableAttestation createAttestation(
final SszBitlist aggregationBits =
attestationSchema
.getAggregationBitsSchema()
.ofBits(committeeSizes.values().intStream().sum(), validators);
.ofBits(
committeeSizes.int2IntEntrySet().stream()
.filter(entry -> committeeIndices.contains(entry.getIntKey()))
.mapToInt(Entry::getIntValue)
.sum(),
validators);
final Supplier<SszBitvector> committeeBits =
() -> attestationSchema.getCommitteeBitsSchema().orElseThrow().ofBits(committeeIndices);

Expand Down