Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prioritizedStages for duchy claimTask #1673

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ abstract class MillBase(
) {
abstract val endingStage: ComputationStage

abstract val prioritizedStagesToClaim: List<ComputationStage>

private val meter: Meter = openTelemetry.getMeter(this::class.qualifiedName!!)

protected val cryptoWallClockDurationHistogram: DoubleHistogram =
Expand Down Expand Up @@ -158,6 +160,7 @@ abstract class MillBase(
computationType = [email protected]
owner = millId
lockDuration = workLockDuration.toProtoDuration()
prioritizedStages += prioritizedStagesToClaim
}
val claimWorkResponse =
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ import org.wfanet.measurement.duchy.service.internal.computations.outputPathList
import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader
import org.wfanet.measurement.duchy.toProtocolStage
import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason
import org.wfanet.measurement.internal.duchy.ComputationStage
import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.ElGamalPublicKey
import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest
import org.wfanet.measurement.internal.duchy.computationStage
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR
Expand Down Expand Up @@ -148,10 +148,9 @@ class ReachFrequencyLiquidLegionsV2Mill(
maximumAttempts,
clock,
) {
override val endingStage: ComputationStage =
ComputationStage.newBuilder()
.apply { liquidLegionsSketchAggregationV2 = Stage.COMPLETE }
.build()
override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage())

private val actions =
mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType
import org.wfanet.measurement.internal.duchy.ElGamalPublicKey
import org.wfanet.measurement.internal.duchy.computationDetails
import org.wfanet.measurement.internal.duchy.computationStage
import org.wfanet.measurement.internal.duchy.config.RoleInComputation
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.AGGREGATOR
import org.wfanet.measurement.internal.duchy.config.RoleInComputation.NON_AGGREGATOR
Expand Down Expand Up @@ -139,6 +140,8 @@ class ReachOnlyLiquidLegionsV2Mill(
) {
override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZATION_PHASE.toProtocolStage())

private val actions =
mapOf(
Pair(Stage.INITIALIZATION_PHASE, AGGREGATOR) to ::initializationPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class HonestMajorityShareShuffleMill(

override val endingStage = Stage.COMPLETE.toProtocolStage()

override val prioritizedStagesToClaim = listOf(Stage.INITIALIZED.toProtocolStage())

private val actions =
mapOf(
Pair(Stage.INITIALIZED, FIRST_NON_AGGREGATOR) to ::initializationPhase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ interface ComputationsDatabaseTransactor<ProtocolT, StageT, StageDetailsT, Compu
* @param ownerId The identifier of the worker process that will own the lock.
* @return global computation id of work that was claimed. When null, no work was claimed.
*/
suspend fun claimTask(protocol: ProtocolT, ownerId: String, lockDuration: Duration): String?
suspend fun claimTask(
protocol: ProtocolT,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<StageT> = listOf(),
): String?

/**
* Transitions a computation to a new stage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ private constructor(
protocol: ComputationType,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<ComputationStage>,
): String? {
val claimed =
val editTokens =
tokens.values
.asSequence()
.filter { it.globalComputationId !in claimedComputationIds }
Expand All @@ -385,7 +386,13 @@ private constructor(
globalId = it.globalComputationId,
)
}
.firstOrNull() ?: return null
val prioritizedComputations = editTokens.filter { it.stage in prioritizedStages }
val claimed =
if (!prioritizedComputations.none()) {
prioritizedComputations.first()
} else {
editTokens.firstOrNull() ?: return null
}

updateToken(claimed) { existing ->
claimedComputationIds.add(existing.globalComputationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class PostgresComputationsService(
try {
ClaimWork(
request.computationType,
request.prioritizedStagesList,
request.owner,
lockDuration,
clock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,24 +349,49 @@ class ComputationReader(
readContext: ReadContext,
protocol: Long,
timestamp: Instant,
prioritizedStages: List<ComputationStage>,
): Flow<UnclaimedTaskQueryResult> {
val listUnclaimedTasksSql =
boundStatement(
"""
val baseSql =
"""
SELECT c.ComputationId, c.GlobalComputationId,
c.Protocol, c.ComputationStage, c.UpdateTime,
c.CreationTime, cs.NextAttempt
FROM Computations AS c
JOIN ComputationStages AS cs
ON c.ComputationId = cs.ComputationId
AND c.ComputationStage = cs.ComputationStage
WHERE c.Protocol = $1
WHERE c.Protocol = ${'$'}1
AND c.LockExpirationTime IS NOT NULL
AND c.LockExpirationTime <= $2
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
AND c.LockExpirationTime <= ${'$'}2
"""
) {

val baseSqlWithOrder =
if (prioritizedStages.isEmpty()) {
baseSql +
"""
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
} else {
// Binding list of String into the IN clause does not work as expected with r2dbc library.
// Hence, manually joining targeting stages into a comma separated string and stub it into
// the
// query.
val stagesString =
prioritizedStages
.map { computationProtocolStagesEnumHelper.computationStageEnumToLongValues(it).stage }
.toList()
.joinToString(",")
baseSql +
"""
ORDER BY CASE WHEN c.ComputationStage IN ($stagesString) THEN 0 ELSE 1 END ASC,
c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50;
"""
}

val listUnclaimedTasksSql =
boundStatement(baseSqlWithOrder) {
bind("$1", protocol)
bind("$2", timestamp)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper
import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader
import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader
import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException
import org.wfanet.measurement.internal.duchy.ComputationStage
import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails
import org.wfanet.measurement.internal.duchy.ComputationToken
import org.wfanet.measurement.internal.duchy.copy
Expand All @@ -48,6 +49,7 @@ import org.wfanet.measurement.internal.duchy.copy
*/
class ClaimWork<ProtocolT, StageT>(
private val protocol: ProtocolT,
private val prioritizedStages: List<ComputationStage>,
private val ownerId: String,
private val lockDuration: Duration,
private val clock: Clock,
Expand All @@ -63,6 +65,7 @@ class ClaimWork<ProtocolT, StageT>(
transactionContext,
protocolEnum,
clock.instant().truncatedTo(ChronoUnit.MICROS),
prioritizedStages,
)
// First the possible tasks to claim are selected from the computations table, then for each
// item in the list we try to claim the lock in a transaction which will only succeed if the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class GcpSpannerComputationsDatabaseTransactor<
protocol: ProtocolT,
ownerId: String,
lockDuration: Duration,
prioritizedStages: List<StageT>,
): String? {
/** Claim a specific task represented by the results of running the above sql. */
suspend fun claimSpecificTask(result: UnclaimedTaskQueryResult<StageT>): Boolean =
Expand All @@ -157,8 +158,11 @@ class GcpSpannerComputationsDatabaseTransactor<
lockDuration,
)
}
val prioritizedStageLongValues =
prioritizedStages.map(computationMutations::computationStageEnumToLongValues).map { it.stage }
return UnclaimedTasksQuery(
computationMutations.protocolEnumToLong(protocol),
prioritizedStageLongValues,
computationMutations::longValuesToComputationStageEnum,
clock.gcloudTimestamp(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ import com.google.cloud.spanner.Statement
import com.google.cloud.spanner.Struct
import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues
import org.wfanet.measurement.duchy.deploy.gcloud.spanner.common.SqlBasedQuery
import org.wfanet.measurement.gcloud.spanner.appendClause
import org.wfanet.measurement.gcloud.spanner.statement

/** Queries for computations which may be claimed at a timestamp. */
class UnclaimedTasksQuery<StageT>(
val protocol: Long,
protocol: Long,
prioritizedStageLongValues: List<Long>,
val parseStageEnum: (ComputationStageLongValues) -> StageT,
timestamp: Timestamp,
) : SqlBasedQuery<UnclaimedTaskQueryResult<StageT>> {
Expand All @@ -44,18 +47,27 @@ class UnclaimedTasksQuery<StageT>(
AND c.LockExpirationTime IS NOT NULL
AND c.UpdateTime IS NOT NULL
AND c.LockExpirationTime <= @current_time
ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
LIMIT 50
"""
}

override val sql: Statement =
Statement.newBuilder(parameterizedQueryString)
.bind("current_time")
.to(timestamp)
.bind("protocol")
.to(protocol)
.build()
statement(parameterizedQueryString) {
bind("current_time").to(timestamp)
bind("protocol").to(protocol)
if (prioritizedStageLongValues.isEmpty()) {
appendClause("ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC")
} else {
appendClause(
"""
ORDER BY CASE WHEN c.ComputationStage IN UNNEST (@prioritized_stages) THEN 0 ELSE 1 END ASC,
c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC
"""
.trimIndent()
)
bind("prioritized_stages").toInt64Array(prioritizedStageLongValues)
}
appendClause("LIMIT 50")
}

override fun asResult(struct: Struct): UnclaimedTaskQueryResult<StageT> =
UnclaimedTaskQueryResult(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ class ComputationsService(
val lockDuration =
if (request.hasLockDuration()) request.lockDuration.toDuration() else defaultLockDuration
val claimed =
computationsDatabase.claimTask(request.computationType, request.owner, lockDuration)
computationsDatabase.claimTask(
request.computationType,
request.owner,
lockDuration,
request.prioritizedStagesList,
)
return if (claimed != null) {
val token = computationsDatabase.readComputationToken(claimed)!!
sendStatusUpdateToKingdom(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,62 @@ abstract class ComputationsServiceTest<T : ComputationsCoroutineImplBase> {
assertThat(claimWorkResponse).isEqualTo(ClaimWorkResponse.getDefaultInstance())
}

@Test
fun `claimWork returns computation of prioritized stage`() = runBlocking {
var computation1 =
service
.createComputation(
DEFAULT_CREATE_COMPUTATION_REQUEST.copy {
globalComputationId = "1"
requisitions.clear()
}
)
.token

computation1 =
service
.claimWork(
DEFAULT_CLAIM_WORK_REQUEST.copy {
this.lockDuration = Durations.fromSeconds(1)
prioritizedStages += Stage.INITIALIZATION_PHASE.toProtocolStage()
}
)
.token

computation1 =
service
.advanceComputationStage(
advanceComputationStageRequest {
token = computation1
nextComputationStage = Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage()
// To simplify the test, assume WAIT_REQUISITIONS_AND_KEY_SET is claimable.
afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE
}
)
.token

// create a new computation with prioritized stage.
val computation2 =
service
.createComputation(
DEFAULT_CREATE_COMPUTATION_REQUEST.copy {
globalComputationId = "2"
requisitions.clear()
}
)
.token

val response =
service.claimWork(
DEFAULT_CLAIM_WORK_REQUEST.copy {
this.lockDuration = Durations.fromSeconds(1)
prioritizedStages += Stage.INITIALIZATION_PHASE.toProtocolStage()
}
)

assertThat(response.token.globalComputationId).isEqualTo(computation2.globalComputationId)
}

@Test
fun `getComputationIds returns empty response when no matching computations`() = runBlocking {
service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ message ClaimWorkRequest {
// How long the work lock is held for as a result of this claim. If not
// specified, a default value may be chosen by the server.
google.protobuf.Duration lock_duration = 3;

// Stages of Computations that have higher priority to claim.
repeated ComputationStage prioritized_stages = 4;
}
message ClaimWorkResponse {
// The token of the computation being claimed. The response would be empty if
Expand Down
Loading
Loading