Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ TEST(crypto_merkle_tree, test_update_memberships)
std::vector<field_ct> new_roots_ct;

for (size_t i = 0; i < old_indices.size(); i++) {
auto idx_vec = field_ct(witness_ct(&builder, uint256_t(old_indices[i]))).decompose_into_bits(depth);
auto idx_vec = field_ct(witness_ct(&builder, uint256_t(old_indices[i]))).decompose_into_bits();
old_indices_ct.push_back(idx_vec);
old_values_ct.push_back(witness_ct(&builder, old_values[i]));
old_hash_paths_ct.push_back(create_witness_hash_path(builder, old_hash_paths[i]));
Expand Down
132 changes: 74 additions & 58 deletions barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,23 @@ field_t<Builder>::field_t(Builder* parent_context)
template <typename Builder>
field_t<Builder>::field_t(const witness_t<Builder>& value)
: context(value.context)
{
additive_constant = 0;
multiplicative_constant = 1;
witness_index = value.witness_index;
}
, additive_constant(0)
, multiplicative_constant(1)
, witness_index(value.witness_index)
{}

template <typename Builder>
field_t<Builder>::field_t(Builder* parent_context, const bb::fr& value)
: context(parent_context)
{
additive_constant = value;
multiplicative_constant = bb::fr::one();
witness_index = IS_CONSTANT;
}
, additive_constant(value)
, multiplicative_constant(bb::fr::one())
, witness_index(IS_CONSTANT)
{}

template <typename Builder> field_t<Builder>::field_t(const bool_t<Builder>& other)
template <typename Builder>
field_t<Builder>::field_t(const bool_t<Builder>& other)
: context(other.context)
{
context = (other.context == nullptr) ? nullptr : other.context;
if (other.witness_index == IS_CONSTANT) {
additive_constant = (other.witness_bool ^ other.witness_inverted) ? bb::fr::one() : bb::fr::zero();
multiplicative_constant = bb::fr::one();
Expand Down Expand Up @@ -417,10 +416,9 @@ template <typename Builder> field_t<Builder> field_t<Builder>::pow(const size_t
*/
template <typename Builder> field_t<Builder> field_t<Builder>::madd(const field_t& to_mul, const field_t& to_add) const
{
Builder* ctx = (context == nullptr) ? (to_mul.context == nullptr ? to_add.context : to_mul.context) : context;
Builder* ctx = first_non_null<Builder>(context, to_mul.context, to_add.context);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and below, use first_non_null to avoid nested ternary operators


if ((to_mul.witness_index == IS_CONSTANT) && (to_add.witness_index == IS_CONSTANT) &&
(witness_index == IS_CONSTANT)) {
if (to_mul.is_constant() && to_add.is_constant() && this->is_constant()) {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and below, replace foo.witness_index == IS_CONSTANT with foo.is_constant() that performs the same check, but seems more readable/expressive

return ((*this) * to_mul + to_add);
}

Expand Down Expand Up @@ -473,32 +471,41 @@ template <typename Builder> field_t<Builder> field_t<Builder>::madd(const field_
return result;
}

/**
* @brief Returns (this + a + b)
*
* @details Use custom big_mul_gate to save gates.
*
* @tparam Builder
* @param add_a
* @param add_b
* @return field_t<Builder>
*/
template <typename Builder> field_t<Builder> field_t<Builder>::add_two(const field_t& add_a, const field_t& add_b) const
{
Builder* ctx = (context == nullptr) ? (add_a.context == nullptr ? add_b.context : add_a.context) : context;
Builder* ctx = first_non_null<Builder>(context, add_a.context, add_b.context);

if ((add_a.witness_index == IS_CONSTANT) && (add_b.witness_index == IS_CONSTANT) &&
(witness_index == IS_CONSTANT)) {
if ((add_a.is_constant()) && (add_b.is_constant()) && (this->is_constant())) {
return ((*this) + add_a + add_b).normalize();
}
bb::fr q_1 = multiplicative_constant;
bb::fr q_2 = add_a.multiplicative_constant;
bb::fr q_3 = add_b.multiplicative_constant;
bb::fr q_c = additive_constant + add_a.additive_constant + add_b.additive_constant;

bb::fr a = witness_index == IS_CONSTANT ? bb::fr(0) : ctx->get_variable(witness_index);
bb::fr b = add_a.witness_index == IS_CONSTANT ? bb::fr(0) : ctx->get_variable(add_a.witness_index);
bb::fr c = add_b.witness_index == IS_CONSTANT ? bb::fr(0) : ctx->get_variable(add_b.witness_index);
bb::fr a = this->is_constant() ? bb::fr(0) : ctx->get_variable(witness_index);
bb::fr b = add_a.is_constant() ? bb::fr(0) : ctx->get_variable(add_a.witness_index);
bb::fr c = add_b.is_constant() ? bb::fr(0) : ctx->get_variable(add_b.witness_index);

bb::fr out = a * q_1 + b * q_2 + c * q_3 + q_c;

field_t<Builder> result(ctx);
result.witness_index = ctx->add_variable(out);

ctx->create_big_mul_gate({
.a = witness_index == IS_CONSTANT ? ctx->zero_idx : witness_index,
.b = add_a.witness_index == IS_CONSTANT ? ctx->zero_idx : add_a.witness_index,
.c = add_b.witness_index == IS_CONSTANT ? ctx->zero_idx : add_b.witness_index,
.a = this->is_constant() ? ctx->zero_idx : witness_index,
.b = add_a.is_constant() ? ctx->zero_idx : add_a.witness_index,
.c = add_b.is_constant() ? ctx->zero_idx : add_b.witness_index,
.d = result.witness_index,
.mul_scaling = bb::fr(0),
.a_scaling = q_1,
Expand All @@ -512,12 +519,12 @@ template <typename Builder> field_t<Builder> field_t<Builder>::add_two(const fie
}

/**
* @brief Return an new element, where the in-circuit witness contains the actual represented value (multiplicative
* @brief Return a new element, where the in-circuit witness contains the actual represented value (multiplicative
* constant is 1 and additive_constant is 0)
*
* @details If the element is a constant or it is already normalized, just return the element itself
*
*@todo We need to add a mechanism into the circuit builders for caching normalized variants for fields and bigfields.
* @todo We need to add a mechanism into the circuit builders for caching normalized variants for fields and bigfields.
*It should make the circuits smaller. https://github.com/AztecProtocol/barretenberg/issues/1052
*
* @tparam Builder
Expand Down Expand Up @@ -703,11 +710,10 @@ template <typename Builder> bb::fr field_t<Builder>::get_value() const
if (witness_index != IS_CONSTANT) {
ASSERT(context != nullptr);
return (multiplicative_constant * context->get_variable(witness_index)) + additive_constant;
} else {
ASSERT(this->multiplicative_constant == bb::fr::one());
// A constant field_t's value is tracked wholly by its additive_constant member.
return additive_constant;
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed redundant else

ASSERT(this->multiplicative_constant == bb::fr::one());
// A constant field_t's value is tracked wholly by its additive_constant member.
return additive_constant;
}

template <typename Builder> bool_t<Builder> field_t<Builder>::operator==(const field_t& other) const
Expand Down Expand Up @@ -763,22 +769,34 @@ field_t<Builder> field_t<Builder>::conditional_negate(const bool_t<Builder>& pre
return multiplicand.madd(*this, *this);
}

// if predicate == true then return lhs, else return rhs
/**
* @brief If predicate == true then return lhs, else return rhs

* @details Conditional assign x = (predicate) ? lhs : rhs can be expressed arithmetically as follows
* x = predciate * lhs + (1 - predicate) * rhs
* which is equivalent to
* x = (lhs - rhs) * predicate + rhs = (lhs - rhs)*madd(predicate, rhs)
* where take advantage of `madd()` to create less gates.
*
* @return field_t<Builder>
*/
template <typename Builder>
field_t<Builder> field_t<Builder>::conditional_assign(const bool_t<Builder>& predicate,
const field_t& lhs,
const field_t& rhs)
{
// If the predicate is constant, the conditional assignment can be done out of circuit
if (predicate.is_constant()) {
auto result = field_t(predicate.get_value() ? lhs : rhs);
result.set_origin_tag(OriginTag(predicate.get_origin_tag(), lhs.get_origin_tag(), rhs.get_origin_tag()));
return result;
}
// if lhs and rhs are the same witness, just return it!
// If lhs and rhs are the same witness, just return it
if (lhs.get_witness_index() == rhs.get_witness_index() && (lhs.additive_constant == rhs.additive_constant) &&
(lhs.multiplicative_constant == rhs.multiplicative_constant)) {
return lhs;
}

return (lhs - rhs).madd(predicate, rhs);
}

Expand Down Expand Up @@ -925,15 +943,15 @@ field_t<Builder> field_t<Builder>::select_from_three_bit_table(const std::array<
return R6;
}

/**
* @brief Constrain a + b + c + d to be equal to 0
*/
template <typename Builder>
void field_t<Builder>::evaluate_linear_identity(const field_t& a, const field_t& b, const field_t& c, const field_t& d)
{
Builder* ctx = a.context == nullptr
? (b.context == nullptr ? (c.context == nullptr ? d.context : c.context) : b.context)
: a.context;
Builder* ctx = first_non_null(a.context, b.context, c.context, d.context);

if (a.witness_index == IS_CONSTANT && b.witness_index == IS_CONSTANT && c.witness_index == IS_CONSTANT &&
d.witness_index == IS_CONSTANT) {
if (a.is_constant() && b.is_constant() && c.is_constant() && d.is_constant()) {
ASSERT(a.get_value() + b.get_value() + c.get_value() + d.get_value() == 0);
return;
}
Expand All @@ -946,10 +964,10 @@ void field_t<Builder>::evaluate_linear_identity(const field_t& a, const field_t&
bb::fr q_c = a.additive_constant + b.additive_constant + c.additive_constant + d.additive_constant;

ctx->create_big_add_gate({
a.witness_index == IS_CONSTANT ? ctx->zero_idx : a.witness_index,
b.witness_index == IS_CONSTANT ? ctx->zero_idx : b.witness_index,
c.witness_index == IS_CONSTANT ? ctx->zero_idx : c.witness_index,
d.witness_index == IS_CONSTANT ? ctx->zero_idx : d.witness_index,
a.is_constant() ? ctx->zero_idx : a.witness_index,
b.is_constant() ? ctx->zero_idx : b.witness_index,
c.is_constant() ? ctx->zero_idx : c.witness_index,
d.is_constant() ? ctx->zero_idx : d.witness_index,
q_1,
q_2,
q_3,
Expand All @@ -964,12 +982,9 @@ void field_t<Builder>::evaluate_polynomial_identity(const field_t& a,
const field_t& c,
const field_t& d)
{
Builder* ctx = a.context == nullptr
? (b.context == nullptr ? (c.context == nullptr ? d.context : c.context) : b.context)
: a.context;
Builder* ctx = first_non_null(a.context, b.context, c.context, d.context);

if (a.witness_index == IS_CONSTANT && b.witness_index == IS_CONSTANT && c.witness_index == IS_CONSTANT &&
d.witness_index == IS_CONSTANT) {
if (a.is_constant() && b.is_constant() && c.is_constant() && d.is_constant()) {
ASSERT((a.get_value() * b.get_value() + c.get_value() + d.get_value()).is_zero());
return;
}
Expand All @@ -983,10 +998,10 @@ void field_t<Builder>::evaluate_polynomial_identity(const field_t& a,
bb::fr q_c = a.additive_constant * b.additive_constant + c.additive_constant + d.additive_constant;

ctx->create_big_mul_gate({
a.witness_index == IS_CONSTANT ? ctx->zero_idx : a.witness_index,
b.witness_index == IS_CONSTANT ? ctx->zero_idx : b.witness_index,
c.witness_index == IS_CONSTANT ? ctx->zero_idx : c.witness_index,
d.witness_index == IS_CONSTANT ? ctx->zero_idx : d.witness_index,
a.is_constant() ? ctx->zero_idx : a.witness_index,
b.is_constant() ? ctx->zero_idx : b.witness_index,
c.is_constant() ? ctx->zero_idx : c.witness_index,
d.is_constant() ? ctx->zero_idx : d.witness_index,
q_m,
q_1,
q_2,
Expand All @@ -1001,7 +1016,7 @@ void field_t<Builder>::evaluate_polynomial_identity(const field_t& a,
*/
template <typename Builder> field_t<Builder> field_t<Builder>::accumulate(const std::vector<field_t>& input)
{
if (input.size() == 0) {
if (input.empty()) {
return field_t<Builder>(nullptr, 0);
}
if (input.size() == 1) {
Expand Down Expand Up @@ -1043,7 +1058,7 @@ template <typename Builder> field_t<Builder> field_t<Builder>::accumulate(const
}
ctx = (element.get_context() ? element.get_context() : ctx);
}
if (accumulator.size() == 0) {
if (accumulator.empty()) {
return constant_term;
} else if (accumulator.size() != input.size()) {
accumulator[0] += constant_term;
Expand Down Expand Up @@ -1138,7 +1153,7 @@ std::array<field_t<Builder>, 3> field_t<Builder>::slice(const uint8_t msb, const
}

/**
* @brief Build a circuit allowing a user to prove that they have deomposed `this` into bits.
* @brief Build a circuit allowing a user to prove that they have decomposed `this` into bits.
*
* @details A bit vector `result` is extracted and used to to construct a sum `sum` using the normal binary expansion.
* Along the way, we extract a value `shifted_high_limb` which is equal to `sum_hi` in the natural decomposition
Expand All @@ -1162,16 +1177,16 @@ std::array<field_t<Builder>, 3> field_t<Builder>::slice(const uint8_t msb, const
*/
template <typename Builder>
std::vector<bool_t<Builder>> field_t<Builder>::decompose_into_bits(
const size_t num_bits, const std::function<witness_t<Builder>(Builder*, uint64_t, uint256_t)> get_bit) const
const std::function<witness_t<Builder>(Builder*, uint64_t, uint256_t)> get_bit) const

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the num_bits parameter wasn't handled correctly and the only non-test appearance used the default value 256, so I removed the argument.
moreover, this method seems like a good candidate to be deleted

{
ASSERT(num_bits <= 256);
static constexpr size_t num_bits = 256;
ASSERT(num_bits == 256);
static constexpr size_t midpoint = num_bits / 2 - 1;
std::vector<bool_t<Builder>> result(num_bits);

const uint256_t val_u256 = static_cast<uint256_t>(get_value());
field_t<Builder> sum(context, 0);
field_t<Builder> shifted_high_limb(context, 0); // will equal high 128 bits, left shifted by 128 bits
// TODO: Guido will make a PR that will fix an error here; hard-coded 127 is incorrect when 128 < num_bits < 256.
// Extract bit vector and show that it has the same value as `this`.
for (size_t i = 0; i < num_bits; ++i) {
bool_t<Builder> bit = get_bit(context, num_bits - 1 - i, val_u256);
bit.set_origin_tag(tag);
Expand All @@ -1180,8 +1195,9 @@ std::vector<bool_t<Builder>> field_t<Builder>::decompose_into_bits(
field_t<Builder> scaling_factor(context, scaling_factor_value);

sum = sum + (scaling_factor * bit);
if (i == 127)
if (i == midpoint) {
shifted_high_limb = sum;
}
}

this->assert_equal(sum); // `this` and `sum` are both normalized here.
Expand Down
38 changes: 21 additions & 17 deletions barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@

namespace bb::stdlib {

// Recursive helper to determine first non-null ptr to avoid sequential ternary choices.
template <typename T> T* first_non_null(T* ptr)
{
return ptr;
}

template <typename T, typename... Ts> T* first_non_null(T* first, Ts*... rest)
{
return first ? first : first_non_null(rest...);
}

template <typename Builder> class bool_t;
template <typename Builder> class field_t {
public:
Expand Down Expand Up @@ -184,13 +195,6 @@ template <typename Builder> class field_t {

field_t invert() const { return (field_t(1) / field_t(*this)).normalize(); }

static field_t coset_generator(const size_t generator_idx)
{
return field_t(bb::fr::coset_generator(generator_idx));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two methods are not needed, since our verifiers don't use FFT

}

static field_t external_coset_generator() { return field_t(bb::fr::external_coset_generator()); }

field_t operator-() const
{
field_t result(*this);
Expand Down Expand Up @@ -290,7 +294,7 @@ template <typename Builder> class field_t {
uint32_t set_public() const { return context->set_public_input(normalize().witness_index); }

/**
* Create a witness form a constant. This way the value of the witness is fixed and public (public, because the
* Create a witness from a constant. This way the value of the witness is fixed and public (public, because the
* value becomes hard-coded as an element of the q_c selector vector).
*/
void convert_constant_to_fixed_witness(Builder* ctx)
Expand Down Expand Up @@ -328,26 +332,25 @@ template <typename Builder> class field_t {
* @brief Get the index of a normalized version of this element
*
* @details Most of the time when using field elements in other parts of stdlib we want to use this API instead of
* get_witness index. The reason is it will prevent some soundess vulnerabilities
* get_witness index. The reason is it will prevent some soundness vulnerabilities
*
* @return uint32_t
*/
uint32_t get_normalized_witness_index() const { return normalize().witness_index; }

std::vector<bool_t<Builder>> decompose_into_bits(
size_t num_bits = 256,
std::function<witness_t<Builder>(Builder* ctx, uint64_t, uint256_t)> get_bit =
[](Builder* ctx, uint64_t j, const uint256_t& val) {
return witness_t<Builder>(ctx, val.get_bit(j));
}) const;

/**
* @brief Return (a < b) as bool circuit type.
* This method *assumes* that both a and b are < 2^{input_bits} - 1
* This method *assumes* that both a and b are < 2^{num_bits} - 1
* i.e. it is not checked here, we assume this has been done previously
*
* @tparam Builder
* @tparam input_bits
* @tparam num_bits
* @param a
* @param b
* @return bool_t<Builder>
Expand All @@ -361,11 +364,12 @@ template <typename Builder> class field_t {
return uint256_t(a.get_value()) < uint256_t(b.get_value());
}

// a < b
// both a and b are < K where K = 2^{input_bits} - 1
// if a < b, this implies b - a - 1 < K
// if a >= b, this implies b - a + K - 1 < K
// i.e. (b - a - 1) * q + (b - a + K - 1) * (1 - q) = r < K
// Let q = (a < b)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly expanded the explanation here

// Assume both a and b are < K where K = 2^{num_bits} - 1
// if q == 1, then 0 < b - a - 1 < K
// if q == 0, then 0 < b - a + K - 1 < K
// i.e. for any bool value of q:
// (b - a - 1) * q + (b - a + K - 1) * (1 - q) = r < K
// q.(b - a - b + a) + b - a + K - 1 - (K - 1).q - q = r
// b - a + (K - 1) - (K).q = r
uint256_t range_constant = (uint256_t(1) << num_bits);
Expand Down
Loading