Skip to content

Commit

Permalink
feat: Fulfill hmss requisition via the share shuffle library. (#1748)
Browse files Browse the repository at this point in the history
  • Loading branch information
ple13 authored Aug 21, 2024
1 parent f21a035 commit 2ba3a5a
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.wfanet.measurement.consent.client.dataprovider.signRandomSeed
*/
class FulfillRequisitionRequestBuilder(
private val requisition: Requisition,
private val requisitionNonce: Long,
private val frequencyVector: FrequencyVector,
private val dataProviderCertificateKey: DataProviderCertificateKey,
private val signingKeyHandle: SigningKeyHandle,
Expand Down Expand Up @@ -128,7 +129,7 @@ class FulfillRequisitionRequestBuilder(
header = header {
name = requisition.name
requisitionFingerprint = computeRequisitionFingerprint(requisition)
nonce = requisition.nonce
this.nonce = requisitionNonce
protocolConfig = requisition.protocolConfig
this.honestMajorityShareShuffle = honestMajorityShareShuffle {
secretSeed = encryptedSignedShareSeed
Expand Down Expand Up @@ -185,12 +186,14 @@ class FulfillRequisitionRequestBuilder(
/** A convenience function for building the Sequence of Requests. */
fun build(
requisition: Requisition,
requisitionNonce: Long,
frequencyVector: FrequencyVector,
dataProviderCertificateKey: DataProviderCertificateKey,
signingKeyHandle: SigningKeyHandle,
): Sequence<FulfillRequisitionRequest> =
FulfillRequisitionRequestBuilder(
requisition,
requisitionNonce,
frequencyVector,
dataProviderCertificateKey,
signingKeyHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,19 @@ interface VidIndexMap {
* The order of iteration is undefined.
*/
operator fun iterator(): Iterator<VidIndexMapEntry>

companion object {
val EMPTY: VidIndexMap =
object : VidIndexMap {
override fun get(vid: Long): Int = throw VidNotFoundException(vid)

override val size: Long = 0L

override val populationSpec: PopulationSpec
get() = PopulationSpec.getDefaultInstance()

override fun iterator(): Iterator<VidIndexMapEntry> =
emptyList<VidIndexMapEntry>().iterator()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/common/identity",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement/testing",
"//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:edp_simulator",
"//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:population_spec_converter",
"//src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider:synthetic_generator_event_query",
"@wfa_common_jvm//imports/java/io/grpc:core",
"@wfa_common_jvm//imports/java/org/junit",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCorouti
import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec
import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticPopulationSpec
import org.wfanet.measurement.api.v2alpha.event_templates.testing.TestEvent
import org.wfanet.measurement.api.v2alpha.populationSpec
import org.wfanet.measurement.common.Health
import org.wfanet.measurement.common.identity.withPrincipalName
import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler
Expand All @@ -51,9 +52,10 @@ import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.InMemory
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyBucketFilter
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyBudgetManager
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.testing.TestPrivacyBucketMapper
import org.wfanet.measurement.eventdataprovider.shareshuffle.v2alpha.InMemoryVidIndexMap
import org.wfanet.measurement.loadtest.dataprovider.EdpSimulator
import org.wfanet.measurement.loadtest.dataprovider.SyntheticGeneratorEventQuery
import org.wfanet.measurement.loadtest.dataprovider.VidToIndexMapGenerator
import org.wfanet.measurement.loadtest.dataprovider.toPopulationSpec

/** An in process EDP simulator. */
class InProcessEdpSimulator(
Expand Down Expand Up @@ -82,15 +84,12 @@ class InProcessEdpSimulator(
private val delegate: EdpSimulator

init {
val vidRangeStart = syntheticPopulationSpec.vidRange.start
val vidRangeEndExclusive = syntheticPopulationSpec.vidRange.endExclusive
val vidIndexMap =
val populationSpec = syntheticPopulationSpec.toPopulationSpec()
val hmssVidIndexMap =
if (honestMajorityShareShuffleSupported) {
VidToIndexMapGenerator.generateMapping(
(vidRangeStart until vidRangeEndExclusive).asSequence()
)
InMemoryVidIndexMap.build(populationSpec)
} else {
emptyMap()
null
}

delegate =
Expand All @@ -117,10 +116,7 @@ class InProcessEdpSimulator(
},
eventQuery =
object :
SyntheticGeneratorEventQuery(
SyntheticGenerationSpecs.SYNTHETIC_POPULATION_SPEC_SMALL,
TestEvent.getDescriptor(),
) {
SyntheticGeneratorEventQuery(syntheticPopulationSpec, TestEvent.getDescriptor()) {
override fun getSyntheticDataSpec(eventGroup: EventGroup) = syntheticDataSpec
},
throttler = MinimumIntervalThrottler(Clock.systemUTC(), Duration.ofMillis(1000)),
Expand All @@ -132,7 +128,7 @@ class InProcessEdpSimulator(
100.0f,
),
trustedCertificates = trustedCertificates,
vidToIndexMap = vidIndexMap,
hmssVidIndexMap = hmssVidIndexMap,
knownEventGroupMetadataTypes = listOf(SyntheticEventGroupSpec.getDescriptor().file),
random = random,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,14 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement:privacy_budget_manager",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/privacybudgetmanagement/api/v2alpha:privacy_query_mapper",
"//src/main/kotlin/org/wfanet/measurement/eventdataprovider/shareshuffle/v2alpha:shareshuffle",
"//src/main/kotlin/org/wfanet/measurement/loadtest/common:sample_vids",
"//src/main/kotlin/org/wfanet/measurement/loadtest/config:privacy_budgets",
"//src/main/kotlin/org/wfanet/measurement/loadtest/config:test_identifiers",
"//src/main/kotlin/org/wfanet/measurement/loadtest/config:vid_sampling",
"//src/main/proto/halo_cmm/uk/pilot:event_kt_jvm_proto",
"//src/main/proto/wfa/any_sketch:frequency_vector_kt_jvm_proto",
"//src/main/proto/wfa/any_sketch:sketch_kt_jvm_proto",
"//src/main/proto/wfa/frequency_count:frequency_vector_kt_jvm_proto",
"//src/main/proto/wfa/frequency_count:secret_share_kt_jvm_proto",
"//src/main/proto/wfa/frequency_count:secret_share_methods_kt_jvm__proto",
"//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto",
Expand Down Expand Up @@ -286,6 +287,7 @@ kt_jvm_library(
srcs = ["SyntheticGeneratorEdpSimulatorRunner.kt"],
deps = [
":edp_simulator_runner",
":population_spec_converter",
":synthetic_generator_event_query",
"@wfa_common_jvm//imports/java/picocli",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
Expand Down Expand Up @@ -327,3 +329,14 @@ kt_jvm_library(
"@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider",
],
)

kt_jvm_library(
name = "population_spec_converter",
srcs = [
"PopulationSpecConverter.kt",
],
deps = [
"//src/main/proto/wfa/measurement/api/v2alpha:population_spec_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:simulator_synthetic_data_spec_kt_jvm_proto",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import kotlin.random.Random
import kotlin.random.asJavaRandom
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
Expand All @@ -47,11 +48,6 @@ import org.wfanet.anysketch.Sketch
import org.wfanet.anysketch.SketchConfig
import org.wfanet.anysketch.crypto.ElGamalPublicKey as AnySketchElGamalPublicKey
import org.wfanet.anysketch.crypto.elGamalPublicKey as anySketchElGamalPublicKey
import org.wfanet.frequencycount.FrequencyVector
import org.wfanet.frequencycount.SecretShare
import org.wfanet.frequencycount.SecretShareGeneratorAdapter
import org.wfanet.frequencycount.frequencyVector
import org.wfanet.frequencycount.secretShareGeneratorRequest
import org.wfanet.measurement.api.v2alpha.Certificate
import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub
import org.wfanet.measurement.api.v2alpha.CustomDirectMethodologyKt.variance
Expand All @@ -64,7 +60,6 @@ import org.wfanet.measurement.api.v2alpha.DeterministicDistribution
import org.wfanet.measurement.api.v2alpha.DifferentialPrivacyParams
import org.wfanet.measurement.api.v2alpha.ElGamalPublicKey
import org.wfanet.measurement.api.v2alpha.EncryptedMessage
import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey
import org.wfanet.measurement.api.v2alpha.EventAnnotationsProto
import org.wfanet.measurement.api.v2alpha.EventGroup
import org.wfanet.measurement.api.v2alpha.EventGroupKey
Expand Down Expand Up @@ -106,7 +101,7 @@ import org.wfanet.measurement.api.v2alpha.fulfillRequisitionRequest
import org.wfanet.measurement.api.v2alpha.getEventGroupRequest
import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest
import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest
import org.wfanet.measurement.api.v2alpha.randomSeed
import org.wfanet.measurement.api.v2alpha.populationSpec
import org.wfanet.measurement.api.v2alpha.replaceDataAvailabilityIntervalRequest
import org.wfanet.measurement.api.v2alpha.replaceDataProviderCapabilitiesRequest
import org.wfanet.measurement.api.v2alpha.unpack
Expand All @@ -124,8 +119,6 @@ import org.wfanet.measurement.common.throttler.Throttler
import org.wfanet.measurement.common.toProtoTime
import org.wfanet.measurement.consent.client.dataprovider.computeRequisitionFingerprint
import org.wfanet.measurement.consent.client.dataprovider.encryptMetadata
import org.wfanet.measurement.consent.client.dataprovider.encryptRandomSeed
import org.wfanet.measurement.consent.client.dataprovider.signRandomSeed
import org.wfanet.measurement.consent.client.dataprovider.verifyElGamalPublicKey
import org.wfanet.measurement.consent.client.measurementconsumer.verifyEncryptionPublicKey
import org.wfanet.measurement.dataprovider.DataProviderData
Expand All @@ -142,6 +135,9 @@ import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyB
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.Reference
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.api.v2alpha.PrivacyQueryMapper.getDirectAcdpQuery
import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.api.v2alpha.PrivacyQueryMapper.getMpcAcdpQuery
import org.wfanet.measurement.eventdataprovider.shareshuffle.v2alpha.FrequencyVectorBuilder
import org.wfanet.measurement.eventdataprovider.shareshuffle.v2alpha.FulfillRequisitionRequestBuilder
import org.wfanet.measurement.eventdataprovider.shareshuffle.v2alpha.VidIndexMap
import org.wfanet.measurement.loadtest.common.sampleVids
import org.wfanet.measurement.loadtest.config.TestIdentifiers.SIMULATOR_EVENT_GROUP_REFERENCE_ID_PREFIX
import org.wfanet.measurement.loadtest.dataprovider.MeasurementResults.computeImpression
Expand All @@ -163,12 +159,12 @@ class EdpSimulator(
private val privacyBudgetManager: PrivacyBudgetManager,
private val trustedCertificates: Map<ByteString, X509Certificate>,
/**
* EDP uses the vidToIndexMap to fulfill the requisitions for the honest majority share shuffle
* EDP uses the vidIndexMap to fulfill the requisitions for the honest majority share shuffle
* protocol.
*
* When the vidToIndexMap is empty, the honest majority share shuffle protocol is not supported.
* When the vidIndexMap is empty, the honest majority share shuffle protocol is not supported.
*/
private val vidToIndexMap: Map<Long, IndexedValue> = emptyMap(),
private val hmssVidIndexMap: VidIndexMap? = null,
/**
* Known protobuf types for [EventGroupMetadataDescriptor]s.
*
Expand Down Expand Up @@ -196,7 +192,7 @@ class EdpSimulator(
val supportedProtocols = buildSet {
add(ProtocolConfig.Protocol.ProtocolCase.LIQUID_LEGIONS_V2)
add(ProtocolConfig.Protocol.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2)
if (vidToIndexMap.isNotEmpty()) {
if (hmssVidIndexMap != null) {
add(ProtocolConfig.Protocol.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE)
}
}
Expand All @@ -220,7 +216,7 @@ class EdpSimulator(
name = edpData.name
capabilities =
DataProviderKt.capabilities {
honestMajorityShareShuffleSupported = vidToIndexMap.isNotEmpty()
honestMajorityShareShuffleSupported = (hmssVidIndexMap != null)
}
}
)
Expand Down Expand Up @@ -953,31 +949,6 @@ class EdpSimulator(
}
}

private fun generateHmssSketch(
vidToIndexMap: Map<Long, IndexedValue>,
measurementSpec: MeasurementSpec,
eventGroupSpecs: Iterable<EventQuery.EventGroupSpec>,
): IntArray {
logger.info("Generating HMSS Sketch...")
val maximumFrequency =
if (measurementSpec.hasReachAndFrequency()) measurementSpec.reachAndFrequency.maximumFrequency
else 1

val sketch =
FrequencyVectorGenerator(vidToIndexMap, eventQuery, measurementSpec.vidSamplingInterval)
.generate(eventGroupSpecs)
.map { if (it > maximumFrequency) maximumFrequency else it }
.toIntArray()

logger.log(Level.INFO) { "Registers Size:\n${sketch.size}" }

if (logSketchDetails) {
logShareShuffleSketchDetails(sketch)
}

return sketch
}

private fun encryptLiquidLegionsV2Sketch(
sketch: Sketch,
ellipticCurveId: Int,
Expand Down Expand Up @@ -1138,48 +1109,16 @@ class EdpSimulator(
}
)
}
try {
requisitionFulfillmentStubsByDuchyId.values.first().fulfillRequisition(requests)
} catch (e: StatusException) {
throw Exception("Error fulfilling requisition ${requisition.name}", e)
}
fulfillRequisition(requisitionFulfillmentStubsByDuchyId.values.first(), requisition, requests)
}

private suspend fun fulfillRequisition(
requisitionFulfillmentStub: RequisitionFulfillmentCoroutineStub,
requisition: Requisition,
requisitionFingerprint: ByteString,
nonce: Long,
encryptedSignedSeed: EncryptedMessage,
shareVector: FrequencyVector,
requests: Flow<FulfillRequisitionRequest>,
) {
logger.info("Fulfilling requisition ${requisition.name}...")
val requests: Flow<FulfillRequisitionRequest> = flow {
logger.info { "Emitting FulfillRequisitionRequests..." }
emit(
fulfillRequisitionRequest {
header = header {
name = requisition.name
this.requisitionFingerprint = requisitionFingerprint
this.nonce = nonce
this.honestMajorityShareShuffle = honestMajorityShareShuffle {
secretSeed = encryptedSignedSeed
registerCount = shareVector.dataList.size.toLong()
dataProviderCertificate = edpData.certificateKey.toName()
}
}
}
)
emitAll(
shareVector.toByteString().asBufferedFlow(RPC_CHUNK_SIZE_BYTES).map {
fulfillRequisitionRequest { bodyChunk = bodyChunk { this.data = it } }
}
)
}
try {
val duchyId = getDuchyWithoutPublicKey(requisition)
val requisitionFulfillmentStub =
requisitionFulfillmentStubsByDuchyId[duchyId]
?: throw Exception("Requisition fulfillment stub not found for $duchyId.")
requisitionFulfillmentStub.fulfillRequisition(requests)
} catch (e: StatusException) {
throw Exception("Error fulfilling requisition ${requisition.name}", e)
Expand Down Expand Up @@ -1231,41 +1170,33 @@ class EdpSimulator(
requisition.duchiesCount - 1,
)

val frequencyVector =
try {
generateHmssSketch(vidToIndexMap, measurementSpec, eventGroupSpecs)
} catch (e: EventFilterValidationException) {
logger.log(
Level.WARNING,
"RequisitionFulfillmentWorkflow failed due to invalid event filter",
e,
)
throw RequisitionRefusalException(
Requisition.Refusal.Justification.SPEC_INVALID,
"Invalid event filter (${e.code}): ${e.code.description}",
)
logger.info("Generating sampled frequency vector for HMSS...")
val frequencyVectorBuilder =
FrequencyVectorBuilder(hmssVidIndexMap!!.populationSpec, measurementSpec, strict = false)
for (eventGroupSpec in eventGroupSpecs) {
eventQuery.getUserVirtualIds(eventGroupSpec).forEach {
frequencyVectorBuilder.increment(hmssVidIndexMap!![it])
}

val secretShareGeneratorRequest = secretShareGeneratorRequest {
data += frequencyVector.toList()
ringModulus = protocolConfig.ringModulus
}

val secretShare =
SecretShare.parseFrom(
SecretShareGeneratorAdapter.generateSecretShares(secretShareGeneratorRequest.toByteArray())
)

val shareSeed = randomSeed { data = secretShare.shareSeed.key.concat(secretShare.shareSeed.iv) }
val signedShareSeed =
signRandomSeed(shareSeed, edpData.signingKeyHandle, edpData.signingKeyHandle.defaultAlgorithm)
val publicKey =
EncryptionPublicKey.parseFrom(getEncryptionKeyForShareSeed(requisition).message.value)
val shareSeedCiphertext = encryptRandomSeed(signedShareSeed, publicKey)
val sampledFrequencyVector = frequencyVectorBuilder.build()
logger.log(Level.INFO) { "Sampled frequency vector size:\n${sampledFrequencyVector.dataCount}" }

val shareVector = frequencyVector { data += secretShare.shareVectorList }
val requests =
FulfillRequisitionRequestBuilder.build(
requisition,
nonce,
sampledFrequencyVector,
edpData.certificateKey,
edpData.signingKeyHandle,
)
.asFlow()

fulfillRequisition(requisition, requisitionFingerprint, nonce, shareSeedCiphertext, shareVector)
val duchyId = getDuchyWithoutPublicKey(requisition)
val requisitionFulfillmentStub =
requisitionFulfillmentStubsByDuchyId[duchyId]
?: throw Exception("Requisition fulfillment stub not found for $duchyId.")
fulfillRequisition(requisitionFulfillmentStub, requisition, requests)
}

private fun Requisition.getCombinedPublicKey(curveId: Int): AnySketchElGamalPublicKey {
Expand Down
Loading

0 comments on commit 2ba3a5a

Please sign in to comment.