-
Notifications
You must be signed in to change notification settings - Fork 598
feat(avm-simulator): msm blackbox #7048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| import { Fq, Fr } from '@aztec/circuits.js'; | ||
| import { Grumpkin } from '@aztec/circuits.js/barretenberg'; | ||
|
|
||
| import { type AvmContext } from '../avm_context.js'; | ||
| import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js'; | ||
| import { initContext } from '../fixtures/index.js'; | ||
| import { MultiScalarMul } from './multi_scalar_mul.js'; | ||
|
|
||
| describe('MultiScalarMul Opcode', () => { | ||
| let context: AvmContext; | ||
|
|
||
| beforeEach(async () => { | ||
| context = initContext(); | ||
| }); | ||
| it('Should (de)serialize correctly', () => { | ||
| const buf = Buffer.from([ | ||
| MultiScalarMul.opcode, // opcode | ||
| 7, // indirect | ||
| ...Buffer.from('12345678', 'hex'), // pointsOffset | ||
| ...Buffer.from('23456789', 'hex'), // scalars Offset | ||
| ...Buffer.from('3456789a', 'hex'), // outputOffset | ||
| ...Buffer.from('456789ab', 'hex'), // pointsLengthOffset | ||
| ]); | ||
| const inst = new MultiScalarMul( | ||
| /*indirect=*/ 7, | ||
| /*pointsOffset=*/ 0x12345678, | ||
| /*scalarsOffset=*/ 0x23456789, | ||
| /*outputOffset=*/ 0x3456789a, | ||
| /*pointsLengthOffset=*/ 0x456789ab, | ||
| ); | ||
|
|
||
| expect(MultiScalarMul.deserialize(buf)).toEqual(inst); | ||
| expect(inst.serialize()).toEqual(buf); | ||
| }); | ||
|
|
||
| it('Should perform msm correctly - direct', async () => { | ||
| const indirect = 0; | ||
| const grumpkin = new Grumpkin(); | ||
| // We need to ensure points are actually on curve, so we just use the generator | ||
| // In future we could use a random point, for now we create an array of [G, 2G, 3G] | ||
| const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); | ||
|
|
||
| // Pick some big scalars to test the edge cases | ||
| const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; | ||
| const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory | ||
| const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory | ||
| // Transform the points and scalars into the format that we will write to memory | ||
| // We just store the x and y coordinates here, and handle the infinities when we write to memory | ||
| const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); | ||
| // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] | ||
| const storedPoints: MemoryValue[] = points | ||
| .map(p => p.toFieldsWithInf()) | ||
| .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); | ||
| const pointsOffset = 0; | ||
| context.machineState.memory.setSlice(pointsOffset, storedPoints); | ||
| // Store scalars | ||
| const scalarsOffset = pointsOffset + pointsReadLength; | ||
| context.machineState.memory.setSlice(scalarsOffset, storedScalars); | ||
| // Store length of points to read | ||
| const pointsLengthOffset = scalarsOffset + scalarsLength; | ||
| context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); | ||
| const outputOffset = pointsLengthOffset + 1; | ||
|
|
||
| await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); | ||
|
|
||
| const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); | ||
|
|
||
| // We write it out explicitly here | ||
| let expectedResult = grumpkin.mul(points[0], scalars[0]); | ||
| expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); | ||
| expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); | ||
|
|
||
| expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); | ||
| }); | ||
|
|
||
| it('Should perform msm correctly - indirect', async () => { | ||
| const indirect = 7; | ||
| const grumpkin = new Grumpkin(); | ||
| // We need to ensure points are actually on curve, so we just use the generator | ||
| // In future we could use a random point, for now we create an array of [G, 2G, 3G] | ||
| const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); | ||
|
|
||
| // Pick some big scalars to test the edge cases | ||
| const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; | ||
| const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory | ||
| const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory | ||
| // Transform the points and scalars into the format that we will write to memory | ||
| // We just store the x and y coordinates here, and handle the infinities when we write to memory | ||
| const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); | ||
| // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] | ||
| const storedPoints: MemoryValue[] = points | ||
| .map(p => p.toFieldsWithInf()) | ||
| .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); | ||
| const pointsOffset = 0; | ||
| context.machineState.memory.setSlice(pointsOffset, storedPoints); | ||
| // Store scalars | ||
| const scalarsOffset = pointsOffset + pointsReadLength; | ||
| context.machineState.memory.setSlice(scalarsOffset, storedScalars); | ||
| // Store length of points to read | ||
| const pointsLengthOffset = scalarsOffset + scalarsLength; | ||
| context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); | ||
| const outputOffset = pointsLengthOffset + 1; | ||
|
|
||
| // Set up the indirect pointers | ||
| const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */ | ||
| const scalarsIndirectOffset = pointsIndirectOffset + 1; | ||
| const outputIndirectOffset = scalarsIndirectOffset + 1; | ||
|
|
||
| context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset)); | ||
| context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset)); | ||
| context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset)); | ||
|
|
||
| await new MultiScalarMul( | ||
| indirect, | ||
| pointsIndirectOffset, | ||
| scalarsIndirectOffset, | ||
| outputIndirectOffset, | ||
| pointsLengthOffset, | ||
| ).execute(context); | ||
|
|
||
| const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); | ||
|
|
||
| // We write it out explicitly here | ||
| let expectedResult = grumpkin.mul(points[0], scalars[0]); | ||
| expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); | ||
| expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); | ||
|
|
||
| expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); | ||
| }); | ||
| }); |
114 changes: 114 additions & 0 deletions
114
yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import { Fq, Point } from '@aztec/circuits.js'; | ||
| import { Grumpkin } from '@aztec/circuits.js/barretenberg'; | ||
|
|
||
| import { strict as assert } from 'assert'; | ||
|
|
||
| import { type AvmContext } from '../avm_context.js'; | ||
| import { Field, TypeTag } from '../avm_memory_types.js'; | ||
| import { InstructionExecutionError } from '../errors.js'; | ||
| import { Opcode, OperandType } from '../serialization/instruction_serialization.js'; | ||
| import { Addressing } from './addressing_mode.js'; | ||
| import { Instruction } from './instruction.js'; | ||
|
|
||
| export class MultiScalarMul extends Instruction { | ||
| static type: string = 'MultiScalarMul'; | ||
| static readonly opcode: Opcode = Opcode.MSM; | ||
|
|
||
| // Informs (de)serialization. See Instruction.deserialize. | ||
| static readonly wireFormat: OperandType[] = [ | ||
| OperandType.UINT8 /* opcode */, | ||
| OperandType.UINT8 /* indirect */, | ||
| OperandType.UINT32 /* points vector offset */, | ||
| OperandType.UINT32 /* scalars vector offset */, | ||
| OperandType.UINT32 /* output offset (fixed triplet) */, | ||
| OperandType.UINT32 /* points length offset */, | ||
| ]; | ||
|
|
||
| constructor( | ||
| private indirect: number, | ||
| private pointsOffset: number, | ||
| private scalarsOffset: number, | ||
| private outputOffset: number, | ||
| private pointsLengthOffset: number, | ||
| ) { | ||
| super(); | ||
| } | ||
|
|
||
| public async execute(context: AvmContext): Promise<void> { | ||
| const memory = context.machineState.memory.track(this.type); | ||
| // Resolve indirects | ||
| const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve( | ||
| [this.pointsOffset, this.scalarsOffset, this.outputOffset], | ||
| memory, | ||
| ); | ||
|
|
||
| // Length of the points vector should be U32 | ||
| memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); | ||
| // Get the size of the unrolled (x, y , inf) points vector | ||
IlyasRidhuan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); | ||
| assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3'); | ||
| // Divide by 3 since each point is represented as a triplet to get the number of points | ||
| const numPoints = pointsReadLength / 3; | ||
| // The tag for each triplet will be (Field, Field, Uint8) | ||
| for (let i = 0; i < numPoints; i++) { | ||
| const offset = pointsOffset + i * 3; | ||
| // Check (Field, Field) | ||
| memory.checkTagsRange(TypeTag.FIELD, offset, 2); | ||
| // Check Uint8 (inf flag) | ||
| memory.checkTag(TypeTag.UINT8, offset + 2); | ||
| } | ||
| // Get the unrolled (x, y, inf) representing the points | ||
| const pointsVector = memory.getSlice(pointsOffset, pointsReadLength); | ||
|
|
||
| // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition | ||
| const scalarReadLength = numPoints * 2; | ||
| // Consume gas prior to performing work | ||
| const memoryOperations = { | ||
| reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, | ||
| writes: 3 /* output triplet */, | ||
| indirect: this.indirect, | ||
| }; | ||
| context.machineState.consumeGas(this.gasCost(memoryOperations)); | ||
| // Get the unrolled scalar (lo & hi) representing the scalars | ||
| const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); | ||
| memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); | ||
|
|
||
| // Now we need to reconstruct the points and scalars into something we can operate on. | ||
| const grumpkinPoints: Point[] = []; | ||
| for (let i = 0; i < numPoints; i++) { | ||
| const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr()); | ||
| // Include this later when we have a standard for representing infinity | ||
| // const isInf = pointsVector[i + 2].toBoolean(); | ||
|
|
||
| if (!p.isOnGrumpkin()) { | ||
| throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`); | ||
| } | ||
| grumpkinPoints.push(p); | ||
| } | ||
| // The scalars are read from memory as Fr elements, which are limbs of Fq elements | ||
| // So we need to reconstruct them before performing the scalar multiplications | ||
| const scalarFqVector: Fq[] = []; | ||
| for (let i = 0; i < numPoints; i++) { | ||
| const scalarLo = scalarsVector[2 * i].toFr(); | ||
| const scalarHi = scalarsVector[2 * i + 1].toFr(); | ||
| const fqScalar = Fq.fromHighLow(scalarHi, scalarLo); | ||
| scalarFqVector.push(fqScalar); | ||
| } | ||
| // TODO: Is there an efficient MSM implementation in ts that we can replace this by? | ||
| const grumpkin = new Grumpkin(); | ||
IlyasRidhuan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Zip the points and scalars into pairs | ||
| const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]); | ||
| // Fold the points and scalars into a single point | ||
| // We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts | ||
| const outputPoint = rest.reduce( | ||
| (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), | ||
| grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), | ||
| ); | ||
| const output = outputPoint.toFieldsWithInf().map(f => new Field(f)); | ||
|
|
||
| memory.setSlice(outputOffset, output); | ||
|
|
||
| memory.assert(memoryOperations); | ||
| context.machineState.incrementPc(); | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.