Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"

#include "./cycle_group.hpp"
#include "barretenberg/numeric/general/general.hpp"
#include "barretenberg/stdlib/primitives/plookup/plookup.hpp"
#include "barretenberg/stdlib_circuit_builders/plookup_tables/fixed_base/fixed_base.hpp"
#include "barretenberg/stdlib_circuit_builders/plookup_tables/types.hpp"
Expand Down Expand Up @@ -765,7 +766,9 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_
const std::span<AffineElement const> offset_generators,
const bool unconditional_add)
{
BB_ASSERT_EQ(scalars.size(), base_points.size());
BB_ASSERT_EQ(!scalars.empty(), true, "Empty scalars provided to variable base batch mul!");
BB_ASSERT_EQ(scalars.size(), base_points.size(), "Points/scalars size mismatch in variable base batch mul!");
const size_t num_points = scalars.size();

Builder* context = nullptr;
for (auto& scalar : scalars) {
Expand All @@ -781,21 +784,16 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_
}
}

size_t num_bits = 0;
size_t num_bits = scalars[0].num_bits();
for (auto& s : scalars) {
num_bits = std::max(num_bits, s.num_bits());
BB_ASSERT_EQ(s.num_bits(), num_bits, "Scalars of different bit-lengths not supported!");
}
size_t num_rounds = (num_bits + TABLE_BITS - 1) / TABLE_BITS;

