Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_builder_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"

namespace bb::eccvm {

static constexpr size_t NUM_SCALAR_BITS = 128;
static constexpr size_t WNAF_SLICE_BITS = 4;
static constexpr size_t NUM_WNAF_SLICES = (NUM_SCALAR_BITS + WNAF_SLICE_BITS - 1) / WNAF_SLICE_BITS;
static constexpr uint64_t WNAF_MASK = static_cast<uint64_t>((1ULL << WNAF_SLICE_BITS) - 1ULL);
static constexpr size_t POINT_TABLE_SIZE = 1ULL << (WNAF_SLICE_BITS);
static constexpr size_t WNAF_SLICES_PER_ROW = 4;
static constexpr size_t NUM_SCALAR_BITS = 128; // The length of scalars handled by the ECCVVM
static constexpr size_t NUM_WNAF_DIGIT_BITS = 4; // Scalars are decompose into base 16 in wNAF form
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = NUM_SCALAR_BITS / NUM_WNAF_DIGIT_BITS; // 32
static constexpr uint64_t WNAF_MASK = static_cast<uint64_t>((1ULL << NUM_WNAF_DIGIT_BITS) - 1ULL);
static constexpr size_t POINT_TABLE_SIZE = 1ULL << (NUM_WNAF_DIGIT_BITS);
static constexpr size_t WNAF_DIGITS_PER_ROW = 4;
static constexpr size_t ADDITIONS_PER_ROW = 4;

template <typename CycleGroup> struct VMOperation {
Expand Down Expand Up @@ -39,7 +38,7 @@ template <typename CycleGroup> struct ScalarMul {
uint32_t pc;
uint256_t scalar;
typename CycleGroup::affine_element base_point;
std::array<int, NUM_WNAF_SLICES> wnaf_slices;
std::array<int, NUM_WNAF_DIGITS_PER_SCALAR> wnaf_digits;
bool wnaf_skew;
// size bumped by 1 to record base_point.dbl()
std::array<typename CycleGroup::affine_element, POINT_TABLE_SIZE + 1> precomputed_table;
Expand Down
51 changes: 25 additions & 26 deletions barretenberg/cpp/src/barretenberg/eccvm/eccvm_circuit_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class ECCVMCircuitBuilder {
using AffineElement = typename CycleGroup::affine_element;

static constexpr size_t NUM_SCALAR_BITS = bb::eccvm::NUM_SCALAR_BITS;
static constexpr size_t WNAF_SLICE_BITS = bb::eccvm::WNAF_SLICE_BITS;
static constexpr size_t NUM_WNAF_SLICES = bb::eccvm::NUM_WNAF_SLICES;
static constexpr size_t NUM_WNAF_DIGIT_BITS = bb::eccvm::NUM_WNAF_DIGIT_BITS;
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR;
static constexpr uint64_t WNAF_MASK = bb::eccvm::WNAF_MASK;
static constexpr size_t POINT_TABLE_SIZE = bb::eccvm::POINT_TABLE_SIZE;
static constexpr size_t WNAF_SLICES_PER_ROW = bb::eccvm::WNAF_SLICES_PER_ROW;
static constexpr size_t WNAF_DIGITS_PER_ROW = bb::eccvm::WNAF_DIGITS_PER_ROW;
static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW;

using MSM = bb::eccvm::MSM<CycleGroup>;
Expand All @@ -50,7 +50,8 @@ class ECCVMCircuitBuilder {
/**
* For input point [P], return { -15[P], -13[P], ..., -[P], [P], ..., 13[P], 15[P] }
*/
const auto compute_precomputed_table = [](const AffineElement& base_point) {
const auto compute_precomputed_table =
[](const AffineElement& base_point) -> std::array<AffineElement, POINT_TABLE_SIZE + 1> {
const auto d2 = Element(base_point).dbl();
std::array<Element, POINT_TABLE_SIZE + 1> table;
table[POINT_TABLE_SIZE] = d2; // need this for later
Expand All @@ -69,10 +70,10 @@ class ECCVMCircuitBuilder {
}
return result;
};
const auto compute_wnaf_slices = [](uint256_t scalar) {
std::array<int, NUM_WNAF_SLICES> output;
const auto compute_wnaf_digits = [](uint256_t scalar) -> std::array<int, NUM_WNAF_DIGITS_PER_SCALAR> {
std::array<int, NUM_WNAF_DIGITS_PER_SCALAR> output;
int previous_slice = 0;
for (size_t i = 0; i < NUM_WNAF_SLICES; ++i) {
for (size_t i = 0; i < NUM_WNAF_DIGITS_PER_SCALAR; ++i) {
// slice the scalar into 4-bit chunks, starting with the least significant bits
uint64_t raw_slice = static_cast<uint64_t>(scalar) & WNAF_MASK;

Expand All @@ -86,19 +87,19 @@ class ECCVMCircuitBuilder {
} else if (is_even) {
// for other slices, if it's even, we add 1 to the slice value
// and subtract 16 from the previous slice to preserve the total scalar sum
static constexpr int borrow_constant = static_cast<int>(1ULL << WNAF_SLICE_BITS);
static constexpr int borrow_constant = static_cast<int>(1ULL << NUM_WNAF_DIGIT_BITS);
previous_slice -= borrow_constant;
wnaf_slice += 1;
}

if (i > 0) {
const size_t idx = i - 1;
output[NUM_WNAF_SLICES - idx - 1] = previous_slice;
output[NUM_WNAF_DIGITS_PER_SCALAR - idx - 1] = previous_slice;
}
previous_slice = wnaf_slice;

// downshift raw_slice by 4 bits
scalar = scalar >> WNAF_SLICE_BITS;
scalar = scalar >> NUM_WNAF_DIGIT_BITS;
}

ASSERT(scalar == 0);
Expand All @@ -108,8 +109,6 @@ class ECCVMCircuitBuilder {
return output;
};

// a vector of MSMs = a vector of a vector of scalar muls
// each mul
size_t msm_count = 0;
size_t active_mul_count = 0;
std::vector<size_t> msm_opqueue_index;
Expand All @@ -118,6 +117,7 @@ class ECCVMCircuitBuilder {

const auto& raw_ops = op_queue->get_raw_ops();
size_t op_idx = 0;
// populate opqueue and mul indices
for (const auto& op : raw_ops) {
if (op.mul) {
if (op.z1 != 0 || op.z2 != 0) {
Expand All @@ -142,39 +142,38 @@ class ECCVMCircuitBuilder {
msm_sizes.push_back(active_mul_count);
msm_count++;
}
std::vector<MSM> msms_test(msm_count);
std::vector<MSM> result(msm_count);
for (size_t i = 0; i < msm_count; ++i) {
auto& msm = msms_test[i];
auto& msm = result[i];
msm.resize(msm_sizes[i]);
}

run_loop_in_parallel(msm_opqueue_index.size(), [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
const size_t opqueue_index = msm_opqueue_index[i];
const auto& op = raw_ops[opqueue_index];
const auto& op = raw_ops[msm_opqueue_index[i]];
auto [msm_index, mul_index] = msm_mul_index[i];
if (op.z1 != 0) {
ASSERT(msms_test.size() > msm_index);
ASSERT(msms_test[msm_index].size() > mul_index);
msms_test[msm_index][mul_index] = (ScalarMul{
ASSERT(result.size() > msm_index);
ASSERT(result[msm_index].size() > mul_index);
result[msm_index][mul_index] = (ScalarMul{
.pc = 0,
.scalar = op.z1,
.base_point = op.base_point,
.wnaf_slices = compute_wnaf_slices(op.z1),
.wnaf_digits = compute_wnaf_digits(op.z1),
.wnaf_skew = (op.z1 & 1) == 0,
.precomputed_table = compute_precomputed_table(op.base_point),
});
mul_index++;
}
if (op.z2 != 0) {
ASSERT(msms_test.size() > msm_index);
ASSERT(msms_test[msm_index].size() > mul_index);
ASSERT(result.size() > msm_index);
ASSERT(result[msm_index].size() > mul_index);
auto endo_point = AffineElement{ op.base_point.x * FF::cube_root_of_unity(), -op.base_point.y };
msms_test[msm_index][mul_index] = (ScalarMul{
result[msm_index][mul_index] = (ScalarMul{
.pc = 0,
.scalar = op.z2,
.base_point = endo_point,
.wnaf_slices = compute_wnaf_slices(op.z2),
.wnaf_digits = compute_wnaf_digits(op.z2),
.wnaf_skew = (op.z2 & 1) == 0,
.precomputed_table = compute_precomputed_table(endo_point),
});
Expand All @@ -191,15 +190,15 @@ class ECCVMCircuitBuilder {
// sumcheck relations that involve pc (if we did the other way around, starting at 1 and ending at num_muls,
// we create a discontinuity in pc values between the last transcript row and the following empty row)
uint32_t pc = num_muls;
for (auto& msm : msms_test) {
for (auto& msm : result) {
for (auto& mul : msm) {
mul.pc = pc;
pc--;
}
}

ASSERT(pc == 0);
return msms_test;
return result;
}

static std::vector<ScalarMul> get_flattened_scalar_muls(const std::vector<MSM>& msms)
Expand Down
Loading