Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ enum class AvmError : uint32_t {
OUT_OF_GAS,
STATIC_CALL_ALTERATION,
FAILED_BYTECODE_RETRIEVAL,
MSM_POINTS_LEN_INVALID,
MSM_POINT_NOT_ON_CURVE,
};

} // namespace bb::avm_trace
8 changes: 8 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ std::string to_name(AvmError error)
return "TAG CHECKING ERROR";
case AvmError::ADDR_RES_TAG_ERROR:
return "ADDRESS RESOLUTION TAG ERROR";
case AvmError::MEM_SLICE_OUT_OF_RANGE:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case was missing as well as STATIC_CALL_ALTERATION

return "MEMORY SLICE OUT OF RANGE";
case AvmError::REL_ADDR_OUT_OF_RANGE:
return "RELATIVE ADDRESS IS OUT OF RANGE";
case AvmError::DIV_ZERO:
Expand All @@ -135,8 +137,14 @@ std::string to_name(AvmError error)
return "SIDE EFFECT LIMIT REACHED";
case AvmError::OUT_OF_GAS:
return "OUT OF GAS";
case AvmError::STATIC_CALL_ALTERATION:
return "STATIC CALL ALTERATION";
case AvmError::FAILED_BYTECODE_RETRIEVAL:
return "FAILED BYTECODE RETRIEVAL";
case AvmError::MSM_POINTS_LEN_INVALID:
return "MSM POINTS LEN INVALID";
case AvmError::MSM_POINT_NOT_ON_CURVE:
return "MSM POINT NOT ON CURVE";
default:
throw std::runtime_error("Invalid error type");
break;
Expand Down
9 changes: 9 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4780,6 +4780,11 @@ AvmError AvmTraceBuilder::op_variable_msm(uint8_t indirect,

const FF points_length = is_ok(error) ? unconstrained_read_from_memory(resolved_point_length_offset) : 0;

// Unconstrained check that points_length must be a multiple of 3.
if (is_ok(error) && static_cast<uint32_t>(points_length) % 3 != 0) {
error = AvmError::MSM_POINTS_LEN_INVALID;
}

if (is_ok(error) && !check_slice_mem_range(resolved_points_offset, static_cast<uint32_t>(points_length))) {
error = AvmError::MEM_SLICE_OUT_OF_RANGE;
}
Expand Down Expand Up @@ -4863,6 +4868,10 @@ AvmError AvmTraceBuilder::op_variable_msm(uint8_t indirect,
points.emplace_back(grumpkin::g1::affine_element::infinity());
} else {
points.emplace_back(x, y);
// Unconstrained check that this point lies on the Grumpkin curve.
if (!points.back().on_curve()) {
return AvmError::MSM_POINT_NOT_ON_CURVE;
}
}
}
// Reconstruct Grumpkin scalars
Expand Down
22 changes: 21 additions & 1 deletion yarn-project/simulator/src/avm/errors.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type FailingFunction, type NoirCallStack } from '@aztec/circuit-types';
import { type AztecAddress, type Fr } from '@aztec/circuits.js';
import { type AztecAddress, type Fr, type Point } from '@aztec/circuits.js';

import { ExecutionError } from '../common/errors.js';
import { type AvmContext } from './avm_context.js';
Expand Down Expand Up @@ -128,6 +128,26 @@ export class OutOfGasError extends AvmExecutionError {
}
}

/**
* Error is thrown when the supplied points length is not a multiple of 3. Specific for MSM opcode.
*/
export class MSMPointsLengthError extends AvmExecutionError {
constructor(pointsReadLength: number) {
super(`Points vector length should be a multiple of 3, was ${pointsReadLength}`);
this.name = 'MSMPointsLengthError';
}
}

/**
* Error is thrown when one of the supplied points does not lie on the Grumpkin curve. Specific for MSM opcode.
*/
export class MSMPointNotOnCurveError extends AvmExecutionError {
constructor(point: Point) {
super(`Point ${point.toString()} is not on the curve.`);
this.name = 'MSMPointNotOnCurveError';
}
}

/**
* Error is thrown when a static call attempts to alter some state
*/
Expand Down
53 changes: 52 additions & 1 deletion yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { Fq, Fr } from '@aztec/circuits.js';
import { Fq, Fr, Point } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, type MemoryValue, Uint1, Uint32 } from '../avm_memory_types.js';
import { MSMPointNotOnCurveError, MSMPointsLengthError } from '../errors.js';
import { initContext } from '../fixtures/index.js';
import { MultiScalarMul } from './multi_scalar_mul.js';

Expand Down Expand Up @@ -127,4 +128,54 @@ describe('MultiScalarMul Opcode', () => {

expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});

it('Should throw an error if points length is not a multiple of 3', async () => {
const indirect = 0;

// No need to set up points nor scalars as it is expected to fail before any processing of them.
const pointsReadLength = 17; // Not multiple of 3
const pointsOffset = 0;
const scalarsOffset = 20;
const pointsLengthOffset = 100;
const outputOffset = 120;

context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));

await expect(
new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context),
).rejects.toThrow(MSMPointsLengthError);
});

it('Should throw an error if a point is not on Grumpkin curve', async () => {
const indirect = 0;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, NOT_ON_CURVE]
const points = Array.from({ length: 2 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));
points.push(new Point(new Fr(13), new Fr(14), false));

const scalars = [new Fq(5n), new Fq(3n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.lo), new Field(s.hi)]);
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
const storedPoints: MemoryValue[] = points
.map(p => p.toFields())
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint1(inf.toNumber())]);
const pointsOffset = 0;
context.machineState.memory.setSlice(pointsOffset, storedPoints);
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

await expect(
new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context),
).rejects.toThrow(MSMPointNotOnCurveError);
});
});
6 changes: 3 additions & 3 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, TypeTag, Uint1 } from '../avm_memory_types.js';
import { InstructionExecutionError } from '../errors.js';
import { MSMPointNotOnCurveError, MSMPointsLengthError } from '../errors.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';
Expand Down Expand Up @@ -44,7 +44,7 @@ export class MultiScalarMul extends Instruction {
// Get the size of the unrolled (x, y , inf) points vector
const pointsReadLength = memory.get(pointsLengthOffset).toNumber();
if (pointsReadLength % 3 !== 0) {
throw new InstructionExecutionError(`Points vector offset should be a multiple of 3, was ${pointsReadLength}`);
throw new MSMPointsLengthError(pointsReadLength);
}

// Get the unrolled (x, y, inf) representing the points
Expand Down Expand Up @@ -76,7 +76,7 @@ export class MultiScalarMul extends Instruction {
const isInf = pointsVector[3 * i + 2].toNumber() === 1;
const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr(), isInf);
if (!p.isOnGrumpkin()) {
throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`);
throw new MSMPointNotOnCurveError(p);
}
grumpkinPoints.push(p);
}
Expand Down