diff --git a/yarn-project/prover-client/src/mocks/test_context.ts b/yarn-project/prover-client/src/mocks/test_context.ts index 82d278906c13..42ecc4985c0e 100644 --- a/yarn-project/prover-client/src/mocks/test_context.ts +++ b/yarn-project/prover-client/src/mocks/test_context.ts @@ -20,6 +20,8 @@ import * as fs from 'fs/promises'; import { type MockProxy, mock } from 'jest-mock-extended'; import { ProvingOrchestrator } from '../orchestrator/orchestrator.js'; +import { CircuitProverAgent } from '../prover-pool/circuit-prover-agent.js'; +import { ProverPool } from '../prover-pool/prover-pool.js'; import { type BBProverConfig } from '../prover/bb_prover.js'; import { type CircuitProver } from '../prover/interface.js'; import { TestCircuitProver } from '../prover/test_circuit_prover.js'; @@ -35,6 +37,7 @@ export class TestContext { public globalVariables: GlobalVariables, public actualDb: MerkleTreeOperations, public prover: CircuitProver, + public proverPool: ProverPool, public orchestrator: ProvingOrchestrator, public blockNumber: number, public directoriesToCleanup: string[], @@ -43,6 +46,7 @@ export class TestContext { static async new( logger: DebugLogger, + proverCount = 4, createProver: (bbConfig: BBProverConfig) => Promise = _ => Promise.resolve(new TestCircuitProver(new WASMSimulator())), blockNumber = 3, @@ -82,7 +86,10 @@ export class TestContext { localProver = await createProver(bbConfig); } - const orchestrator = await ProvingOrchestrator.new(actualDb, localProver); + const proverPool = new ProverPool(proverCount, i => new CircuitProverAgent(localProver, 10, `${i}`)); + const orchestrator = new ProvingOrchestrator(actualDb, proverPool.queue); + + await proverPool.start(); return new this( publicExecutor, @@ -93,6 +100,7 @@ export class TestContext { globalVariables, actualDb, localProver, + proverPool, orchestrator, blockNumber, [config?.directoryToCleanup ?? ''], @@ -101,7 +109,7 @@ export class TestContext { } async cleanup() { - await this.orchestrator.stop(); + await this.proverPool.stop(); for (const dir of this.directoriesToCleanup.filter(x => x !== '')) { await fs.rm(dir, { recursive: true, force: true }); } diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator.ts b/yarn-project/prover-client/src/orchestrator/orchestrator.ts index 88ae37c05d07..743d1319d32b 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator.ts @@ -31,16 +31,20 @@ import { } from '@aztec/circuits.js'; import { makeTuple } from '@aztec/foundation/array'; import { padArrayEnd } from '@aztec/foundation/collection'; -import { MemoryFifo } from '@aztec/foundation/fifo'; import { createDebugLogger } from '@aztec/foundation/log'; +import { promiseWithResolvers } from '@aztec/foundation/promise'; import { type Tuple } from '@aztec/foundation/serialize'; -import { sleep } from '@aztec/foundation/sleep'; -import { elapsed } from '@aztec/foundation/timer'; +import { Timer } from '@aztec/foundation/timer'; import { type MerkleTreeOperations } from '@aztec/world-state'; import { inspect } from 'util'; -import { type CircuitProver } from '../prover/index.js'; +import { type ProvingQueue } from '../prover-pool/proving-queue.js'; +import { + type ProvingRequest, + type ProvingRequestPublicInputs, + ProvingRequestType, +} from '../prover-pool/proving-request.js'; import { buildBaseRollupInput, createMergeRollupInputs, @@ -67,71 +71,17 @@ const logger = createDebugLogger('aztec:prover:proving-orchestrator'); * The proving implementation is determined by the provided prover. This could be for example a local prover or a remote prover pool. */ -const SLEEP_TIME = 50; -const MAX_CONCURRENT_JOBS = 64; - -enum PROMISE_RESULT { - SLEEP, - OPERATIONS, -} - const KernelTypesWithoutFunctions: Set = new Set([ PublicKernelType.NON_PUBLIC, PublicKernelType.TAIL, ]); -/** - * Enums and structs to communicate the type of work required in each request. - */ -export enum PROVING_JOB_TYPE { - STATE_UPDATE, - BASE_ROLLUP, - MERGE_ROLLUP, - ROOT_ROLLUP, - BASE_PARITY, - ROOT_PARITY, - PUBLIC_KERNEL, - PUBLIC_VM, -} - -export type ProvingJob = { - type: PROVING_JOB_TYPE; - operation: () => Promise; -}; - /** * The orchestrator, managing the flow of recursive proving operations required to build the rollup proof tree. */ export class ProvingOrchestrator { private provingState: ProvingState | undefined = undefined; - private jobQueue: MemoryFifo = new MemoryFifo(); - private jobProcessPromise?: Promise; - private stopped = false; - constructor( - private db: MerkleTreeOperations, - private prover: CircuitProver, - private maxConcurrentJobs = MAX_CONCURRENT_JOBS, - ) {} - - // Constructs and starts a new orchestrator - public static async new(db: MerkleTreeOperations, prover: CircuitProver) { - const orchestrator = new ProvingOrchestrator(db, prover); - await orchestrator.start(); - return Promise.resolve(orchestrator); - } - - // Starts the proving job queue - public start() { - this.jobProcessPromise = this.processJobQueue(); - return Promise.resolve(); - } - - // Stops the proving job queue - public async stop() { - this.stopped = true; - this.jobQueue.cancel(); - await this.jobProcessPromise; - } + constructor(private db: MerkleTreeOperations, private queue: ProvingQueue) {} /** * Starts off a new block @@ -185,26 +135,28 @@ export class ProvingOrchestrator { // Update the local trees to include the new l1 to l2 messages await this.db.appendLeaves(MerkleTreeId.L1_TO_L2_MESSAGE_TREE, l1ToL2MessagesPadded); - let provingState: ProvingState | undefined = undefined; - - const promise = new Promise((resolve, reject) => { - provingState = new ProvingState( - numTxs, - resolve, - reject, - globalVariables, - l1ToL2MessagesPadded, - baseParityInputs.length, - emptyTx, - messageTreeSnapshot, - newL1ToL2MessageTreeRootSiblingPath, - ); - }).catch((reason: string) => ({ status: PROVING_STATUS.FAILURE, reason } as const)); + const { promise: _promise, resolve, reject } = promiseWithResolvers(); + const promise = _promise.catch( + (reason): ProvingResult => ({ + status: PROVING_STATUS.FAILURE, + reason, + }), + ); + + const provingState = new ProvingState( + numTxs, + resolve, + reject, + globalVariables, + l1ToL2MessagesPadded, + baseParityInputs.length, + emptyTx, + messageTreeSnapshot, + newL1ToL2MessageTreeRootSiblingPath, + ); for (let i = 0; i < baseParityInputs.length; i++) { - this.enqueueJob(provingState, PROVING_JOB_TYPE.BASE_PARITY, () => - this.runBaseParityCircuit(provingState, baseParityInputs[i], i), - ); + this.enqueueBaseParityCircuit(provingState, baseParityInputs[i], i); } this.provingState = provingState; @@ -258,6 +210,7 @@ export class ProvingOrchestrator { * Cancel any further proving of the block */ public cancelBlock() { + this.queue.cancelAll(); this.provingState?.cancel(); } @@ -331,16 +284,14 @@ export class ProvingOrchestrator { if (!numPublicKernels) { // no public functions, go straight to the base rollup logger.debug(`Enqueueing base rollup for tx ${txIndex}`); - this.enqueueJob(provingState, PROVING_JOB_TYPE.BASE_ROLLUP, () => - this.runBaseRollup(provingState, BigInt(txIndex), txProvingState), - ); + this.enqueueBaseRollup(provingState, BigInt(txIndex), txProvingState); return; } // Enqueue all of the VM proving requests // Rather than handle the Kernel Tail as a special case here, we will just handle it inside executeVM for (let i = 0; i < numPublicKernels; i++) { logger.debug(`Enqueueing public VM ${i} for tx ${txIndex}`); - this.enqueueJob(provingState, PROVING_JOB_TYPE.PUBLIC_VM, () => this.executeVM(provingState, txIndex, i)); + this.enqueueVM(provingState, txIndex, i); } } @@ -350,26 +301,89 @@ export class ProvingOrchestrator { * @param jobType - The type of job to be queued * @param job - The actual job, returns a promise notifying of the job's completion */ - private enqueueJob(provingState: ProvingState | undefined, jobType: PROVING_JOB_TYPE, job: () => Promise) { + private enqueueJob( + provingState: ProvingState | undefined, + request: T, + callback: (output: ProvingRequestPublicInputs[T['type']], proof: Proof) => void | Promise, + ) { if (!provingState?.verifyState()) { - logger.debug(`Not enqueueing job, proving state invalid`); + logger.debug(`Not enqueuing job type ${ProvingRequestType[request.type]}, state no longer valid`); return; } // We use a 'safeJob'. We don't want promise rejections in the proving pool, we want to capture the error here // and reject the proving job whilst keeping the event loop free of rejections const safeJob = async () => { try { - await job(); + const timer = new Timer(); + const [publicInputs, proof] = await this.queue.prove(request); + const duration = timer.ms(); + + const inputSize = 'toBuffer' in request.inputs ? request.inputs.toBuffer().length : 0; + const outputSize = 'toBuffer' in publicInputs ? publicInputs.toBuffer().length : 0; + const circuitName = this.getCircuitNameFromRequest(request); + const stats: CircuitSimulationStats | undefined = circuitName + ? { + eventName: 'circuit-simulation', + circuitName, + duration, + inputSize, + outputSize, + } + : undefined; + + logger.debug(`Simulated ${ProvingRequestType[request.type]} circuit duration=${duration}ms`, stats); + + if (!provingState?.verifyState()) { + logger.debug(`State no longer valid, discarding result of job type ${ProvingRequestType[request.type]}`); + return; + } + + await callback(publicInputs, proof); } catch (err) { - logger.error(`Error thrown when proving job type ${PROVING_JOB_TYPE[jobType]}: ${err}`); + logger.error(`Error thrown when proving job type ${ProvingRequestType[request.type]}: ${err}`); provingState!.reject(`${err}`); } }; - const provingJob: ProvingJob = { - type: jobType, - operation: safeJob, - }; - this.jobQueue.put(provingJob); + + // let the callstack unwind before adding the job to the queue + setImmediate(safeJob); + } + + private getCircuitNameFromRequest(request: ProvingRequest): CircuitSimulationStats['circuitName'] | null { + switch (request.type) { + case ProvingRequestType.PUBLIC_VM: + return null; + case ProvingRequestType.PUBLIC_KERNEL_NON_TAIL: + switch (request.kernelType) { + case PublicKernelType.SETUP: + return 'public-kernel-setup'; + case PublicKernelType.APP_LOGIC: + return 'public-kernel-app-logic'; + case PublicKernelType.TEARDOWN: + return 'public-kernel-teardown'; + default: + return null; + } + case ProvingRequestType.PUBLIC_KERNEL_TAIL: + switch (request.kernelType) { + case PublicKernelType.TAIL: + return 'public-kernel-tail'; + default: + return null; + } + case ProvingRequestType.BASE_ROLLUP: + return 'base-rollup'; + case ProvingRequestType.MERGE_ROLLUP: + return 'merge-rollup'; + case ProvingRequestType.ROOT_ROLLUP: + return 'root-rollup'; + case ProvingRequestType.BASE_PARITY: + return 'base-parity'; + case ProvingRequestType.ROOT_PARITY: + return 'root-parity'; + default: + return null; + } } // Updates the merkle trees for a transaction. The first enqueued job for a transaction @@ -414,91 +428,77 @@ export class ProvingOrchestrator { // Executes the base rollup circuit and stored the output as intermediate state for the parent merge/root circuit // Executes the next level of merge if all inputs are available - private async runBaseRollup(provingState: ProvingState | undefined, index: bigint, tx: TxProvingState) { + private enqueueBaseRollup(provingState: ProvingState | undefined, index: bigint, tx: TxProvingState) { + if (!provingState?.verifyState()) { + logger.debug('Not running base rollup, state invalid'); + return; + } if ( !tx.baseRollupInputs.kernelData.publicInputs.end.encryptedLogsHash .toBuffer() .equals(tx.processedTx.encryptedLogs.hash()) ) { - throw new Error( + provingState.reject( `Encrypted logs hash mismatch: ${ tx.baseRollupInputs.kernelData.publicInputs.end.encryptedLogsHash } === ${Fr.fromBuffer(tx.processedTx.encryptedLogs.hash())}`, ); + return; } if ( !tx.baseRollupInputs.kernelData.publicInputs.end.unencryptedLogsHash .toBuffer() .equals(tx.processedTx.unencryptedLogs.hash()) ) { - throw new Error( + provingState.reject( `Unencrypted logs hash mismatch: ${ tx.baseRollupInputs.kernelData.publicInputs.end.unencryptedLogsHash } === ${Fr.fromBuffer(tx.processedTx.unencryptedLogs.hash())}`, ); - } - if (!provingState?.verifyState()) { - logger.debug('Not running base rollup, state invalid'); return; } - const [duration, baseRollupOutputs] = await elapsed(async () => { - const [rollupOutput, proof] = await this.prover.getBaseRollupProof(tx.baseRollupInputs); - validatePartialState(rollupOutput.end, tx.treeSnapshots); - return { rollupOutput, proof }; - }); - logger.debug(`Simulated base rollup circuit`, { - eventName: 'circuit-simulation', - circuitName: 'base-rollup', - duration, - inputSize: tx.baseRollupInputs.toBuffer().length, - outputSize: baseRollupOutputs.rollupOutput.toBuffer().length, - } satisfies CircuitSimulationStats); - if (!provingState?.verifyState()) { - logger.debug(`Discarding job as state no longer valid`); - return; - } - const currentLevel = provingState.numMergeLevels + 1n; - logger.info(`Completed base rollup at index ${index}, current level ${currentLevel}`); - this.storeAndExecuteNextMergeLevel(provingState, currentLevel, index, [ - baseRollupOutputs.rollupOutput, - baseRollupOutputs.proof, - ]); + + this.enqueueJob( + provingState, + { + inputs: tx.baseRollupInputs, + type: ProvingRequestType.BASE_ROLLUP, + }, + (publicInputs, proof) => { + validatePartialState(publicInputs.end, tx.treeSnapshots); + const currentLevel = provingState.numMergeLevels + 1n; + this.storeAndExecuteNextMergeLevel(provingState, currentLevel, index, [publicInputs, proof]); + }, + ); } // Executes the merge rollup circuit and stored the output as intermediate state for the parent merge/root circuit // Enqueues the next level of merge if all inputs are available - private async runMergeRollup( - provingState: ProvingState | undefined, + private enqueueMergeRollup( + provingState: ProvingState, level: bigint, index: bigint, mergeInputData: MergeRollupInputData, ) { - if (!provingState?.verifyState()) { - logger.debug('Not running merge rollup, state invalid'); - return; - } - const circuitInputs = createMergeRollupInputs( + const inputs = createMergeRollupInputs( [mergeInputData.inputs[0]!, mergeInputData.proofs[0]!], [mergeInputData.inputs[1]!, mergeInputData.proofs[1]!], ); - const [duration, circuitOutputs] = await elapsed(() => this.prover.getMergeRollupProof(circuitInputs)); - logger.debug(`Simulated merge rollup circuit`, { - eventName: 'circuit-simulation', - circuitName: 'merge-rollup', - duration, - inputSize: circuitInputs.toBuffer().length, - outputSize: circuitOutputs[0].toBuffer().length, - } satisfies CircuitSimulationStats); - if (!provingState?.verifyState()) { - logger.debug(`Discarding job as state no longer valid`); - return; - } - logger.info(`Completed merge rollup at level ${level}, index ${index}`); - this.storeAndExecuteNextMergeLevel(provingState, level, index, circuitOutputs); + + this.enqueueJob( + provingState, + { + type: ProvingRequestType.MERGE_ROLLUP, + inputs, + }, + (publicInputs, proof) => { + this.storeAndExecuteNextMergeLevel(provingState, level, index, [publicInputs, proof]); + }, + ); } // Executes the root rollup circuit - private async runRootRollup(provingState: ProvingState | undefined) { + private async enqueueRootRollup(provingState: ProvingState | undefined) { if (!provingState?.verifyState()) { logger.debug('Not running root rollup, state no longer valid'); return; @@ -506,7 +506,7 @@ export class ProvingOrchestrator { const mergeInputData = provingState.getMergeInputs(0); const rootParityInput = provingState.finalRootParityInput!; - const rootInput = await getRootRollupInput( + const inputs = await getRootRollupInput( mergeInputData.inputs[0]!, mergeInputData.proofs[0]!, mergeInputData.inputs[1]!, @@ -518,90 +518,67 @@ export class ProvingOrchestrator { this.db, ); - // Simulate and get proof for the root circuit - const [rootOutput, rootProof] = await this.prover.getRootRollupProof(rootInput); - - logger.info(`Completed root rollup`); - - provingState.rootRollupPublicInputs = rootOutput; - provingState.finalProof = rootProof; - - const provingResult: ProvingResult = { - status: PROVING_STATUS.SUCCESS, - }; - provingState.resolve(provingResult); + this.enqueueJob( + provingState, + { + type: ProvingRequestType.ROOT_ROLLUP, + inputs, + }, + (publicInputs, proof) => { + provingState.rootRollupPublicInputs = publicInputs; + provingState.finalProof = proof; + + const provingResult: ProvingResult = { + status: PROVING_STATUS.SUCCESS, + }; + provingState.resolve(provingResult); + }, + ); } // Executes the base parity circuit and stores the intermediate state for the root parity circuit // Enqueues the root parity circuit if all inputs are available - private async runBaseParityCircuit(provingState: ProvingState | undefined, inputs: BaseParityInputs, index: number) { - if (!provingState?.verifyState()) { - logger.debug('Not running base parity, state no longer valid'); - return; - } - const [duration, circuitOutputs] = await elapsed(async () => { - const [parityPublicInputs, proof] = await this.prover.getBaseParityProof(inputs); - return new RootParityInput(proof, parityPublicInputs); - }); - logger.debug(`Simulated base parity circuit`, { - eventName: 'circuit-simulation', - circuitName: 'base-parity', - duration, - inputSize: inputs.toBuffer().length, - outputSize: circuitOutputs.toBuffer().length, - } satisfies CircuitSimulationStats); - - if (!provingState?.verifyState()) { - logger.debug(`Discarding job as state no longer valid`); - return; - } - provingState.setRootParityInputs(circuitOutputs, index); - - if (!provingState.areRootParityInputsReady()) { - // not ready to run the root parity circuit yet - return; - } - const rootParityInputs = new RootParityInputs( - provingState.rootParityInput as Tuple, - ); - this.enqueueJob(provingState, PROVING_JOB_TYPE.ROOT_PARITY, () => - this.runRootParityCircuit(provingState, rootParityInputs), + private enqueueBaseParityCircuit(provingState: ProvingState, inputs: BaseParityInputs, index: number) { + this.enqueueJob( + provingState, + { + inputs, + type: ProvingRequestType.BASE_PARITY, + }, + (publicInputs, proof) => { + const rootInput = new RootParityInput(proof, publicInputs); + provingState.setRootParityInputs(rootInput, index); + const rootParityInputs = new RootParityInputs( + provingState.rootParityInput as Tuple, + ); + this.enqueueRootParityCircuit(provingState, rootParityInputs); + }, ); } // Runs the root parity circuit ans stored the outputs // Enqueues the root rollup proof if all inputs are available - private async runRootParityCircuit(provingState: ProvingState | undefined, inputs: RootParityInputs) { - if (!provingState?.verifyState()) { - logger.debug(`Not running root parity circuit as state is no longer valid`); - return; - } - const [duration, circuitOutputs] = await elapsed(async () => { - const [parityPublicInputs, proof] = await this.prover.getRootParityProof(inputs); - return new RootParityInput(proof, parityPublicInputs); - }); - logger.debug(`Simulated root parity circuit`, { - eventName: 'circuit-simulation', - circuitName: 'root-parity', - duration, - inputSize: inputs.toBuffer().length, - outputSize: circuitOutputs.toBuffer().length, - } satisfies CircuitSimulationStats); - - if (!provingState?.verifyState()) { - logger.debug(`Discarding job as state no longer valid`); - return; - } - provingState!.finalRootParityInput = circuitOutputs; - this.checkAndExecuteRootRollup(provingState); + private enqueueRootParityCircuit(provingState: ProvingState | undefined, inputs: RootParityInputs) { + this.enqueueJob( + provingState, + { + type: ProvingRequestType.ROOT_PARITY, + inputs, + }, + async (publicInputs, proof) => { + const rootInput = new RootParityInput(proof, publicInputs); + provingState!.finalRootParityInput = rootInput; + await this.checkAndEnqueueRootRollup(provingState); + }, + ); } - private checkAndExecuteRootRollup(provingState: ProvingState | undefined) { + private async checkAndEnqueueRootRollup(provingState: ProvingState | undefined) { if (!provingState?.isReadyForRootRollup()) { logger.debug('Not ready for root rollup'); return; } - this.enqueueJob(provingState, PROVING_JOB_TYPE.ROOT_ROLLUP, () => this.runRootRollup(provingState)); + await this.enqueueRootRollup(provingState); } /** @@ -625,12 +602,11 @@ export class ProvingOrchestrator { } if (result.mergeLevel === 0n) { - this.checkAndExecuteRootRollup(provingState); + // TODO (alexg) remove this `void` + void this.checkAndEnqueueRootRollup(provingState); } else { // onto the next merge level - this.enqueueJob(provingState, PROVING_JOB_TYPE.MERGE_ROLLUP, () => - this.runMergeRollup(provingState, result.mergeLevel, result.indexWithinMergeLevel, result.mergeInputData), - ); + this.enqueueMergeRollup(provingState, result.mergeLevel, result.indexWithinMergeLevel, result.mergeInputData); } } @@ -641,7 +617,7 @@ export class ProvingOrchestrator { * @param txIndex - The index of the transaction being proven * @param functionIndex - The index of the function/kernel being proven */ - private async executeVM(provingState: ProvingState | undefined, txIndex: number, functionIndex: number) { + private enqueueVM(provingState: ProvingState | undefined, txIndex: number, functionIndex: number) { if (!provingState?.verifyState()) { logger.debug(`Not running VM circuit as state is no longer valid`); return; @@ -653,15 +629,24 @@ export class ProvingOrchestrator { // Prove the VM if this is a kernel that requires one if (!KernelTypesWithoutFunctions.has(publicFunction.publicKernelRequest.type)) { // Just sleep for a small amount of time - await sleep(Math.random() * 10 + 10); - logger.debug(`Proven VM for function index ${functionIndex} of tx index ${txIndex}`); - } - - if (!provingState?.verifyState()) { - logger.debug(`Not continuing after VM circuit as state is no longer valid`); - return; + this.enqueueJob( + provingState, + { + type: ProvingRequestType.PUBLIC_VM, + inputs: {}, + }, + (_1, _2) => { + logger.debug(`Proven VM for function index ${functionIndex} of tx index ${txIndex}`); + this.checkAndEnqueuePublicKernel(provingState, txIndex, functionIndex); + }, + ); + } else { + this.checkAndEnqueuePublicKernel(provingState, txIndex, functionIndex); } + } + private checkAndEnqueuePublicKernel(provingState: ProvingState, txIndex: number, functionIndex: number) { + const txProvingState = provingState.getTxProvingState(txIndex); const kernelRequest = txProvingState.getNextPublicKernelFromVMProof(functionIndex, makeEmptyProof()); if (kernelRequest.code === TX_PROVING_CODE.READY) { if (kernelRequest.function === undefined) { @@ -669,9 +654,7 @@ export class ProvingOrchestrator { throw new Error(`Error occurred, public function request undefined after VM proof completed`); } logger.debug(`Enqueuing kernel from VM for tx ${txIndex}, function ${functionIndex}`); - this.enqueueJob(provingState, PROVING_JOB_TYPE.PUBLIC_KERNEL, () => - this.executePublicKernel(provingState, txIndex, functionIndex), - ); + this.enqueuePublicKernel(provingState, txIndex, functionIndex); } } @@ -682,115 +665,37 @@ export class ProvingOrchestrator { * @param txIndex - The index of the transaction being proven * @param functionIndex - The index of the function/kernel being proven */ - private async executePublicKernel(provingState: ProvingState | undefined, txIndex: number, functionIndex: number) { + private enqueuePublicKernel(provingState: ProvingState | undefined, txIndex: number, functionIndex: number) { if (!provingState?.verifyState()) { logger.debug(`Not running public kernel circuit as state is no longer valid`); return; } const txProvingState = provingState.getTxProvingState(txIndex); - const kernelRequest = txProvingState.getPublicFunctionState(functionIndex).publicKernelRequest; - - // We may need to use the public inputs produced here instead of those coming from the sequencer - const [_, proof] = - kernelRequest.type == PublicKernelType.TAIL - ? await this.prover.getPublicTailProof(kernelRequest) - : await this.prover.getPublicKernelProof(kernelRequest); - - if (!provingState?.verifyState()) { - logger.debug(`Not continuing after public kernel circuit as state is no longer valid`); - return; - } - - logger.debug(`Proven ${PublicKernelType[kernelRequest.type]} at index ${functionIndex} for tx index ${txIndex}`); - - const nextKernelRequest = txProvingState.getNextPublicKernelFromKernelProof(functionIndex, proof); - // What's the status of the next kernel? - if (nextKernelRequest.code === TX_PROVING_CODE.NOT_READY) { - // Must be waiting on a VM proof - return; - } - if (nextKernelRequest.code === TX_PROVING_CODE.COMPLETED) { - // We must have completed all public function proving, we now move to the base rollup - logger.debug(`Public functions completed for tx ${txIndex} enqueueing base rollup`); - this.enqueueJob(provingState, PROVING_JOB_TYPE.BASE_ROLLUP, () => - this.runBaseRollup(provingState, BigInt(txIndex), txProvingState), - ); - return; - } - // There must be another kernel ready to be proven - if (nextKernelRequest.function === undefined) { - // Should not be possible - throw new Error(`Error occurred, public function request undefined after kernel proof completed`); - } - logger.debug(`Enqueuing kernel from kernel for tx ${txIndex}, function ${functionIndex + 1}`); - this.enqueueJob(provingState, PROVING_JOB_TYPE.PUBLIC_KERNEL, () => - this.executePublicKernel(provingState, txIndex, functionIndex + 1), - ); - } - - /** - * Process the job queue - * Works by managing an input queue of proof requests and an active pool of proving 'jobs' - */ - private async processJobQueue() { - // Used for determining the current state of a proving job - const promiseState = (p: Promise) => { - const t = {}; - return Promise.race([p, t]).then( - v => (v === t ? 'pending' : 'fulfilled'), - () => 'rejected', - ); - }; - - // Just a short break between managing the sets of requests and active jobs - const createSleepPromise = () => - sleep(SLEEP_TIME).then(_ => { - return PROMISE_RESULT.SLEEP; - }); - - let sleepPromise = createSleepPromise(); - let promises: Promise[] = []; - while (!this.stopped) { - // first look for more work - if (this.jobQueue.length() && promises.length < this.maxConcurrentJobs) { - // more work could be available - const job = await this.jobQueue.get(); - if (job !== null) { - // a proving job, add it to the pool of outstanding jobs - promises.push(job.operation()); - } - // continue adding more work - continue; + const provingRequest = txProvingState.getPublicFunctionState(functionIndex).provingRequest; + + this.enqueueJob(provingState, provingRequest, (_, proof) => { + logger.debug(`Proven ${PublicKernelType[provingRequest.type]} at index ${functionIndex} for tx index ${txIndex}`); + const nextKernelRequest = txProvingState.getNextPublicKernelFromKernelProof(functionIndex, proof); + // What's the status of the next kernel? + if (nextKernelRequest.code === TX_PROVING_CODE.NOT_READY) { + // Must be waiting on a VM proof + return; } - // no more work to add, here we wait for any outstanding jobs to finish and/or sleep a little - try { - const ops = Promise.race(promises).then(_ => { - return PROMISE_RESULT.OPERATIONS; - }); - const result = await Promise.race([sleepPromise, ops]); - if (result === PROMISE_RESULT.SLEEP) { - // this is the sleep promise - // we simply setup the promise again and go round the loop checking for more work - sleepPromise = createSleepPromise(); - continue; - } - } catch (err) { - // We shouldn't get here as all jobs should be wrapped in a 'safeJob' meaning they don't fail! - logger.error(`Unexpected error in proving orchestrator ${err}`); + if (nextKernelRequest.code === TX_PROVING_CODE.COMPLETED) { + // We must have completed all public function proving, we now move to the base rollup + logger.debug(`Public functions completed for tx ${txIndex} enqueueing base rollup`); + this.enqueueBaseRollup(provingState, BigInt(txIndex), txProvingState); + return; } - - // one or more of the jobs completed, remove them - const pendingPromises = []; - for (const jobPromise of promises) { - const state = await promiseState(jobPromise); - if (state === 'pending') { - pendingPromises.push(jobPromise); - } + // There must be another kernel ready to be proven + if (nextKernelRequest.function === undefined) { + // Should not be possible + throw new Error(`Error occurred, public function request undefined after kernel proof completed`); } - // eslint-disable-next-line @typescript-eslint/no-floating-promises - promises = pendingPromises; - } + + this.enqueuePublicKernel(provingState, txIndex, functionIndex + 1); + }); } } diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_errors.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_errors.test.ts index 42bb5a177d01..8f23bd02b3b8 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_errors.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_errors.test.ts @@ -2,13 +2,9 @@ import { PROVING_STATUS } from '@aztec/circuit-types'; import { Fr } from '@aztec/circuits.js'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx, makeEmptyProcessedTestTx } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-errors'); describe('prover/orchestrator/errors', () => { @@ -53,7 +49,7 @@ describe('prover/orchestrator/errors', () => { const finalisedBlock = await context.orchestrator.finaliseBlock(); expect(finalisedBlock.block.number).toEqual(context.blockNumber); - }, 30_000); + }, 40_000); it('throws if adding a transaction before start', async () => { await expect( diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_failures.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_failures.test.ts index 1b9b0e28614e..d8b7f1e6c690 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_failures.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_failures.test.ts @@ -3,21 +3,21 @@ import { createDebugLogger } from '@aztec/foundation/log'; import { WASMSimulator } from '@aztec/simulator'; import { jest } from '@jest/globals'; -import { type MemDown, default as memdown } from 'memdown'; import { makeEmptyProcessedTestTx } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; +import { CircuitProverAgent } from '../prover-pool/circuit-prover-agent.js'; +import { ProverPool } from '../prover-pool/prover-pool.js'; import { type CircuitProver } from '../prover/index.js'; import { TestCircuitProver } from '../prover/test_circuit_prover.js'; import { ProvingOrchestrator } from './orchestrator.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-failures'); describe('prover/orchestrator/failures', () => { let context: TestContext; let orchestrator: ProvingOrchestrator; + let proverPool: ProverPool; beforeEach(async () => { context = await TestContext.new(logger); @@ -32,11 +32,13 @@ describe('prover/orchestrator/failures', () => { beforeEach(async () => { mockProver = new TestCircuitProver(new WASMSimulator()); - orchestrator = await ProvingOrchestrator.new(context.actualDb, mockProver); + proverPool = new ProverPool(1, i => new CircuitProverAgent(mockProver, 10, `${i}`)); + orchestrator = new ProvingOrchestrator(context.actualDb, proverPool.queue); + await proverPool.start(); }); afterEach(async () => { - await orchestrator.stop(); + await proverPool.stop(); }); it.each([ diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts index bfffd42382fb..898e3aab9bb5 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_lifecycle.test.ts @@ -4,13 +4,9 @@ import { fr } from '@aztec/circuits.js/testing'; import { range } from '@aztec/foundation/array'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, makeGlobals } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-lifecycle'); describe('prover/orchestrator/lifecycle', () => { @@ -79,7 +75,7 @@ describe('prover/orchestrator/lifecycle', () => { const finalisedBlock = await context.orchestrator.finaliseBlock(); expect(finalisedBlock.block.number).toEqual(101); - }, 20000); + }, 40000); it('automatically cancels an incomplete block when starting a new one', async () => { const txs1 = await Promise.all([ diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks.test.ts index 7191c868466e..53a6959b710e 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks.test.ts @@ -4,13 +4,9 @@ import { fr } from '@aztec/circuits.js/testing'; import { range } from '@aztec/foundation/array'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx, makeEmptyProcessedTestTx } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-mixed-blocks'); describe('prover/orchestrator/mixed-blocks', () => { diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks_2.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks_2.test.ts index 63a32b713a7d..1233bbe30f72 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks_2.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_mixed_blocks_2.test.ts @@ -7,13 +7,9 @@ import { createDebugLogger } from '@aztec/foundation/log'; import { openTmpStore } from '@aztec/kv-store/utils'; import { type MerkleTreeOperations, MerkleTrees } from '@aztec/world-state'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, updateExpectedTreesFromTxs } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-mixed-blocks-2'); describe('prover/orchestrator/mixed-blocks', () => { diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts index e1c18cf3dacb..6e1054032b11 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_multi_public_functions.test.ts @@ -2,13 +2,9 @@ import { PROVING_STATUS, mockTx } from '@aztec/circuit-types'; import { times } from '@aztec/foundation/collection'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeEmptyProcessedTestTx } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-multi-public-functions'); describe('prover/orchestrator/public-functions', () => { diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_multiple_blocks.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_multiple_blocks.test.ts index a84fa8079294..1525aeff7409 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_multiple_blocks.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_multiple_blocks.test.ts @@ -1,13 +1,9 @@ import { PROVING_STATUS } from '@aztec/circuit-types'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, makeGlobals } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-multi-blocks'); describe('prover/orchestrator/multi-block', () => { diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_public_functions.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_public_functions.test.ts index 9409ba9d41e4..2f7679b973bb 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_public_functions.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_public_functions.test.ts @@ -1,13 +1,9 @@ import { PROVING_STATUS, mockTx } from '@aztec/circuit-types'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeEmptyProcessedTestTx } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-public-functions'); describe('prover/orchestrator/public-functions', () => { diff --git a/yarn-project/prover-client/src/orchestrator/orchestrator_single_blocks.test.ts b/yarn-project/prover-client/src/orchestrator/orchestrator_single_blocks.test.ts index 96f8f64b3df6..959a6105c6a1 100644 --- a/yarn-project/prover-client/src/orchestrator/orchestrator_single_blocks.test.ts +++ b/yarn-project/prover-client/src/orchestrator/orchestrator_single_blocks.test.ts @@ -7,13 +7,9 @@ import { sleep } from '@aztec/foundation/sleep'; import { openTmpStore } from '@aztec/kv-store/utils'; import { type MerkleTreeOperations, MerkleTrees } from '@aztec/world-state'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, updateExpectedTreesFromTxs } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:orchestrator-single-blocks'); describe('prover/orchestrator/blocks', () => { diff --git a/yarn-project/prover-client/src/orchestrator/tx-proving-state.ts b/yarn-project/prover-client/src/orchestrator/tx-proving-state.ts index ab8802d2fccc..73339d00d6d5 100644 --- a/yarn-project/prover-client/src/orchestrator/tx-proving-state.ts +++ b/yarn-project/prover-client/src/orchestrator/tx-proving-state.ts @@ -1,6 +1,8 @@ import { type MerkleTreeId, type ProcessedTx, type PublicKernelRequest, PublicKernelType } from '@aztec/circuit-types'; import { type AppendOnlyTreeSnapshot, type BaseRollupInputs, type Proof } from '@aztec/circuits.js'; +import { type ProvingRequest, ProvingRequestType } from '../prover-pool/proving-request.js'; + export enum TX_PROVING_CODE { NOT_READY, READY, @@ -12,6 +14,7 @@ export type PublicFunction = { previousProofType: PublicKernelType; previousKernelProof: Proof | undefined; publicKernelRequest: PublicKernelRequest; + provingRequest: ProvingRequest; }; // Type encapsulating the instruction to the orchestrator as to what @@ -37,11 +40,24 @@ export class TxProvingState { let previousKernelProof: Proof | undefined = processedTx.proof; let previousProofType = PublicKernelType.NON_PUBLIC; for (const kernelRequest of processedTx.publicKernelRequests) { + const provingRequest: ProvingRequest = + kernelRequest.type === PublicKernelType.TAIL + ? { + type: ProvingRequestType.PUBLIC_KERNEL_TAIL, + kernelType: kernelRequest.type, + inputs: kernelRequest.inputs, + } + : { + type: ProvingRequestType.PUBLIC_KERNEL_NON_TAIL, + kernelType: kernelRequest.type, + inputs: kernelRequest.inputs, + }; const publicFunction: PublicFunction = { vmProof: undefined, previousProofType, previousKernelProof, publicKernelRequest: kernelRequest, + provingRequest, }; this.publicFunctions.push(publicFunction); previousKernelProof = undefined; diff --git a/yarn-project/prover-client/src/prover-pool/circuit-prover-agent.test.ts b/yarn-project/prover-client/src/prover-pool/circuit-prover-agent.test.ts new file mode 100644 index 000000000000..a5e8b224cc9d --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/circuit-prover-agent.test.ts @@ -0,0 +1,85 @@ +import { makeBaseParityInputs, makeParityPublicInputs, makeProof } from '@aztec/circuits.js/testing'; + +import { type MockProxy, mock } from 'jest-mock-extended'; + +import { type CircuitProver } from '../prover/interface.js'; +import { CircuitProverAgent } from './circuit-prover-agent.js'; +import { MemoryProvingQueue } from './memory-proving-queue.js'; +import { type ProvingAgent } from './prover-agent.js'; +import { type ProvingQueue } from './proving-queue.js'; +import { ProvingRequestType } from './proving-request.js'; + +describe('LocalProvingAgent', () => { + let queue: ProvingQueue; + let agent: ProvingAgent; + let prover: MockProxy; + + beforeEach(() => { + prover = mock(); + queue = new MemoryProvingQueue(); + agent = new CircuitProverAgent(prover); + }); + + beforeEach(() => { + agent.start(queue); + }); + + afterEach(async () => { + await agent.stop(); + }); + + it('takes jobs from the queue', async () => { + const publicInputs = makeParityPublicInputs(); + const proof = makeProof(); + prover.getBaseParityProof.mockResolvedValue([publicInputs, proof]); + + const inputs = makeBaseParityInputs(); + const promise = queue.prove({ + type: ProvingRequestType.BASE_PARITY, + inputs, + }); + + await expect(promise).resolves.toEqual([publicInputs, proof]); + expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs); + }); + + it('reports errors', async () => { + const error = new Error('test error'); + prover.getBaseParityProof.mockRejectedValue(error); + + const inputs = makeBaseParityInputs(); + const promise = queue.prove({ + type: ProvingRequestType.BASE_PARITY, + inputs, + }); + + await expect(promise).rejects.toEqual(error); + expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs); + }); + + it('continues to process jobs', async () => { + const publicInputs = makeParityPublicInputs(); + const proof = makeProof(); + prover.getBaseParityProof.mockResolvedValue([publicInputs, proof]); + + const inputs = makeBaseParityInputs(); + const promise1 = queue.prove({ + type: ProvingRequestType.BASE_PARITY, + inputs, + }); + + await expect(promise1).resolves.toEqual([publicInputs, proof]); + + const inputs2 = makeBaseParityInputs(); + const promise2 = queue.prove({ + type: ProvingRequestType.BASE_PARITY, + inputs: inputs2, + }); + + await expect(promise2).resolves.toEqual([publicInputs, proof]); + + expect(prover.getBaseParityProof).toHaveBeenCalledTimes(2); + expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs); + expect(prover.getBaseParityProof).toHaveBeenCalledWith(inputs2); + }); +}); diff --git a/yarn-project/prover-client/src/prover-pool/circuit-prover-agent.ts b/yarn-project/prover-client/src/prover-pool/circuit-prover-agent.ts new file mode 100644 index 000000000000..76553512eb24 --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/circuit-prover-agent.ts @@ -0,0 +1,107 @@ +import { makeEmptyProof } from '@aztec/circuits.js'; +import { createDebugLogger } from '@aztec/foundation/log'; +import { RunningPromise } from '@aztec/foundation/running-promise'; +import { elapsed } from '@aztec/foundation/timer'; + +import { type CircuitProver } from '../prover/interface.js'; +import { type ProvingAgent } from './prover-agent.js'; +import { type ProvingQueueConsumer } from './proving-queue.js'; +import { type ProvingRequest, type ProvingRequestResult, ProvingRequestType } from './proving-request.js'; + +export class CircuitProverAgent implements ProvingAgent { + private runningPromise?: RunningPromise; + + constructor( + /** The prover implementation to defer jobs to */ + private prover: CircuitProver, + /** How long to wait between jobs */ + private intervalMs = 10, + /** A name for this agent (if there are multiple agents running) */ + name = '', + private log = createDebugLogger('aztec:prover-client:prover-pool:agent' + (name ? `:${name}` : '')), + ) {} + + start(queue: ProvingQueueConsumer): void { + if (this.runningPromise) { + throw new Error('Agent is already running'); + } + + this.runningPromise = new RunningPromise(async () => { + const job = await queue.getProvingJob(); + if (!job) { + return; + } + + try { + const [time, result] = await elapsed(() => this.work(job.request)); + await queue.resolveProvingJob(job.id, result); + this.log.info( + `Processed proving job id=${job.id} type=${ProvingRequestType[job.request.type]} duration=${time}ms`, + ); + } catch (err) { + this.log.error( + `Error processing proving job id=${job.id} type=${ProvingRequestType[job.request.type]}: ${err}`, + ); + await queue.rejectProvingJob(job.id, err as Error); + } + }, this.intervalMs); + + this.runningPromise.start(); + } + + async stop(): Promise { + if (!this.runningPromise) { + throw new Error('Agent is not running'); + } + + await this.runningPromise.stop(); + this.runningPromise = undefined; + } + + private work(request: ProvingRequest): Promise> { + const { type, inputs } = request; + switch (type) { + case ProvingRequestType.PUBLIC_VM: { + return Promise.resolve([{}, makeEmptyProof()] as const); + } + + case ProvingRequestType.PUBLIC_KERNEL_NON_TAIL: { + return this.prover.getPublicKernelProof({ + type: request.kernelType, + inputs, + }); + } + + case ProvingRequestType.PUBLIC_KERNEL_TAIL: { + return this.prover.getPublicTailProof({ + type: request.kernelType, + inputs, + }); + } + + case ProvingRequestType.BASE_ROLLUP: { + return this.prover.getBaseRollupProof(inputs); + } + + case ProvingRequestType.MERGE_ROLLUP: { + return this.prover.getMergeRollupProof(inputs); + } + + case ProvingRequestType.ROOT_ROLLUP: { + return this.prover.getRootRollupProof(inputs); + } + + case ProvingRequestType.BASE_PARITY: { + return this.prover.getBaseParityProof(inputs); + } + + case ProvingRequestType.ROOT_PARITY: { + return this.prover.getRootParityProof(inputs); + } + + default: { + return Promise.reject(new Error(`Invalid proof request type: ${type}`)); + } + } + } +} diff --git a/yarn-project/prover-client/src/prover-pool/memory-proving-queue.test.ts b/yarn-project/prover-client/src/prover-pool/memory-proving-queue.test.ts new file mode 100644 index 000000000000..cea156acfc65 --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/memory-proving-queue.test.ts @@ -0,0 +1,70 @@ +import { + makeBaseParityInputs, + makeBaseRollupInputs, + makeParityPublicInputs, + makeProof, +} from '@aztec/circuits.js/testing'; + +import { MemoryProvingQueue } from './memory-proving-queue.js'; +import { type ProvingQueue } from './proving-queue.js'; +import { ProvingRequestType } from './proving-request.js'; + +describe('MemoryProvingQueue', () => { + let queue: ProvingQueue; + + beforeEach(() => { + queue = new MemoryProvingQueue(); + }); + + it('returns jobs in order', async () => { + void queue.prove({ + type: ProvingRequestType.BASE_PARITY, + inputs: makeBaseParityInputs(), + }); + + void queue.prove({ + type: ProvingRequestType.BASE_ROLLUP, + inputs: makeBaseRollupInputs(), + }); + + const job1 = await queue.getProvingJob(); + expect(job1?.request.type).toEqual(ProvingRequestType.BASE_PARITY); + + const job2 = await queue.getProvingJob(); + expect(job2?.request.type).toEqual(ProvingRequestType.BASE_ROLLUP); + }); + + it('returns null when no jobs are available', async () => { + await expect(queue.getProvingJob({ timeoutSec: 0 })).resolves.toBeNull(); + }); + + it('notifies of completion', async () => { + const inputs = makeBaseParityInputs(); + const promise = queue.prove({ + inputs, + type: ProvingRequestType.BASE_PARITY, + }); + + const job = await queue.getProvingJob(); + expect(job?.request.inputs).toEqual(inputs); + + const publicInputs = makeParityPublicInputs(); + const proof = makeProof(); + await queue.resolveProvingJob(job!.id, [publicInputs, proof]); + await expect(promise).resolves.toEqual([publicInputs, proof]); + }); + + it('notifies of errors', async () => { + const inputs = makeBaseParityInputs(); + const promise = queue.prove({ + inputs, + type: ProvingRequestType.BASE_PARITY, + }); + const job = await queue.getProvingJob(); + expect(job?.request.inputs).toEqual(inputs); + + const error = new Error('test error'); + await queue.rejectProvingJob(job!.id, error); + await expect(promise).rejects.toEqual(error); + }); +}); diff --git a/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts b/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts new file mode 100644 index 000000000000..155b548ff2e1 --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/memory-proving-queue.ts @@ -0,0 +1,86 @@ +import { TimeoutError } from '@aztec/foundation/error'; +import { MemoryFifo } from '@aztec/foundation/fifo'; +import { createDebugLogger } from '@aztec/foundation/log'; +import { type PromiseWithResolvers, promiseWithResolvers } from '@aztec/foundation/promise'; + +import { type ProvingJob, type ProvingQueue } from './proving-queue.js'; +import { type ProvingRequest, type ProvingRequestResult, ProvingRequestType } from './proving-request.js'; + +type ProvingJobWithResolvers = { + id: string; + request: T; +} & PromiseWithResolvers>; + +export class MemoryProvingQueue implements ProvingQueue { + private jobId = 0; + private log = createDebugLogger('aztec:prover-client:prover-pool:queue'); + private queue = new MemoryFifo(); + private jobsInProgress = new Map(); + + async getProvingJob({ timeoutSec = 1 } = {}): Promise | null> { + try { + const job = await this.queue.get(timeoutSec); + if (!job) { + return null; + } + + this.jobsInProgress.set(job.id, job); + return { + id: job.id, + request: job.request, + }; + } catch (err) { + if (err instanceof TimeoutError) { + return null; + } + + throw err; + } + } + + resolveProvingJob(jobId: string, result: ProvingRequestResult): Promise { + const job = this.jobsInProgress.get(jobId); + if (!job) { + return Promise.reject(new Error('Job not found')); + } + + this.jobsInProgress.delete(jobId); + job.resolve(result); + return Promise.resolve(); + } + + rejectProvingJob(jobId: string, err: any): Promise { + const job = this.jobsInProgress.get(jobId); + if (!job) { + return Promise.reject(new Error('Job not found')); + } + + this.jobsInProgress.delete(jobId); + job.reject(err); + return Promise.resolve(); + } + + prove(request: T): Promise> { + const { promise, resolve, reject } = promiseWithResolvers>(); + const item: ProvingJobWithResolvers = { + id: String(this.jobId++), + request, + promise, + resolve, + reject, + }; + + this.log.info(`Adding ${ProvingRequestType[request.type]} proving job to queue`); + // TODO (alexg) remove the `any` + if (!this.queue.put(item as any)) { + throw new Error(); + } + + return promise; + } + + cancelAll(): void { + this.queue.cancel(); + this.queue = new MemoryFifo(); + } +} diff --git a/yarn-project/prover-client/src/prover-pool/prover-agent.ts b/yarn-project/prover-client/src/prover-pool/prover-agent.ts new file mode 100644 index 000000000000..6d408e3a0743 --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/prover-agent.ts @@ -0,0 +1,15 @@ +import { type ProvingQueueConsumer } from './proving-queue.js'; + +/** An agent that reads proving jobs from the queue, creates the proof and submits back the result */ +export interface ProvingAgent { + /** + * Starts the agent to read proving jobs from the queue. + * @param queue - The queue to read proving jobs from. + */ + start(queue: ProvingQueueConsumer): void; + + /** + * Stops the agent. Does nothing if the agent is not running. + */ + stop(): Promise; +} diff --git a/yarn-project/prover-client/src/prover-pool/prover-pool.ts b/yarn-project/prover-client/src/prover-pool/prover-pool.ts new file mode 100644 index 000000000000..defeed61ce7e --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/prover-pool.ts @@ -0,0 +1,47 @@ +import { MemoryProvingQueue } from './memory-proving-queue.js'; +import { type ProvingAgent } from './prover-agent.js'; +import { type ProvingQueue } from './proving-queue.js'; + +/** + * Utility class that spawns N prover agents all connected to the same queue + */ +export class ProverPool { + private agents: ProvingAgent[] = []; + private running = false; + + constructor( + private size: number, + private agentFactory: (i: number) => ProvingAgent | Promise, + public readonly queue: ProvingQueue = new MemoryProvingQueue(), + ) {} + + async start(): Promise { + if (this.running) { + throw new Error('Prover pool is already running'); + } + + // lock the pool state here since creating agents is async + this.running = true; + + // handle start, stop, start cycles by only creating as many agents as were requested + for (let i = this.agents.length; i < this.size; i++) { + this.agents.push(await this.agentFactory(i)); + } + + for (const agent of this.agents) { + agent.start(this.queue); + } + } + + async stop(): Promise { + if (!this.running) { + throw new Error('Prover pool is not running'); + } + + for (const agent of this.agents) { + await agent.stop(); + } + + this.running = false; + } +} diff --git a/yarn-project/prover-client/src/prover-pool/proving-queue.ts b/yarn-project/prover-client/src/prover-pool/proving-queue.ts new file mode 100644 index 000000000000..3ab8b0153453 --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/proving-queue.ts @@ -0,0 +1,23 @@ +import type { ProvingRequest, ProvingRequestResult, ProvingRequestType } from './proving-request.js'; + +export type GetJobOptions = { + timeoutSec?: number; +}; + +export type ProvingJob = { + id: string; + request: T; +}; + +export interface ProvingRequestProducer { + prove(request: T): Promise>; + cancelAll(): void; +} + +export interface ProvingQueueConsumer { + getProvingJob(options?: GetJobOptions): Promise | null>; + resolveProvingJob(jobId: string, result: ProvingRequestResult): Promise; + rejectProvingJob(jobId: string, reason: Error): Promise; +} + +export interface ProvingQueue extends ProvingQueueConsumer, ProvingRequestProducer {} diff --git a/yarn-project/prover-client/src/prover-pool/proving-request.ts b/yarn-project/prover-client/src/prover-pool/proving-request.ts new file mode 100644 index 000000000000..ef98e7e18cd2 --- /dev/null +++ b/yarn-project/prover-client/src/prover-pool/proving-request.ts @@ -0,0 +1,81 @@ +import { type PublicKernelNonTailRequest, type PublicKernelTailRequest } from '@aztec/circuit-types'; +import { + type BaseOrMergeRollupPublicInputs, + type BaseParityInputs, + type BaseRollupInputs, + type KernelCircuitPublicInputs, + type MergeRollupInputs, + type ParityPublicInputs, + type Proof, + type PublicKernelCircuitPublicInputs, + type RootParityInputs, + type RootRollupInputs, + type RootRollupPublicInputs, +} from '@aztec/circuits.js'; + +export enum ProvingRequestType { + PUBLIC_VM, + + PUBLIC_KERNEL_NON_TAIL, + PUBLIC_KERNEL_TAIL, + + BASE_ROLLUP, + MERGE_ROLLUP, + ROOT_ROLLUP, + + BASE_PARITY, + ROOT_PARITY, +} + +export type ProvingRequest = + | { + type: ProvingRequestType.PUBLIC_VM; + // prefer object over unknown so that we can run "in" checks, e.g. `'toBuffer' in request.inputs` + inputs: object; + } + | { + type: ProvingRequestType.PUBLIC_KERNEL_NON_TAIL; + kernelType: PublicKernelNonTailRequest['type']; + inputs: PublicKernelNonTailRequest['inputs']; + } + | { + type: ProvingRequestType.PUBLIC_KERNEL_TAIL; + kernelType: PublicKernelTailRequest['type']; + inputs: PublicKernelTailRequest['inputs']; + } + | { + type: ProvingRequestType.BASE_PARITY; + inputs: BaseParityInputs; + } + | { + type: ProvingRequestType.ROOT_PARITY; + inputs: RootParityInputs; + } + | { + type: ProvingRequestType.BASE_ROLLUP; + inputs: BaseRollupInputs; + } + | { + type: ProvingRequestType.MERGE_ROLLUP; + inputs: MergeRollupInputs; + } + | { + type: ProvingRequestType.ROOT_ROLLUP; + inputs: RootRollupInputs; + }; + +export type ProvingRequestPublicInputs = { + [ProvingRequestType.PUBLIC_VM]: object; + + [ProvingRequestType.PUBLIC_KERNEL_NON_TAIL]: PublicKernelCircuitPublicInputs; + [ProvingRequestType.PUBLIC_KERNEL_TAIL]: KernelCircuitPublicInputs; + + [ProvingRequestType.BASE_ROLLUP]: BaseOrMergeRollupPublicInputs; + [ProvingRequestType.MERGE_ROLLUP]: BaseOrMergeRollupPublicInputs; + [ProvingRequestType.ROOT_ROLLUP]: RootRollupPublicInputs; + + [ProvingRequestType.BASE_PARITY]: ParityPublicInputs; + [ProvingRequestType.ROOT_PARITY]: ParityPublicInputs; +}; + +export type ProvingRequestResult = [ProvingRequestPublicInputs[T], Proof]; diff --git a/yarn-project/prover-client/src/prover/bb_prover_base_rollup.test.ts b/yarn-project/prover-client/src/prover/bb_prover_base_rollup.test.ts index bd010508498c..4b1e0337337c 100644 --- a/yarn-project/prover-client/src/prover/bb_prover_base_rollup.test.ts +++ b/yarn-project/prover-client/src/prover/bb_prover_base_rollup.test.ts @@ -1,14 +1,10 @@ import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { makeBloatedProcessedTx } from '../mocks/fixtures.js'; import { TestContext } from '../mocks/test_context.js'; import { buildBaseRollupInput } from '../orchestrator/block-building-helpers.js'; import { BBNativeRollupProver, type BBProverConfig } from './bb_prover.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:bb-prover-base-rollup'); describe('prover/bb_prover/base-rollup', () => { @@ -19,7 +15,7 @@ describe('prover/bb_prover/base-rollup', () => { bbConfig.circuitFilter = ['BaseRollupArtifact']; return BBNativeRollupProver.new(bbConfig); }; - context = await TestContext.new(logger, buildProver); + context = await TestContext.new(logger, 1, buildProver); }, 60_000); afterAll(async () => { diff --git a/yarn-project/prover-client/src/prover/bb_prover_full_rollup.test.ts b/yarn-project/prover-client/src/prover/bb_prover_full_rollup.test.ts index 3e77bdb45f8d..5d8fbc7b8ddc 100644 --- a/yarn-project/prover-client/src/prover/bb_prover_full_rollup.test.ts +++ b/yarn-project/prover-client/src/prover/bb_prover_full_rollup.test.ts @@ -3,20 +3,16 @@ import { Fr, Header } from '@aztec/circuits.js'; import { times } from '@aztec/foundation/collection'; import { createDebugLogger } from '@aztec/foundation/log'; -import { type MemDown, default as memdown } from 'memdown'; - import { TestContext } from '../mocks/test_context.js'; import { BBNativeRollupProver } from './bb_prover.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:bb-prover-full-rollup'); describe('prover/bb_prover/full-rollup', () => { let context: TestContext; beforeAll(async () => { - context = await TestContext.new(logger, BBNativeRollupProver.new); + context = await TestContext.new(logger, 1, BBNativeRollupProver.new); }, 60_000); afterAll(async () => { @@ -56,7 +52,5 @@ describe('prover/bb_prover/full-rollup', () => { const blockResult = await context.orchestrator.finaliseBlock(); await expect(context.prover.verifyProof('RootRollupArtifact', blockResult.proof)).resolves.not.toThrow(); - - await context.orchestrator.stop(); }, 600_000); }); diff --git a/yarn-project/prover-client/src/prover/bb_prover_parity.test.ts b/yarn-project/prover-client/src/prover/bb_prover_parity.test.ts index 58fafabdf3d2..8f112fd99882 100644 --- a/yarn-project/prover-client/src/prover/bb_prover_parity.test.ts +++ b/yarn-project/prover-client/src/prover/bb_prover_parity.test.ts @@ -10,13 +10,9 @@ import { makeTuple } from '@aztec/foundation/array'; import { createDebugLogger } from '@aztec/foundation/log'; import { type Tuple } from '@aztec/foundation/serialize'; -import { type MemDown, default as memdown } from 'memdown'; - import { TestContext } from '../mocks/test_context.js'; import { BBNativeRollupProver, type BBProverConfig } from './bb_prover.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:bb-prover-parity'); describe('prover/bb_prover/parity', () => { @@ -27,7 +23,7 @@ describe('prover/bb_prover/parity', () => { bbConfig.circuitFilter = ['BaseParityArtifact', 'RootParityArtifact']; return BBNativeRollupProver.new(bbConfig); }; - context = await TestContext.new(logger, buildProver); + context = await TestContext.new(logger, 1, buildProver); }, 60_000); afterAll(async () => { diff --git a/yarn-project/prover-client/src/prover/bb_prover_public_kernel.test.ts b/yarn-project/prover-client/src/prover/bb_prover_public_kernel.test.ts index d25675c41e94..17bb3e55a422 100644 --- a/yarn-project/prover-client/src/prover/bb_prover_public_kernel.test.ts +++ b/yarn-project/prover-client/src/prover/bb_prover_public_kernel.test.ts @@ -3,13 +3,9 @@ import { type Proof, makeEmptyProof } from '@aztec/circuits.js'; import { createDebugLogger } from '@aztec/foundation/log'; import { type ServerProtocolArtifact } from '@aztec/noir-protocol-circuits-types'; -import { type MemDown, default as memdown } from 'memdown'; - import { TestContext } from '../mocks/test_context.js'; import { BBNativeRollupProver, type BBProverConfig } from './bb_prover.js'; -export const createMemDown = () => (memdown as any)() as MemDown; - const logger = createDebugLogger('aztec:bb-prover-public-kernel'); describe('prover/bb_prover/public-kernel', () => { @@ -25,7 +21,7 @@ describe('prover/bb_prover/public-kernel', () => { ]; return BBNativeRollupProver.new(bbConfig); }; - context = await TestContext.new(logger, buildProver); + context = await TestContext.new(logger, 1, buildProver); }, 60_000); afterAll(async () => { diff --git a/yarn-project/prover-client/src/tx-prover/tx-prover.ts b/yarn-project/prover-client/src/tx-prover/tx-prover.ts index daa7259f65ca..14f36024a39c 100644 --- a/yarn-project/prover-client/src/tx-prover/tx-prover.ts +++ b/yarn-project/prover-client/src/tx-prover/tx-prover.ts @@ -7,6 +7,8 @@ import { type WorldStateSynchronizer } from '@aztec/world-state'; import { type ProverConfig } from '../config.js'; import { type VerificationKeys, getVerificationKeys } from '../mocks/verification_keys.js'; import { ProvingOrchestrator } from '../orchestrator/orchestrator.js'; +import { CircuitProverAgent } from '../prover-pool/circuit-prover-agent.js'; +import { ProverPool } from '../prover-pool/prover-pool.js'; import { TestCircuitProver } from '../prover/test_circuit_prover.js'; /** @@ -14,29 +16,35 @@ import { TestCircuitProver } from '../prover/test_circuit_prover.js'; */ export class TxProver implements ProverClient { private orchestrator: ProvingOrchestrator; + private proverPool: ProverPool; + constructor( private worldStateSynchronizer: WorldStateSynchronizer, simulationProvider: SimulationProvider, protected vks: VerificationKeys, + agentCount = 4, + agentPollIntervalMS = 10, ) { - this.orchestrator = new ProvingOrchestrator( - worldStateSynchronizer.getLatest(), - new TestCircuitProver(simulationProvider), + this.proverPool = new ProverPool( + agentCount, + i => new CircuitProverAgent(new TestCircuitProver(simulationProvider), agentPollIntervalMS, `${i}`), ); + + this.orchestrator = new ProvingOrchestrator(worldStateSynchronizer.getLatest(), this.proverPool.queue); } /** * Starts the prover instance */ - public start() { - return this.orchestrator.start(); + public async start() { + await this.proverPool.start(); } /** * Stops the prover instance */ public async stop() { - await this.orchestrator.stop(); + await this.proverPool.stop(); } /**