Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
namespace bb::stdlib {

template <typename Builder>
void validate_split_in_field(const field_t<Builder>& lo,
const field_t<Builder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus)
void validate_split_in_field_unsafe(const field_t<Builder>& lo,
const field_t<Builder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus)
{
const size_t hi_bits = static_cast<size_t>(field_modulus.get_msb()) + 1 - lo_bits;

Expand Down Expand Up @@ -79,7 +79,9 @@ std::pair<field_t<Builder>, field_t<Builder>> split_unique(const field_t<Builder
hi.set_origin_tag(field.get_origin_tag());

// Component 2: Field validation against bn254 scalar field modulus
validate_split_in_field(lo, hi, lo_bits, native::modulus);
// Note: We use _unsafe variant because Component 3 applies range constraints (unless explicitly skipped). When
// range constraints are skipped, caller must ensure they are applied elsewhere.
validate_split_in_field_unsafe(lo, hi, lo_bits, native::modulus);

// Component 3: Range constraints (unless skipped)
if (!skip_range_constraints) {
Expand All @@ -106,15 +108,15 @@ template std::pair<field_t<bb::UltraCircuitBuilder>, field_t<bb::UltraCircuitBui
template std::pair<field_t<bb::MegaCircuitBuilder>, field_t<bb::MegaCircuitBuilder>> split_unique(
const field_t<bb::MegaCircuitBuilder>& field, const size_t lo_bits, const bool skip_range_constraints);

// Explicit instantiations for validate_split_in_field
template void validate_split_in_field(const field_t<bb::UltraCircuitBuilder>& lo,
const field_t<bb::UltraCircuitBuilder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus);
template void validate_split_in_field(const field_t<bb::MegaCircuitBuilder>& lo,
const field_t<bb::MegaCircuitBuilder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus);
// Explicit instantiations for validate_split_in_field_unsafe
template void validate_split_in_field_unsafe(const field_t<bb::UltraCircuitBuilder>& lo,
const field_t<bb::UltraCircuitBuilder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus);
template void validate_split_in_field_unsafe(const field_t<bb::MegaCircuitBuilder>& lo,
const field_t<bb::MegaCircuitBuilder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus);

// Explicit instantiations for mark_witness_as_used
template void mark_witness_as_used(const field_t<bb::UltraCircuitBuilder>& field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,25 @@ std::pair<field_t<Builder>, field_t<Builder>> split_unique(const field_t<Builder
const bool skip_range_constraints = false);

/**
* @brief Validates that lo + hi * 2^lo_bits < field_modulus
* @details Can be used in conjunction with range constraints on lo and hi to establish a unique decomposition of a
* field element.
* @brief Validates that lo + hi * 2^lo_bits < field_modulus (assuming range constraints on lo and hi)
* @details Uses a borrow-subtraction algorithm to check the inequality. Can be used in conjunction with range
* constraints on lo and hi to establish a unique decomposition of a field element.
*
* @warning: This function only checks the borrow arithmetic; it does NOT apply the following range constraints which
* are necessary to establish the above inequality in the integer sense:
* - lo < 2^lo_bits
* - hi < 2^hi_bits (where hi_bits = field_modulus.get_msb() + 1 - lo_bits)
*
* @param lo The low limb
* @param hi The high limb
* @param lo_bits The bit position at which the split occurred
* @param field_modulus The field modulus to validate against
*/
template <typename Builder>
void validate_split_in_field(const field_t<Builder>& lo,
const field_t<Builder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus);
void validate_split_in_field_unsafe(const field_t<Builder>& lo,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, love the _unsafe usage here.

const field_t<Builder>& hi,
const size_t lo_bits,
const uint256_t& field_modulus);

/**
* @brief Mark a field_t witness as used (for UltraBuilder only).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,8 @@ typename cycle_group<Builder>::batch_mul_internal_output cycle_group<Builder>::_
std::array<MultiTableId, 2> table_id = table::get_lookup_table_ids_for_point(point);
multitable_ids.push_back(table_id[0]);
multitable_ids.push_back(table_id[1]);
scalar_limbs.push_back(scalar.lo);
scalar_limbs.push_back(scalar.hi);
scalar_limbs.push_back(scalar.lo());
scalar_limbs.push_back(scalar.hi());
}

// Look up the multiples of each slice of each lo/hi scalar limb in the corresponding plookup table.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,42 @@

namespace bb::stdlib {

/**
* @brief Private constructor that skips field validation (for internal use only)
* @details This constructor is used internally in contexts where validation has already been performed externally
* or where it is not required at all (e.g., 256-bit bitstrings).
*
* @tparam Builder
* @param lo Low LO_BITS of the scalar
* @param hi High HI_BITS of the scalar
* @param flag SkipValidation::FLAG explicitly indicates that validation should be skipped
*/
template <typename Builder>
cycle_scalar<Builder>::cycle_scalar(const field_t& _lo, const field_t& _hi, bool skip_validation)
: lo(_lo)
, hi(_hi)
cycle_scalar<Builder>::cycle_scalar(const field_t& lo, const field_t& hi, [[maybe_unused]] SkipValidation flag)
: _lo(lo)
, _hi(hi)
{}

/**
* @brief Construct a cycle_scalar from lo and hi field elements
* @details Standard public constructor. Validates that (lo + hi * 2^LO_BITS) is less than the Grumpkin scalar field
* modulus. Use this constructor when creating cycle_scalars from arbitrary field elements that may not have been
* previously validated.
*
* @warning The validation performed by this constructor is only sound if the resulting cycle_scalar is used in a
* scalar multiplication operation (batch_mul), which provides the necessary range constraints on lo and hi. See
* validate_scalar_is_in_field() documentation for details.
*
* @tparam Builder
* @param lo Low LO_BITS of the scalar
* @param hi High HI_BITS of the scalar
*/
template <typename Builder>
cycle_scalar<Builder>::cycle_scalar(const field_t& lo, const field_t& hi)
: _lo(lo)
, _hi(hi)
{
// Unless explicitly skipped, validate the scalar is in the Grumpkin scalar field
if (!skip_validation) {
validate_scalar_is_in_field();
}
validate_scalar_is_in_field();
}

/**
Expand All @@ -33,8 +60,8 @@ template <typename Builder> cycle_scalar<Builder>::cycle_scalar(const ScalarFiel
{
const uint256_t value(in);
const auto [lo_v, hi_v] = decompose_into_lo_hi_u256(value);
lo = lo_v;
hi = hi_v;
_lo = lo_v;
_hi = hi_v;
}

/**
Expand Down Expand Up @@ -81,7 +108,7 @@ cycle_scalar<Builder> cycle_scalar<Builder>::from_u256_witness(Builder* context,
const uint256_t hi_v = bitstring.slice(LO_BITS, num_bits);
auto lo = field_t::from_witness(context, typename field_t::native(lo_v));
auto hi = field_t::from_witness(context, typename field_t::native(hi_v));
cycle_scalar result{ lo, hi, /*skip_validation=*/true };
cycle_scalar result{ lo, hi, SkipValidation::FLAG };
result._num_bits = num_bits;
return result;
}
Expand All @@ -102,7 +129,7 @@ template <typename Builder> cycle_scalar<Builder> cycle_scalar<Builder>::create_
// Note: split_unique validates the value is less than bn254::fr::modulus
auto [lo, hi] = split_unique(in, LO_BITS, /*skip_range_constraints=*/true);
// Note: we skip validation here since it is redundant with `split_unique`
return cycle_scalar{ lo, hi, /*skip_validation=*/true };
return cycle_scalar{ lo, hi, SkipValidation::FLAG };
}

/**
Expand Down Expand Up @@ -147,10 +174,10 @@ template <typename Builder> cycle_scalar<Builder>::cycle_scalar(BigScalarField&
const uint256_t value((scalar.get_value() % uint512_t(ScalarField::modulus)).lo);
const auto [value_lo, value_hi] = decompose_into_lo_hi_u256(value);

lo = value_lo;
hi = value_hi;
lo.set_origin_tag(scalar.get_origin_tag());
hi.set_origin_tag(scalar.get_origin_tag());
_lo = value_lo;
_hi = value_hi;
_lo.set_origin_tag(scalar.get_origin_tag());
_hi.set_origin_tag(scalar.get_origin_tag());
return;
}

Expand Down Expand Up @@ -187,7 +214,7 @@ template <typename Builder> cycle_scalar<Builder>::cycle_scalar(BigScalarField&
BB_ASSERT_GT(NUM_LIMB_BITS * 2, LO_BITS);
BB_ASSERT_LT(NUM_LIMB_BITS, LO_BITS);

// Step 3: limb1 contributes to both *this.lo and *this.hi. Compute the values of the two limb1 slices
// Step 3: limb1 contributes to both *this._lo and *this._hi. Compute the values of the two limb1 slices
const size_t lo_bits_in_limb_1 = LO_BITS - NUM_LIMB_BITS;
const auto limb1_max_bits = static_cast<size_t>(limb1_max.get_msb() + 1);
auto [limb1_lo, limb1_hi] = limb1.no_wrap_split_at(lo_bits_in_limb_1, limb1_max_bits);
Expand All @@ -196,41 +223,50 @@ template <typename Builder> cycle_scalar<Builder>::cycle_scalar(BigScalarField&
limb1_lo.set_origin_tag(limb1.get_origin_tag());
limb1_hi.set_origin_tag(limb1.get_origin_tag());

// Step 4: Construct *this.lo out of limb0 and limb1_lo
lo = limb0 + (limb1_lo * BigScalarField::shift_1);
// Step 4: Construct *this._lo out of limb0 and limb1_lo
_lo = limb0 + (limb1_lo * BigScalarField::shift_1);

// Step 5: Construct *this.hi out of limb1_hi, limb2 and limb3
// Step 5: Construct *this._hi out of limb1_hi, limb2 and limb3
const uint256_t limb_2_shift = uint256_t(1) << ((2 * NUM_LIMB_BITS) - LO_BITS);
const uint256_t limb_3_shift = uint256_t(1) << ((3 * NUM_LIMB_BITS) - LO_BITS);
hi = limb1_hi.add_two(limb2 * limb_2_shift, limb3 * limb_3_shift);
_hi = limb1_hi.add_two(limb2 * limb_2_shift, limb3 * limb_3_shift);

// Manually propagate the origin tag of the scalar to the lo/hi limbs
lo.set_origin_tag(scalar.get_origin_tag());
hi.set_origin_tag(scalar.get_origin_tag());
_lo.set_origin_tag(scalar.get_origin_tag());
_hi.set_origin_tag(scalar.get_origin_tag());

validate_scalar_is_in_field();
};

template <typename Builder> bool cycle_scalar<Builder>::is_constant() const
{
return (lo.is_constant() && hi.is_constant());
return (_lo.is_constant() && _hi.is_constant());
}

/**
* @brief Validates that the scalar (lo + hi * 2^LO_BITS) is less than the Grumpkin scalar field modulus
* @details Delegates to `validate_split_in_field`
* @details Delegates to `validate_split_in_field_unsafe`, which uses a borrow-subtraction algorithm to check the
* inequality.
*
* @warning This validation assumes range constraints on the lo and hi limbs. Specifically:
* - lo < 2^LO_BITS (128 bits)
* - hi < 2^HI_BITS (126 bits)
*
* By design, these range constraints are not applied by this function. Instead, they are implicitly enforced when
* the cycle_scalar is used in scalar multiplication via batch_mul.
*
* @tparam Builder
*/
template <typename Builder> void cycle_scalar<Builder>::validate_scalar_is_in_field() const
{
validate_split_in_field(lo, hi, LO_BITS, ScalarField::modulus);
// Using _unsafe variant: range constraints are deferred to batch_mul's decompose_into_default_range
validate_split_in_field_unsafe(_lo, _hi, LO_BITS, ScalarField::modulus);
}

template <typename Builder> typename cycle_scalar<Builder>::ScalarField cycle_scalar<Builder>::get_value() const
{
uint256_t lo_v(lo.get_value());
uint256_t hi_v(hi.get_value());
uint256_t lo_v(_lo.get_value());
uint256_t hi_v(_hi.get_value());
return ScalarField(lo_v + (hi_v << LO_BITS));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ template <typename Builder> class cycle_group;
* @note The reason for not using `bigfield` to represent cycle scalars is that `bigfield` is inefficient in this
* context. All required range checks for `cycle_scalar` can be obtained for free from the `batch_mul` algorithm, making
* the range checks performed by `bigfield` largely redundant.
*
* @warning: The field validation performed by cycle_scalar constructors assumes that the lo/hi limbs will
* be range-constrained during scalar multiplication. The validation is ONLY sound when the cycle_scalar is used in a
* batch_mul operation (which applies range constraints as part of the MSM algorithm).
*/
template <typename Builder> class cycle_scalar {
public:
Expand All @@ -38,10 +42,11 @@ template <typename Builder> class cycle_scalar {
static constexpr size_t LO_BITS = field_t::native::Params::MAX_BITS_PER_ENDOMORPHISM_SCALAR;
static constexpr size_t HI_BITS = NUM_BITS - LO_BITS;

field_t lo; // LO_BITS of the scalar
field_t hi; // Remaining HI_BITS of the scalar
enum class SkipValidation { FLAG };

private:
field_t _lo; // LO_BITS of the scalar
field_t _hi; // Remaining HI_BITS of the scalar
size_t _num_bits = NUM_BITS;

/**
Expand All @@ -55,58 +60,60 @@ template <typename Builder> class cycle_scalar {
return { value.slice(0, LO_BITS), value.slice(LO_BITS, NUM_BITS) };
}

cycle_scalar(const field_t& lo, const field_t& hi, SkipValidation flag);

/**
* @brief Validates that the scalar (lo + hi * 2^LO_BITS) is less than the Grumpkin scalar field modulus
*/
void validate_scalar_is_in_field() const;

public:
// AUDITTODO: this is used only in the fuzzer.
cycle_scalar(const ScalarField& _in = 0);
cycle_scalar(const field_t& _lo, const field_t& _hi, bool skip_validation = false);
// AUDITTODO: this is used only in the fuzzer. Its not inherently problematic, but perhaps the fuzzer should use a
// production entrypoint.
cycle_scalar(const ScalarField& in = 0);
cycle_scalar(const field_t& lo, const field_t& hi);
static cycle_scalar from_witness(Builder* context, const ScalarField& value);
static cycle_scalar from_u256_witness(Builder* context, const uint256_t& bitstring);
static cycle_scalar create_from_bn254_scalar(const field_t& _in);
static cycle_scalar create_from_bn254_scalar(const field_t& in);
explicit cycle_scalar(BigScalarField& scalar);

[[nodiscard]] bool is_constant() const;
ScalarField get_value() const;
Builder* get_context() const { return lo.get_context() != nullptr ? lo.get_context() : hi.get_context(); }
Builder* get_context() const { return _lo.get_context() != nullptr ? _lo.get_context() : _hi.get_context(); }
[[nodiscard]] size_t num_bits() const { return _num_bits; }

/**
* @brief Validates that the scalar (lo + hi * 2^LO_BITS) is less than the Grumpkin scalar field modulus
*/
void validate_scalar_is_in_field() const;
const field_t& lo() const { return _lo; }
const field_t& hi() const { return _hi; }

/**
* @brief Get the origin tag of the cycle_scalar (a merge of the lo and hi tags)
*
* @return OriginTag
*/
OriginTag get_origin_tag() const { return OriginTag(lo.get_origin_tag(), hi.get_origin_tag()); }
OriginTag get_origin_tag() const { return OriginTag(_lo.get_origin_tag(), _hi.get_origin_tag()); }
/**
* @brief Set the origin tag of lo and hi members of cycle scalar
*
* @param tag
*/
void set_origin_tag(const OriginTag& tag) const
void set_origin_tag(const OriginTag& tag)
{
lo.set_origin_tag(tag);
hi.set_origin_tag(tag);
_lo.set_origin_tag(tag);
_hi.set_origin_tag(tag);
}
/**
* @brief Set the free witness flag for the cycle scalar's tags
*/
void set_free_witness_tag()
{
lo.set_free_witness_tag();
hi.set_free_witness_tag();
_lo.set_free_witness_tag();
_hi.set_free_witness_tag();
}
/**
* @brief Unset the free witness flag for the cycle scalar's tags
*/
void unset_free_witness_tag()
{
lo.unset_free_witness_tag();
hi.unset_free_witness_tag();
_lo.unset_free_witness_tag();
_hi.unset_free_witness_tag();
}
};

Expand Down
Loading
Loading