Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
import { TokenContractArtifact } from '@aztec/noir-contracts.js/Token';
import { ProtocolContractAddress } from '@aztec/protocol-contracts';
import { FunctionSelector } from '@aztec/stdlib/abi';
import { AuthRegistryArtifact } from '@aztec/protocol-contracts/auth-registry';
import { FeeJuiceArtifact } from '@aztec/protocol-contracts/fee-juice';
import { FunctionSelector, countArgumentsSize } from '@aztec/stdlib/abi';
import type { ContractArtifact, FunctionAbi } from '@aztec/stdlib/abi';
import { getContractClassFromArtifact } from '@aztec/stdlib/contract';
import type { AllowedElement } from '@aztec/stdlib/interfaces/server';

/** Returns the expected calldata length for a function: 1 (selector) + arguments size. */
function getCalldataLength(artifact: ContractArtifact, functionName: string): number {
const allFunctions: FunctionAbi[] = (artifact.functions as FunctionAbi[]).concat(
artifact.nonDispatchPublicFunctions || [],
);
const fn = allFunctions.find(f => f.name === functionName);
if (!fn) {
throw new Error(`Unknown function ${functionName} in artifact ${artifact.name}`);
}
return 1 + countArgumentsSize(fn);
}

let defaultAllowedSetupFunctions: AllowedElement[] | undefined;

/** Returns the default list of functions allowed to run in the setup phase of a transaction. */
Expand All @@ -22,31 +37,36 @@ export async function getDefaultAllowedSetupFunctions(): Promise<AllowedElement[
{
address: ProtocolContractAddress.AuthRegistry,
selector: setAuthorizedInternalSelector,
calldataLength: getCalldataLength(AuthRegistryArtifact, '_set_authorized'),
onlySelf: true,
rejectNullMsgSender: true,
},
// AuthRegistry: needed for authwit support via public path (PublicFeePaymentMethod calls set_authorized directly)
{
address: ProtocolContractAddress.AuthRegistry,
selector: setAuthorizedSelector,
calldataLength: getCalldataLength(AuthRegistryArtifact, 'set_authorized'),
rejectNullMsgSender: true,
},
// FeeJuice: needed for claiming on the same tx as a spend (claim_and_end_setup enqueues this)
{
address: ProtocolContractAddress.FeeJuice,
selector: increaseBalanceSelector,
calldataLength: getCalldataLength(FeeJuiceArtifact, '_increase_public_balance'),
onlySelf: true,
},
// Token: needed for private transfers via FPC (transfer_to_public enqueues this)
{
classId: tokenClassId,
selector: increaseBalanceSelector,
calldataLength: getCalldataLength(TokenContractArtifact, '_increase_public_balance'),
onlySelf: true,
},
// Token: needed for public transfers via FPC (fee_entrypoint_public enqueues this)
{
classId: tokenClassId,
selector: transferInPublicSelector,
calldataLength: getCalldataLength(TokenContractArtifact, 'transfer_in_public'),
},
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
TX_ERROR_SETUP_FUNCTION_UNKNOWN_CONTRACT,
TX_ERROR_SETUP_NULL_MSG_SENDER,
TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER,
TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH,
type Tx,
} from '@aztec/stdlib/tx';

Expand Down Expand Up @@ -317,6 +318,114 @@ describe('PhasesTxValidator', () => {
});
});

describe('calldataLength validation', () => {
const expectedLength = 4; // 1 selector + 3 args
let calldataContract: AztecAddress;
let calldataSelector: FunctionSelector;
let calldataClassId: Fr;

beforeEach(() => {
calldataContract = makeAztecAddress(70);
calldataSelector = makeSelector(70);
calldataClassId = Fr.random();

txValidator = new PhasesTxValidator(
contractDataSource,
[
{
address: calldataContract,
selector: calldataSelector,
calldataLength: expectedLength,
},
{
classId: calldataClassId,
selector: calldataSelector,
calldataLength: expectedLength,
},
],
timestamp,
);
});

it('allows address entry with correct calldata length', async () => {
const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 });
await patchNonRevertibleFn(tx, 0, {
address: calldataContract,
selector: calldataSelector,
args: [Fr.random(), Fr.random(), Fr.random()],
});

await expectValid(tx);
});

it('rejects address entry with too short calldata', async () => {
const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 });
await patchNonRevertibleFn(tx, 0, {
address: calldataContract,
selector: calldataSelector,
args: [Fr.random()],
});

await expectInvalid(tx, TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH);
});

it('rejects address entry with too long calldata', async () => {
const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 });
await patchNonRevertibleFn(tx, 0, {
address: calldataContract,
selector: calldataSelector,
args: [Fr.random(), Fr.random(), Fr.random(), Fr.random(), Fr.random()],
});

