diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/MillBase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/MillBase.kt index aefc7f91e27..35799d3cd90 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/MillBase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/MillBase.kt @@ -131,6 +131,8 @@ abstract class MillBase( ) { abstract val endingStage: ComputationStage + abstract val prioritizedStagesToClaim: List + private val meter: Meter = openTelemetry.getMeter(this::class.qualifiedName!!) protected val cryptoWallClockDurationHistogram: DoubleHistogram = @@ -170,6 +172,7 @@ abstract class MillBase( computationType = this@MillBase.computationType owner = millId lockDuration = workLockDuration.toProtoDuration() + prioritizedStages += prioritizedStagesToClaim } val claimWorkResponse = try { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2Mill.kt index 15883998113..f20ed3fa08d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2Mill.kt @@ -39,12 +39,12 @@ import org.wfanet.measurement.duchy.service.internal.computations.outputPathList import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason -import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ElGamalPublicKey import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest +import org.wfanet.measurement.internal.duchy.computationStage import org.wfanet.measurement.internal.duchy.config.RoleInComputation import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR @@ -144,10 +144,9 @@ class ReachFrequencyLiquidLegionsV2Mill( maximumAttempts, clock, ) { - override val endingStage: ComputationStage = - ComputationStage.newBuilder() - .apply { liquidLegionsSketchAggregationV2 = Stage.COMPLETE } - .build() + override val endingStage = Stage.COMPLETE.toProtocolStage() + + override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage()) private val actions = mapOf( diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt index 125cfc5552f..cca1e857a0d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -43,6 +43,7 @@ import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ElGamalPublicKey import org.wfanet.measurement.internal.duchy.computationDetails +import org.wfanet.measurement.internal.duchy.computationStage import org.wfanet.measurement.internal.duchy.config.RoleInComputation import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR @@ -134,6 +135,8 @@ class ReachOnlyLiquidLegionsV2Mill( ) { override val endingStage = Stage.COMPLETE.toProtocolStage() + override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage()) + private val actions = mapOf( Pair(Stage.INITIALIZATION_PHASE, AGGREGATOR) to ::initializationPhase, diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMill.kt index 1ea661de49f..ced4e43d6e0 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMill.kt @@ -135,6 +135,8 @@ class HonestMajorityShareShuffleMill( override val endingStage = Stage.COMPLETE.toProtocolStage() + override val prioritizedStagesToClaim = listOf(Stage.INITIALIZED.toProtocolStage()) + private val actions = mapOf( Pair(Stage.INITIALIZED, FIRST_NON_AGGREGATOR) to ::initializationPhase, diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt index 83bc8167650..6904caa8321 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt @@ -121,9 +121,16 @@ interface ComputationsDatabaseTransactor = listOf(), + ): String? /** * Transitions a computation to a new stage. diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt index 11f0ee220b3..b907c298ff6 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt @@ -359,8 +359,9 @@ private constructor( protocol: ComputationType, ownerId: String, lockDuration: Duration, + prioritizedStages: List, ): String? { - val claimed = + val editTokens = tokens.values .asSequence() .filter { it.globalComputationId !in claimedComputationIds } @@ -385,7 +386,13 @@ private constructor( globalId = it.globalComputationId, ) } - .firstOrNull() ?: return null + val prioritizedComputations = editTokens.filter { it.stage in prioritizedStages } + val claimed = + if (!prioritizedComputations.none()) { + prioritizedComputations.first() + } else { + editTokens.firstOrNull() ?: return null + } updateToken(claimed) { existing -> claimedComputationIds.add(existing.globalComputationId) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 020eca9ca1c..1f77bbc7aed 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -170,6 +170,7 @@ class PostgresComputationsService( try { ClaimWork( request.computationType, + request.prioritizedStagesList, request.owner, lockDuration, clock, diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel index 13a48209006..10c15af0d81 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel @@ -18,6 +18,7 @@ kt_jvm_library( "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", ], diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt index b4e8340031d..7e95029b769 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt @@ -23,6 +23,8 @@ import org.wfanet.measurement.common.db.r2dbc.DatabaseClient import org.wfanet.measurement.common.db.r2dbc.ReadContext import org.wfanet.measurement.common.db.r2dbc.ResultRow import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.toProtoTime import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper @@ -349,10 +351,10 @@ class ComputationReader( readContext: ReadContext, protocol: Long, timestamp: Instant, + prioritizedStages: List, ): Flow { - val listUnclaimedTasksSql = - boundStatement( - """ + val baseSql = + """ SELECT c.ComputationId, c.GlobalComputationId, c.Protocol, c.ComputationStage, c.UpdateTime, c.CreationTime, cs.NextAttempt @@ -363,15 +365,45 @@ class ComputationReader( WHERE c.Protocol = $1 AND c.LockExpirationTime IS NOT NULL AND c.LockExpirationTime <= $2 - ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC - LIMIT 50; """ - ) { - bind("$1", protocol) - bind("$2", timestamp) - } - return readContext.executeQuery(listUnclaimedTasksSql).consume(::buildUnclaimedTaskQueryResult) + if (prioritizedStages.isEmpty()) { + val baseSqlWithOrder = + baseSql + + """ + ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC + LIMIT 50; + """ + val statement = + boundStatement(baseSqlWithOrder) { + bind("$1", protocol) + bind("$2", timestamp) + } + + return readContext.executeQuery(statement).consume(::buildUnclaimedTaskQueryResult) + } else { + val baseSqlWithOrder = + baseSql + + """ + ORDER BY + CASE WHEN c.ComputationStage + IN (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) THEN 0 + ELSE 1 END ASC, + c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC + LIMIT 50; + """ + val statement = + valuesListBoundStatement(valuesStartIndex = 2, paramCount = 1, baseSqlWithOrder) { + bind("$1", protocol) + bind("$2", timestamp) + for (stage in prioritizedStages) { + val longValue = + computationProtocolStagesEnumHelper.computationStageEnumToLongValues(stage).stage + addValuesBinding { bindValuesParam(0, longValue) } + } + } + return readContext.executeQuery(statement).consume(::buildUnclaimedTaskQueryResult) + } } /** diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt index e7a45aadf23..c0fd6bd3f45 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt @@ -25,6 +25,7 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException +import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.copy @@ -48,6 +49,7 @@ import org.wfanet.measurement.internal.duchy.copy */ class ClaimWork( private val protocol: ProtocolT, + private val prioritizedStages: List, private val ownerId: String, private val lockDuration: Duration, private val clock: Clock, @@ -63,6 +65,7 @@ class ClaimWork( transactionContext, protocolEnum, clock.instant().truncatedTo(ChronoUnit.MICROS), + prioritizedStages, ) // First the possible tasks to claim are selected from the computations table, then for each // item in the list we try to claim the lock in a transaction which will only succeed if the diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt index fda98ce9e5d..a79c9bba078 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt @@ -143,6 +143,7 @@ class GcpSpannerComputationsDatabaseTransactor< protocol: ProtocolT, ownerId: String, lockDuration: Duration, + prioritizedStages: List, ): String? { /** Claim a specific task represented by the results of running the above sql. */ suspend fun claimSpecificTask(result: UnclaimedTaskQueryResult): Boolean = @@ -159,7 +160,9 @@ class GcpSpannerComputationsDatabaseTransactor< } return UnclaimedTasksQuery( computationMutations.protocolEnumToLong(protocol), + prioritizedStages, computationMutations::longValuesToComputationStageEnum, + computationMutations::computationStageEnumToLongValues, clock.gcloudTimestamp(), ) .execute(databaseClient) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/UnclaimedTasksQuery.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/UnclaimedTasksQuery.kt index a32185fd147..31a6b0cce85 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/UnclaimedTasksQuery.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/UnclaimedTasksQuery.kt @@ -19,11 +19,15 @@ import com.google.cloud.spanner.Statement import com.google.cloud.spanner.Struct import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues import org.wfanet.measurement.duchy.deploy.gcloud.spanner.common.SqlBasedQuery +import org.wfanet.measurement.gcloud.spanner.appendClause +import org.wfanet.measurement.gcloud.spanner.statement /** Queries for computations which may be claimed at a timestamp. */ class UnclaimedTasksQuery( - val protocol: Long, - val parseStageEnum: (ComputationStageLongValues) -> StageT, + protocol: Long, + prioritizedStages: List, + val longValuesToComputationStageEnum: (ComputationStageLongValues) -> StageT, + computationStageEnumToLongValues: (StageT) -> ComputationStageLongValues, timestamp: Timestamp, ) : SqlBasedQuery> { companion object { @@ -44,25 +48,38 @@ class UnclaimedTasksQuery( AND c.LockExpirationTime IS NOT NULL AND c.UpdateTime IS NOT NULL AND c.LockExpirationTime <= @current_time - ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC - LIMIT 50 """ } override val sql: Statement = - Statement.newBuilder(parameterizedQueryString) - .bind("current_time") - .to(timestamp) - .bind("protocol") - .to(protocol) - .build() + statement(parameterizedQueryString) { + bind("current_time").to(timestamp) + bind("protocol").to(protocol) + if (prioritizedStages.isEmpty()) { + appendClause("ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC") + } else { + appendClause( + """ + ORDER BY + CASE WHEN c.ComputationStage IN UNNEST (@prioritized_stages) THEN 0 + ELSE 1 END ASC, + c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC + """ + .trimIndent() + ) + val prioritizedStageLongValues = + prioritizedStages.map(computationStageEnumToLongValues).map { it.stage } + bind("prioritized_stages").toInt64Array(prioritizedStageLongValues) + } + appendClause("LIMIT 50") + } override fun asResult(struct: Struct): UnclaimedTaskQueryResult = UnclaimedTaskQueryResult( computationId = struct.getLong("ComputationId"), globalId = struct.getString("GlobalComputationId"), computationStage = - parseStageEnum( + longValuesToComputationStageEnum( ComputationStageLongValues(struct.getLong("Protocol"), struct.getLong("ComputationStage")) ), creationTime = struct.getTimestamp("CreationTime"), diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index 4ce8985e9c3..e5f299c3775 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -84,7 +84,12 @@ class ComputationsService( val lockDuration = if (request.hasLockDuration()) request.lockDuration.toDuration() else defaultLockDuration val claimed = - computationsDatabase.claimTask(request.computationType, request.owner, lockDuration) + computationsDatabase.claimTask( + request.computationType, + request.owner, + lockDuration, + request.prioritizedStagesList, + ) return if (claimed != null) { val token = computationsDatabase.readComputationToken(claimed)!! sendStatusUpdateToKingdom( diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt index e954cee31be..37e0cbd2650 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt @@ -840,6 +840,62 @@ abstract class ComputationsServiceTest { assertThat(claimWorkResponse).isEqualTo(ClaimWorkResponse.getDefaultInstance()) } + @Test + fun `claimWork returns computation of prioritized stage`() = runBlocking { + var computation1 = + service + .createComputation( + DEFAULT_CREATE_COMPUTATION_REQUEST.copy { + globalComputationId = "1" + requisitions.clear() + } + ) + .token + + computation1 = + service + .claimWork( + DEFAULT_CLAIM_WORK_REQUEST.copy { + this.lockDuration = Durations.fromSeconds(1) + prioritizedStages += Stage.INITIALIZATION_PHASE.toProtocolStage() + } + ) + .token + + computation1 = + service + .advanceComputationStage( + advanceComputationStageRequest { + token = computation1 + nextComputationStage = Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() + // To simplify the test, assume WAIT_REQUISITIONS_AND_KEY_SET is claimable. + afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE + } + ) + .token + + // create a new computation with prioritized stage. + val computation2 = + service + .createComputation( + DEFAULT_CREATE_COMPUTATION_REQUEST.copy { + globalComputationId = "2" + requisitions.clear() + } + ) + .token + + val response = + service.claimWork( + DEFAULT_CLAIM_WORK_REQUEST.copy { + this.lockDuration = Durations.fromSeconds(1) + prioritizedStages += Stage.INITIALIZATION_PHASE.toProtocolStage() + } + ) + + assertThat(response.token.globalComputationId).isEqualTo(computation2.globalComputationId) + } + @Test fun `getComputationIds returns empty response when no matching computations`() = runBlocking { service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) diff --git a/src/main/proto/wfa/measurement/internal/duchy/computations_service.proto b/src/main/proto/wfa/measurement/internal/duchy/computations_service.proto index 61422d12485..95fe2e0e7e6 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/computations_service.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/computations_service.proto @@ -230,6 +230,9 @@ message ClaimWorkRequest { // How long the work lock is held for as a result of this claim. If not // specified, a default value may be chosen by the server. google.protobuf.Duration lock_duration = 3; + + // Stages of Computations that have higher priority to claim. + repeated ComputationStage prioritized_stages = 4; } message ClaimWorkResponse { // The token of the computation being claimed. The response would be empty if diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2MillTest.kt index 00d5946b127..10e2074491f 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachFrequencyLiquidLegionsV2MillTest.kt @@ -96,6 +96,8 @@ import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata import org.wfanet.measurement.internal.duchy.computationToken import org.wfanet.measurement.internal.duchy.config.RoleInComputation import org.wfanet.measurement.internal.duchy.copy +import org.wfanet.measurement.internal.duchy.elGamalKeyPair +import org.wfanet.measurement.internal.duchy.elGamalPublicKey import org.wfanet.measurement.internal.duchy.protocol.CompleteExecutionPhaseOneAtAggregatorRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteExecutionPhaseOneAtAggregatorResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteExecutionPhaseOneRequest @@ -137,6 +139,7 @@ import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseTwoA import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseTwoAtAggregatorResponse import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseTwoRequest import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseTwoResponse +import org.wfanet.measurement.internal.duchy.protocol.completeInitializationPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.completeSetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.copy import org.wfanet.measurement.internal.duchy.protocol.flagCountTupleNoiseGenerationParameters @@ -657,6 +660,45 @@ class ReachFrequencyLiquidLegionsV2MillTest { assertThat(fakeComputationDb.claimedComputationIds).isEmpty() } + @Test + fun `initialization phase has higher priority to be claimed`() = runBlocking { + fakeComputationDb.addComputation( + 1L, + EXECUTION_PHASE_ONE.toProtocolStage(), + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(newEmptyOutputBlobMetadata(1)), + ) + fakeComputationDb.addComputation( + 2L, + INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + ) + + var cryptoRequest = CompleteInitializationPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeInitializationPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + completeInitializationPhaseResponse { + elGamalKeyPair = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + + // Mill should claim computation1 of INITIALIZATION_PHASE. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[2]!!.computationStage) + .isEqualTo(WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage()) + assertThat(fakeComputationDb[1]!!.computationStage) + .isEqualTo(EXECUTION_PHASE_ONE.toProtocolStage()) + } + @Test fun `initialization phase`() = runBlocking { // Stage 0. preparing the database and set up mock diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt index c7215ad6f2f..c697255ac0b 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -609,6 +609,44 @@ class ReachOnlyLiquidLegionsV2MillTest { ) } + @Test + fun `initialization phase has higher priority to be claimed`() = runBlocking { + fakeComputationDb.addComputation( + 1L, + EXECUTION_PHASE.toProtocolStage(), + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(newEmptyOutputBlobMetadata(1)), + ) + fakeComputationDb.addComputation( + 2L, + INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + ) + + var cryptoRequest = CompleteReachOnlyInitializationPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + completeReachOnlyInitializationPhaseResponse { + elGamalKeyPair = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + + // Mill should claim computation1 of INITIALIZATION_PHASE. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[2]!!.computationStage) + .isEqualTo(WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage()) + assertThat(fakeComputationDb[1]!!.computationStage).isEqualTo(EXECUTION_PHASE.toProtocolStage()) + } + @Test fun `initialization phase`() = runBlocking { // Stage 0. preparing the database and set up mock diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt index 77d79344ed3..ddb50d25315 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/shareshuffle/HonestMajorityShareShuffleMillTest.kt @@ -564,6 +564,32 @@ class HonestMajorityShareShuffleMillTest { } } + @Test + fun `initialized phase has higher priority to be claimed`() = runBlocking { + val computationDetails = + getReachAndFrequencyHmssComputationDetails(RoleInComputation.FIRST_NON_AGGREGATOR) + fakeComputationDb.addComputation( + 1L, + Stage.SETUP_PHASE.toProtocolStage(), + computationDetails = computationDetails, + requisitions = REACH_AND_FREQUENCY_REQUISITIONS, + ) + + fakeComputationDb.addComputation( + 2L, + Stage.INITIALIZED.toProtocolStage(), + computationDetails = computationDetails, + requisitions = REACH_AND_FREQUENCY_REQUISITIONS, + ) + val mill = createHmssMill(DUCHY_ONE_ID) + mill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[2]!!.computationStage) + .isEqualTo(Stage.WAIT_TO_START.toProtocolStage()) + assertThat(fakeComputationDb[1]!!.computationStage) + .isEqualTo(Stage.SETUP_PHASE.toProtocolStage()) + } + @Test fun `initializationPhase sends params to Kingdom and advance stage`() = runBlocking { val computationParticipant = computationParticipant { diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt index 4887aec6077..4efe1cd1a63 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt @@ -678,6 +678,117 @@ class GcpSpannerComputationsDatabaseTransactorTest : assertNull(database.claimTask(FakeProtocol.ONE, "the-owner-of-the-lock", DEFAULT_LOCK_DURATION)) } + @Test + fun `claimTask returns prioritized computation`() = runBlocking { + testClock.tickSeconds("7_minutes_ago") + testClock.tickSeconds("6_minutes_ago", 60) + testClock.tickSeconds("5_minutes_ago", 60) + testClock.tickSeconds("TimeOfTest", 300) + val timestamp1 = testClock["7_minutes_ago"].toGcloudTimestamp() + val timestamp2 = testClock["6_minutes_ago"].toGcloudTimestamp() + val timestamp3 = testClock["5_minutes_ago"].toGcloudTimestamp() + + val computation1 = + computationMutations.insertComputation( + localId = 1, + creationTime = timestamp1, + updateTime = timestamp1, + globalId = "1", + protocol = FakeProtocol.ZERO, + stage = C, + lockOwner = WRITE_NULL_STRING, + lockExpirationTime = timestamp1, + details = FAKE_COMPUTATION_DETAILS, + ) + val computation1Stage = + computationMutations.insertComputationStage( + localId = 1, + stage = C, + nextAttempt = 1, + creationTime = timestamp1, + details = computationMutations.detailsFor(A, FAKE_COMPUTATION_DETAILS), + ) + val computation2 = + computationMutations.insertComputation( + localId = 2, + protocol = FakeProtocol.ZERO, + stage = E, + creationTime = timestamp2, + updateTime = timestamp2, + globalId = "2", + lockOwner = WRITE_NULL_STRING, + lockExpirationTime = timestamp2, + details = FAKE_COMPUTATION_DETAILS, + ) + val computation2Stage = + computationMutations.insertComputationStage( + localId = 2, + stage = E, + nextAttempt = 1, + creationTime = timestamp2, + details = computationMutations.detailsFor(A, FAKE_COMPUTATION_DETAILS), + ) + // This is the computation with prioritized stage. + val computation3 = + computationMutations.insertComputation( + localId = 3, + creationTime = timestamp3, + updateTime = timestamp3, + globalId = "3", + protocol = FakeProtocol.ZERO, + stage = A, + lockOwner = WRITE_NULL_STRING, + lockExpirationTime = timestamp3, + details = FAKE_COMPUTATION_DETAILS, + ) + val computation3Stage = + computationMutations.insertComputationStage( + localId = 3, + stage = A, + nextAttempt = 1, + creationTime = timestamp3, + details = computationMutations.detailsFor(A, FAKE_COMPUTATION_DETAILS), + ) + databaseClient.write( + listOf( + computation1, + computation1Stage, + computation2, + computation2Stage, + computation3, + computation3Stage, + ) + ) + + assertEquals( + "3", + database.claimTask( + FakeProtocol.ZERO, + "the-owner-of-the-lock", + DEFAULT_LOCK_DURATION, + listOf(A), + ), + ) + assertEquals( + "1", + database.claimTask( + FakeProtocol.ZERO, + "the-owner-of-the-lock", + DEFAULT_LOCK_DURATION, + listOf(A), + ), + ) + assertEquals( + "2", + database.claimTask( + FakeProtocol.ZERO, + "the-owner-of-the-lock", + DEFAULT_LOCK_DURATION, + listOf(A), + ), + ) + } + @Test fun `claim locked tasks`() = runBlocking { testClock.tickSeconds("5_minutes_ago", 60)