Skip to content

Commit

Permalink
Add enabledHmssMeasurementConsumers into kingdom config (#1688)
Browse files Browse the repository at this point in the history
Also merged R/F HMSS and Reach HMSS feature flags.
  • Loading branch information
renjiezh committed Jul 18, 2024
1 parent e9fd799 commit d73baa6
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey
import org.wfanet.measurement.internal.duchy.config.ProtocolsSetupConfig
import org.wfanet.measurement.internal.kingdom.DuchyIdConfig
import org.wfanet.measurement.internal.kingdom.HmssProtocolConfigConfig
import org.wfanet.measurement.internal.kingdom.Llv2ProtocolConfigConfig
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.loadtest.resourcesetup.EntityContent
Expand Down Expand Up @@ -74,6 +75,11 @@ val RO_LLV2_PROTOCOL_CONFIG_CONFIG: Llv2ProtocolConfigConfig =
"ro_llv2_protocol_config_config.textproto",
Llv2ProtocolConfigConfig.getDefaultInstance(),
)
val HMSS_PROTOCOL_CONFIG_CONFIG: HmssProtocolConfigConfig =
loadTextProto(
"hmss_protocol_config_config.textproto",
HmssProtocolConfigConfig.getDefaultInstance(),
)

val ALL_DUCHY_NAMES = DUCHY_ID_CONFIG.duchiesList.map { it.externalDuchyId }
val ALL_DUCHIES =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.wfanet.measurement.common.testing.ProviderRule
import org.wfanet.measurement.common.testing.chainRulesSequentially
import org.wfanet.measurement.config.DuchyCertConfig
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.service.DataServices
Expand Down Expand Up @@ -236,6 +237,10 @@ class InProcessCmmsComponents(
setOf("aggregator"),
2,
)
HmssProtocolConfig.setForTest(
HMSS_PROTOCOL_CONFIG_CONFIG.protocolConfig,
setOf("worker1", "worker2", "aggregator"),
)
DuchyInfo.initializeFromConfig(
loadTextProto("duchy_cert_config.textproto", DuchyCertConfig.getDefaultInstance())
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ private fun run(
internalDataProvidersStub,
v2alphaFlags.directNoiseMechanisms,
reachOnlyLlV2Enabled = v2alphaFlags.reachOnlyLlV2Enabled,
reachOnlyHmssEnabled = v2alphaFlags.hmssForReachEnabled,
reachAndFrequencyHmssEnabled = v2alphaFlags.hmssForRfEnabled,
hmssEnabled = v2alphaFlags.hmssEnabled,
hmssEnabledMeasurementConsumers = v2alphaFlags.hmssEnabledMeasurementConsumers,
)
.withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup)
.withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub),
Expand Down Expand Up @@ -247,27 +247,26 @@ private class V2alphaFlags {
private set

@set:CommandLine.Option(
names = ["--enable-hmss-for-rf"],
description =
[
"whether to enable to Honest Majority Share Shuffle protocol for ReachAndFrequency Measurement"
],
names = ["--enable-hmss"],
description = ["whether to enable the Honest Majority Share Shuffle protocol"],
negatable = true,
required = false,
defaultValue = "false",
)
var hmssForRfEnabled by Delegates.notNull<Boolean>()
var hmssEnabled by Delegates.notNull<Boolean>()
private set

@set:CommandLine.Option(
names = ["--enable-hmss-for-reach"],
@CommandLine.Option(
names = ["--hmss-enabled-measurement-consumers"],
description =
["whether to enable to Honest Majority Share Shuffle protocol for Reach Measurement"],
negatable = true,
[
"MeasurementConsumer names who force to enable HMSS protocol" +
" regardless the --enable-hmss flag."
],
required = false,
defaultValue = "false",
defaultValue = "",
)
var hmssForReachEnabled by Delegates.notNull<Boolean>()
lateinit var hmssEnabledMeasurementConsumers: List<String>
private set

@CommandLine.Option(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ class MeasurementsService(
private val internalDataProvidersStub: InternalDataProvidersCoroutineStub,
private val noiseMechanisms: List<NoiseMechanism>,
private val reachOnlyLlV2Enabled: Boolean = false,
// TODO(@renjiez): merge the two options below once implementing reach-only HMSS.
private val reachAndFrequencyHmssEnabled: Boolean = false,
private val reachOnlyHmssEnabled: Boolean = false,
private val hmssEnabled: Boolean = false,
private val hmssEnabledMeasurementConsumers: List<String> = emptyList(),
) : MeasurementsCoroutineImplBase() {

override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement {
Expand Down Expand Up @@ -472,6 +471,7 @@ class MeasurementsService(
private fun buildInternalProtocolConfig(
measurementSpec: MeasurementSpec,
dataProviderCapabilities: Collection<InternalDataProvider.Capabilities>,
measurementConsumerName: String,
): InternalProtocolConfig {
val dataProvidersCount = dataProviderCapabilities.size
val internalNoiseMechanisms = noiseMechanisms.map { it.toInternal() }
Expand All @@ -493,7 +493,7 @@ class MeasurementsService(
}
} else {
if (
reachOnlyHmssEnabled &&
(measurementConsumerName in hmssEnabledMeasurementConsumers || hmssEnabled) &&
dataProviderCapabilities.all { it.honestMajorityShareShuffleSupported }
) {
protocolConfig {
Expand Down Expand Up @@ -533,7 +533,7 @@ class MeasurementsService(
}
} else {
if (
reachAndFrequencyHmssEnabled &&
(measurementConsumerName in hmssEnabledMeasurementConsumers || hmssEnabled) &&
dataProviderCapabilities.all { it.honestMajorityShareShuffleSupported }
) {
protocolConfig {
Expand Down Expand Up @@ -634,7 +634,7 @@ class MeasurementsService(
measurement.toInternal(
measurementConsumerCertificateKey,
dataProviderValues,
buildInternalProtocolConfig(measurementSpec, dataProviderCapabilities),
buildInternalProtocolConfig(measurementSpec, dataProviderCapabilities, parentKey.toName()),
)

val requestId = this.requestId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class MeasurementsServiceTest {
}

private lateinit var service: MeasurementsService
private lateinit var hmssEnabledService: MeasurementsService

@Before
fun initService() {
Expand All @@ -241,7 +242,14 @@ class MeasurementsServiceTest {
DataProvidersGrpcKt.DataProvidersCoroutineStub(grpcTestServerRule.channel),
NOISE_MECHANISMS,
reachOnlyLlV2Enabled = true,
reachAndFrequencyHmssEnabled = true,
)

hmssEnabledService =
MeasurementsService(
MeasurementsGrpcKt.MeasurementsCoroutineStub(grpcTestServerRule.channel),
DataProvidersGrpcKt.DataProvidersCoroutineStub(grpcTestServerRule.channel),
NOISE_MECHANISMS,
hmssEnabled = true,
)
}

Expand Down Expand Up @@ -804,7 +812,7 @@ class MeasurementsServiceTest {
}

withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) {
runBlocking { service.createMeasurement(request) }
runBlocking { hmssEnabledService.createMeasurement(request) }
}

verifyProtoArgument(
Expand Down

0 comments on commit d73baa6

Please sign in to comment.