diff --git a/yarn-project/p2p/src/msg_validators/tx_validator/allowed_public_setup.ts b/yarn-project/p2p/src/msg_validators/tx_validator/allowed_public_setup.ts index 6f536c75f09d..6e78567c039c 100644 --- a/yarn-project/p2p/src/msg_validators/tx_validator/allowed_public_setup.ts +++ b/yarn-project/p2p/src/msg_validators/tx_validator/allowed_public_setup.ts @@ -1,9 +1,24 @@ import { TokenContractArtifact } from '@aztec/noir-contracts.js/Token'; import { ProtocolContractAddress } from '@aztec/protocol-contracts'; -import { FunctionSelector } from '@aztec/stdlib/abi'; +import { AuthRegistryArtifact } from '@aztec/protocol-contracts/auth-registry'; +import { FeeJuiceArtifact } from '@aztec/protocol-contracts/fee-juice'; +import { FunctionSelector, countArgumentsSize } from '@aztec/stdlib/abi'; +import type { ContractArtifact, FunctionAbi } from '@aztec/stdlib/abi'; import { getContractClassFromArtifact } from '@aztec/stdlib/contract'; import type { AllowedElement } from '@aztec/stdlib/interfaces/server'; +/** Returns the expected calldata length for a function: 1 (selector) + arguments size. */ +function getCalldataLength(artifact: ContractArtifact, functionName: string): number { + const allFunctions: FunctionAbi[] = (artifact.functions as FunctionAbi[]).concat( + artifact.nonDispatchPublicFunctions || [], + ); + const fn = allFunctions.find(f => f.name === functionName); + if (!fn) { + throw new Error(`Unknown function ${functionName} in artifact ${artifact.name}`); + } + return 1 + countArgumentsSize(fn); +} + let defaultAllowedSetupFunctions: AllowedElement[] | undefined; /** Returns the default list of functions allowed to run in the setup phase of a transaction. */ @@ -22,6 +37,7 @@ export async function getDefaultAllowedSetupFunctions(): Promise { }); }); + describe('calldataLength validation', () => { + const expectedLength = 4; // 1 selector + 3 args + let calldataContract: AztecAddress; + let calldataSelector: FunctionSelector; + let calldataClassId: Fr; + + beforeEach(() => { + calldataContract = makeAztecAddress(70); + calldataSelector = makeSelector(70); + calldataClassId = Fr.random(); + + txValidator = new PhasesTxValidator( + contractDataSource, + [ + { + address: calldataContract, + selector: calldataSelector, + calldataLength: expectedLength, + }, + { + classId: calldataClassId, + selector: calldataSelector, + calldataLength: expectedLength, + }, + ], + timestamp, + ); + }); + + it('allows address entry with correct calldata length', async () => { + const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 }); + await patchNonRevertibleFn(tx, 0, { + address: calldataContract, + selector: calldataSelector, + args: [Fr.random(), Fr.random(), Fr.random()], + }); + + await expectValid(tx); + }); + + it('rejects address entry with too short calldata', async () => { + const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 }); + await patchNonRevertibleFn(tx, 0, { + address: calldataContract, + selector: calldataSelector, + args: [Fr.random()], + }); + + await expectInvalid(tx, TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH); + }); + + it('rejects address entry with too long calldata', async () => { + const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 }); + await patchNonRevertibleFn(tx, 0, { + address: calldataContract, + selector: calldataSelector, + args: [Fr.random(), Fr.random(), Fr.random(), Fr.random(), Fr.random()], + }); + + await expectInvalid(tx, TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH); + }); + + it('rejects class entry with wrong calldata length', async () => { + const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 }); + const address = await patchNonRevertibleFn(tx, 0, { + selector: calldataSelector, + args: [Fr.random()], + }); + + contractDataSource.getContract.mockImplementationOnce((contractAddress, atTimestamp) => { + if (timestamp !== atTimestamp) { + throw new Error('Unexpected timestamp'); + } + if (address.equals(contractAddress)) { + return Promise.resolve({ + currentContractClassId: calldataClassId, + originalContractClassId: Fr.random(), + } as any); + } + return Promise.resolve(undefined); + }); + + await expectInvalid(tx, TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH); + }); + + it('allows any calldata length when calldataLength is not set', async () => { + txValidator = new PhasesTxValidator( + contractDataSource, + [ + { + address: calldataContract, + selector: calldataSelector, + }, + ], + timestamp, + ); + + const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 }); + await patchNonRevertibleFn(tx, 0, { + address: calldataContract, + selector: calldataSelector, + args: [Fr.random(), Fr.random(), Fr.random(), Fr.random(), Fr.random(), Fr.random()], + }); + + await expectValid(tx); + }); + }); + describe('rejectNullMsgSender validation', () => { const nullMsgSender = AztecAddress.fromBigInt(NULL_MSG_SENDER_CONTRACT_ADDRESS); let rejectNullContract: AztecAddress; diff --git a/yarn-project/p2p/src/msg_validators/tx_validator/phases_validator.ts b/yarn-project/p2p/src/msg_validators/tx_validator/phases_validator.ts index 5a3fcf018b43..9de8f7ef19e2 100644 --- a/yarn-project/p2p/src/msg_validators/tx_validator/phases_validator.ts +++ b/yarn-project/p2p/src/msg_validators/tx_validator/phases_validator.ts @@ -11,6 +11,7 @@ import { TX_ERROR_SETUP_FUNCTION_UNKNOWN_CONTRACT, TX_ERROR_SETUP_NULL_MSG_SENDER, TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER, + TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH, Tx, TxExecutionPhase, type TxValidationResult, @@ -88,6 +89,9 @@ export class PhasesTxValidator implements TxValidator { for (const entry of allowList) { if ('address' in entry) { if (contractAddress.equals(entry.address) && entry.selector.equals(functionSelector)) { + if (entry.calldataLength !== undefined && publicCall.calldata.length !== entry.calldataLength) { + return TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH; + } if (entry.onlySelf && !publicCall.request.msgSender.equals(contractAddress)) { return TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER; } @@ -118,6 +122,9 @@ export class PhasesTxValidator implements TxValidator { } if (contractClassId.value === entry.classId.toString() && entry.selector.equals(functionSelector)) { + if (entry.calldataLength !== undefined && publicCall.calldata.length !== entry.calldataLength) { + return TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH; + } if (entry.onlySelf && !publicCall.request.msgSender.equals(contractAddress)) { return TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER; } diff --git a/yarn-project/stdlib/src/interfaces/allowed_element.ts b/yarn-project/stdlib/src/interfaces/allowed_element.ts index 4c7d351f63ed..efc1a6b467bf 100644 --- a/yarn-project/stdlib/src/interfaces/allowed_element.ts +++ b/yarn-project/stdlib/src/interfaces/allowed_element.ts @@ -11,12 +11,14 @@ type AllowedInstanceFunction = { selector: FunctionSelector; onlySelf?: boolean; rejectNullMsgSender?: boolean; + calldataLength?: number; }; type AllowedClassFunction = { classId: Fr; selector: FunctionSelector; onlySelf?: boolean; rejectNullMsgSender?: boolean; + calldataLength?: number; }; export type AllowedElement = AllowedInstanceFunction | AllowedClassFunction; @@ -28,12 +30,14 @@ export const AllowedElementSchema = zodFor()( selector: schemas.FunctionSelector, onlySelf: z.boolean().optional(), rejectNullMsgSender: z.boolean().optional(), + calldataLength: z.number().optional(), }), z.object({ classId: schemas.Fr, selector: schemas.FunctionSelector, onlySelf: z.boolean().optional(), rejectNullMsgSender: z.boolean().optional(), + calldataLength: z.number().optional(), }), ]), ); diff --git a/yarn-project/stdlib/src/tx/validator/error_texts.ts b/yarn-project/stdlib/src/tx/validator/error_texts.ts index c9214299d46b..6a8326f032a8 100644 --- a/yarn-project/stdlib/src/tx/validator/error_texts.ts +++ b/yarn-project/stdlib/src/tx/validator/error_texts.ts @@ -9,6 +9,7 @@ export const TX_ERROR_SETUP_FUNCTION_NOT_ALLOWED = 'Setup function not on allow export const TX_ERROR_SETUP_FUNCTION_UNKNOWN_CONTRACT = 'Setup function targets unknown contract'; export const TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER = 'Setup only_self function called with incorrect msg_sender'; export const TX_ERROR_SETUP_NULL_MSG_SENDER = 'Setup function called with null msg sender'; +export const TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH = 'Setup function called with wrong calldata length'; // Nullifiers export const TX_ERROR_DUPLICATE_NULLIFIER_IN_TX = 'Duplicate nullifier in tx';