await expectInvalid(tx, TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH);
});

it('rejects class entry with wrong calldata length', async () => {
const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 });
const address = await patchNonRevertibleFn(tx, 0, {
selector: calldataSelector,
args: [Fr.random()],
});

contractDataSource.getContract.mockImplementationOnce((contractAddress, atTimestamp) => {
if (timestamp !== atTimestamp) {
throw new Error('Unexpected timestamp');
}
if (address.equals(contractAddress)) {
return Promise.resolve({
currentContractClassId: calldataClassId,
originalContractClassId: Fr.random(),
} as any);
}
return Promise.resolve(undefined);
});

await expectInvalid(tx, TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH);
});

it('allows any calldata length when calldataLength is not set', async () => {
txValidator = new PhasesTxValidator(
contractDataSource,
[
{
address: calldataContract,
selector: calldataSelector,
},
],
timestamp,
);

const tx = await mockTx(1, { numberOfNonRevertiblePublicCallRequests: 1 });
await patchNonRevertibleFn(tx, 0, {
address: calldataContract,
selector: calldataSelector,
args: [Fr.random(), Fr.random(), Fr.random(), Fr.random(), Fr.random(), Fr.random()],
});

await expectValid(tx);
});
});

describe('rejectNullMsgSender validation', () => {
const nullMsgSender = AztecAddress.fromBigInt(NULL_MSG_SENDER_CONTRACT_ADDRESS);
let rejectNullContract: AztecAddress;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
TX_ERROR_SETUP_FUNCTION_UNKNOWN_CONTRACT,
TX_ERROR_SETUP_NULL_MSG_SENDER,
TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER,
TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH,
Tx,
TxExecutionPhase,
type TxValidationResult,
Expand Down Expand Up @@ -88,6 +89,9 @@ export class PhasesTxValidator implements TxValidator<Tx> {
for (const entry of allowList) {
if ('address' in entry) {
if (contractAddress.equals(entry.address) && entry.selector.equals(functionSelector)) {
if (entry.calldataLength !== undefined && publicCall.calldata.length !== entry.calldataLength) {
return TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH;
}
if (entry.onlySelf && !publicCall.request.msgSender.equals(contractAddress)) {
return TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER;
}
Expand Down Expand Up @@ -118,6 +122,9 @@ export class PhasesTxValidator implements TxValidator<Tx> {
}

if (contractClassId.value === entry.classId.toString() && entry.selector.equals(functionSelector)) {
if (entry.calldataLength !== undefined && publicCall.calldata.length !== entry.calldataLength) {
return TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH;
}
if (entry.onlySelf && !publicCall.request.msgSender.equals(contractAddress)) {
return TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER;
}
Expand Down
4 changes: 4 additions & 0 deletions yarn-project/stdlib/src/interfaces/allowed_element.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ type AllowedInstanceFunction = {
selector: FunctionSelector;
onlySelf?: boolean;
rejectNullMsgSender?: boolean;
calldataLength?: number;
};
type AllowedClassFunction = {
classId: Fr;
selector: FunctionSelector;
onlySelf?: boolean;
rejectNullMsgSender?: boolean;
calldataLength?: number;
};

export type AllowedElement = AllowedInstanceFunction | AllowedClassFunction;
Expand All @@ -28,12 +30,14 @@ export const AllowedElementSchema = zodFor<AllowedElement>()(
selector: schemas.FunctionSelector,
onlySelf: z.boolean().optional(),
rejectNullMsgSender: z.boolean().optional(),
calldataLength: z.number().optional(),
}),
z.object({
classId: schemas.Fr,
selector: schemas.FunctionSelector,
onlySelf: z.boolean().optional(),
rejectNullMsgSender: z.boolean().optional(),
calldataLength: z.number().optional(),
}),
]),
);
1 change: 1 addition & 0 deletions yarn-project/stdlib/src/tx/validator/error_texts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export const TX_ERROR_SETUP_FUNCTION_NOT_ALLOWED = 'Setup function not on allow
export const TX_ERROR_SETUP_FUNCTION_UNKNOWN_CONTRACT = 'Setup function targets unknown contract';
export const TX_ERROR_SETUP_ONLY_SELF_WRONG_SENDER = 'Setup only_self function called with incorrect msg_sender';
export const TX_ERROR_SETUP_NULL_MSG_SENDER = 'Setup function called with null msg sender';
export const TX_ERROR_SETUP_WRONG_CALLDATA_LENGTH = 'Setup function called with wrong calldata length';

// Nullifiers
export const TX_ERROR_DUPLICATE_NULLIFIER_IN_TX = 'Duplicate nullifier in tx';
Expand Down
Loading