const size_t num_points = scalars.size();
size_t num_rounds = numeric::ceil_div(num_bits, TABLE_BITS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, I've done ceil_div by hand in several tests


std::vector<straus_scalar_slice> scalar_slices;
scalar_slices.reserve(num_points);
for (size_t i = 0; i < num_points; ++i) {
scalar_slices.emplace_back(straus_scalar_slice(context, scalars[i], TABLE_BITS));
Copy link
Contributor Author

@ledwards2225 ledwards2225 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of emplace_back is to pass the constructor arguments directly so that the object can be constructed in place rather than constructed then moved. Its possible this would be optimized to be equivalent anyway but at the very least its poor code. I also fixed cases using emplace_back on already constructed objects, which should use push_back. Maybe a bit pedantic but I prefer to have correctness on this issue.

// AUDITTODO: temporary safety check. See test MixedLengthScalarsIsNotSupported
BB_ASSERT_EQ(
scalar_slices[i].slices_native.size() == num_rounds, true, "Scalars of different sizes not supported!");
for (const auto& scalar : scalars) {
scalar_slices.emplace_back(context, scalar, TABLE_BITS);
}

/**
Expand All @@ -805,31 +803,24 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_
* generation times
*/
std::vector<Element> operation_transcript;
std::vector<std::vector<Element>> native_straus_tables;
Element offset_generator_accumulator = offset_generators[0];
{
// Construct native straus lookup table for each point
std::vector<std::vector<Element>> native_straus_tables;
for (size_t i = 0; i < num_points; ++i) {
Copy link
Contributor Author

@ledwards2225 ledwards2225 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that this code was computing the native straus lookup tables twice. This may have been intentional for the following somewhat convoluted reason: To avoid edge cases in the stdlib version of the tables, we replace the point at infinity with a generator point (one), then perform the additions required to compute the straus table values. To create "hints" for these in-circuit computations (i.e. the underlying native witness values) we would need to handle the point at infinity the same way in the native table construction. Fine. BUT we also need the genuine native tables (i.e. where infinity is treated as infinity) to compute the native intermediate accumulator values in the straus algorithm. So, if we want to be able to provide hints to the in-circuit version, we would need to construct the native tables twice in the two ways just described. This is needlessly complicated so what I do instead is just construct the native tables once (correctly treating infinity as infinity) and tell the in-circuit version that no hint exists if the point is the point at infinity. There is really no efficiency loss here since performing an on-the-fly addition with the point at infinity is basically free.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice simplification, the original logic seems quite convoluted indeed

std::vector<Element> native_straus_table;
native_straus_table.emplace_back(offset_generators[i + 1]);
size_t table_size = 1ULL << TABLE_BITS;
for (size_t j = 1; j < table_size; ++j) {
native_straus_table.emplace_back(native_straus_table[j - 1] + base_points[i].get_value());
}
native_straus_tables.emplace_back(native_straus_table);
}
for (size_t i = 0; i < num_points; ++i) {
auto table_transcript = straus_lookup_table::compute_straus_lookup_table_hints(
std::vector<Element> table_transcript = straus_lookup_table::compute_straus_lookup_table_hints(
base_points[i].get_value(), offset_generators[i + 1], TABLE_BITS);
std::copy(table_transcript.begin() + 1, table_transcript.end(), std::back_inserter(operation_transcript));
native_straus_tables.emplace_back(std::move(table_transcript));
}
Element accumulator = offset_generators[0];

// Perform the Straus algorithm natively to generate the witness values (hints) for all intermediate points
Element accumulator = offset_generators[0];
for (size_t i = 0; i < num_rounds; ++i) {
if (i != 0) {
for (size_t j = 0; j < TABLE_BITS; ++j) {
// offset_generator_accuulator is a regular Element, so dbl() won't add constraints
accumulator = accumulator.dbl();
operation_transcript.emplace_back(accumulator);
operation_transcript.push_back(accumulator);
offset_generator_accumulator = offset_generator_accumulator.dbl();
}
}
Expand All @@ -839,29 +830,30 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_

accumulator += point;

operation_transcript.emplace_back(accumulator);
operation_transcript.push_back(accumulator);
offset_generator_accumulator = offset_generator_accumulator + Element(offset_generators[j + 1]);
}
}
}

// Normalize the computed witness points and convert into AffineElement type
// Normalize the computed witness points and convert them into AffineElements
Element::batch_normalize(&operation_transcript[0], operation_transcript.size());

std::vector<AffineElement> operation_hints;
operation_hints.reserve(operation_transcript.size());
for (auto& element : operation_transcript) {
operation_hints.emplace_back(AffineElement(element.x, element.y));
for (const Element& element : operation_transcript) {
operation_hints.emplace_back(element.x, element.y);
}

// Construct an in-circuit straus lookup table for each point
std::vector<straus_lookup_table> point_tables;
const size_t hints_per_table = (1ULL << TABLE_BITS) - 1;
OriginTag tag{};
for (size_t i = 0; i < num_points; ++i) {
std::span<AffineElement> table_hints(&operation_hints[i * hints_per_table], hints_per_table);
// Merge tags
tag = OriginTag(tag, scalars[i].get_origin_tag(), base_points[i].get_origin_tag());
point_tables.emplace_back(straus_lookup_table(context, base_points[i], offset_generators[i + 1], TABLE_BITS));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note a bug here where we were not even using the hints constructed a few lines above when we go to construct the straus_lookup_table! Corrected on the right.


std::span<AffineElement> table_hints(&operation_hints[i * hints_per_table], hints_per_table);
point_tables.emplace_back(context, base_points[i], offset_generators[i + 1], TABLE_BITS, table_hints);
}

AffineElement* hint_ptr = &operation_hints[num_points * hints_per_table];
Expand All @@ -874,39 +866,34 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_
std::vector<cycle_group> points_to_add;
for (size_t i = 0; i < num_rounds; ++i) {
for (size_t j = 0; j < num_points; ++j) {
const std::optional<field_t> scalar_slice = scalar_slices[j].read(num_rounds - i - 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing optional logic which was only needed to support scalars of different size (which was not actually used/supported)

// if we are doing a batch mul over scalars of different bit-lengths, we may not have any scalar bits for a
// given round and a given scalar
if (scalar_slice.has_value()) {
const cycle_group point = point_tables[j].read(scalar_slice.value());
points_to_add.emplace_back(point);
}
const field_t scalar_slice = scalar_slices[j].read(num_rounds - i - 1);
const cycle_group point = point_tables[j].read(scalar_slice);
points_to_add.push_back(point);
}
}

// Perform the Straus algorithm in-circuit, using the previously computed native hints
std::vector<std::tuple<field_t, field_t>> x_coordinate_checks;
size_t point_counter = 0;
for (size_t i = 0; i < num_rounds; ++i) {
// perform once-per-round doublings (except for first round)
if (i != 0) {
for (size_t j = 0; j < TABLE_BITS; ++j) {
accumulator = accumulator.dbl(*hint_ptr);
hint_ptr++;
}
}

// perform each round's additions
for (size_t j = 0; j < num_points; ++j) {
const std::optional<field_t> scalar_slice = scalar_slices[j].read(num_rounds - i - 1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again just removing no longer needed optional logic here

// if we are doing a batch mul over scalars of different bit-lengths, we may not have a bit slice
// for a given round and a given scalar
BB_ASSERT_EQ(scalar_slice.value().get_value(), scalar_slices[j].slices_native[num_rounds - i - 1]);
if (scalar_slice.has_value()) {
const auto& point = points_to_add[point_counter++];
if (!unconditional_add) {
x_coordinate_checks.push_back({ accumulator.x, point.x });
}
accumulator = accumulator.unconditional_add(point, *hint_ptr);
hint_ptr++;
field_t scalar_slice = scalar_slices[j].read(num_rounds - i - 1);

BB_ASSERT_EQ(scalar_slice.get_value(), scalar_slices[j].slices_native[num_rounds - i - 1]);
const auto& point = points_to_add[point_counter++];
if (!unconditional_add) {
x_coordinate_checks.emplace_back(accumulator.x, point.x);
}
accumulator = accumulator.unconditional_add(point, *hint_ptr);
hint_ptr++;
}
}

Expand All @@ -915,7 +902,7 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_
// because `assert_is_not_zero` witness generation needs a modular inversion (expensive)
field_t coordinate_check_product = 1;
for (auto& [x1, x2] : x_coordinate_checks) {
auto x_diff = x2 - x1;
const field_t x_diff = x2 - x1;
coordinate_check_product *= x_diff;
}
coordinate_check_product.assert_is_not_zero("_variable_base_batch_mul_internal x-coordinate collision");
Expand Down
Loading
Loading