Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion halo2-ecc/src/ecc/ecdsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ where
u2.limbs().to_vec(),
base_chip.limb_bits,
var_window_bits,
true, // we can call it with scalar_is_safe = true because of the u2_small check below
);

// check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey
Expand All @@ -77,7 +78,6 @@ where
let x1 = scalar_chip.enforce_less_than(ctx, sum.x);
let equal_check = big_is_equal::assign(base_chip.gate(), ctx, x1.0, r);

// TODO: maybe the big_less_than is optional?
let u1_small = big_less_than::assign(
base_chip.range(),
ctx,
Expand Down
34 changes: 23 additions & 11 deletions halo2-ecc/src/ecc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ where
// y_3 = lambda (x - x_3) - y (mod p)
/// # Assumptions
/// * `P.y != 0`
/// * `P` is not the point at infinity
/// * `P` is not the point at infinity (undefined behavior otherwise)
pub fn ec_double<F: PrimeField, FC: FieldChip<F>>(
chip: &FC,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -463,14 +463,15 @@ where
StrictEcPoint::new(x, y)
}

// computes [scalar] * P on short Weierstrass curve `y^2 = x^3 + b`
// - `scalar` is represented as a reference array of `AssignedValue`s
// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
// assumes:
/// Computes `[scalar] * P` on short Weierstrass curve `y^2 = x^3 + b`
/// - `scalar` is represented as a reference array of `AssignedValue`s
/// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
///
/// # Assumptions
/// * `P` is not the point at infinity
/// * `scalar` is less than the order of `P`
/// * `scalar > 0`
/// * If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`)
/// * `scalar_i < 2^{max_bits} for all i`
/// * `max_bits <= modulus::<F>.bits()`, and equality only allowed when the order of `P` equals the modulus of `F`
pub fn scalar_multiply<F: PrimeField, FC>(
Expand All @@ -480,6 +481,7 @@ pub fn scalar_multiply<F: PrimeField, FC>(
scalar: Vec<AssignedValue<F>>,
max_bits: usize,
window_bits: usize,
scalar_is_safe: bool,
) -> EcPoint<F, FC::FieldPoint>
where
FC: FieldChip<F> + Selectable<F, FC::FieldPoint>,
Expand Down Expand Up @@ -530,7 +532,7 @@ where
let double = ec_double(chip, ctx, &P);
cached_points.push(double);
} else {
let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false);
let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe);
cached_points.push(new_point);
}
}
Expand All @@ -555,7 +557,7 @@ where
&rounded_bits
[rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx],
);
let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, false);
let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe);
let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]);

curr_point =
Expand Down Expand Up @@ -688,7 +690,7 @@ where
ctx,
&rand_start_vec[idx],
&rand_start_vec[idx + window_bits],
false,
true, // not necessary if we assume (2^w - 1) * A != +- A, but put in for safety
);
let point = into_strict_point(chip, ctx, point.clone());
let neg_mult_rand_start = into_strict_point(chip, ctx, neg_mult_rand_start);
Expand Down Expand Up @@ -1002,15 +1004,25 @@ where
ec_select(self.field_chip, ctx, P, Q, condition)
}

/// See [`scalar_multiply`] for more details.
pub fn scalar_mult(
&self,
ctx: &mut Context<F>,
P: EcPoint<F, FC::FieldPoint>,
scalar: Vec<AssignedValue<F>>,
max_bits: usize,
window_bits: usize,
scalar_is_safe: bool,
) -> EcPoint<F, FC::FieldPoint> {
scalar_multiply::<F, FC>(self.field_chip, ctx, P, scalar, max_bits, window_bits)
scalar_multiply::<F, FC>(
self.field_chip,
ctx,
P,
scalar,
max_bits,
window_bits,
scalar_is_safe,
)
}

// default for most purposes
Expand Down
6 changes: 2 additions & 4 deletions halo2-ecc/src/ecc/pippenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ where
/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
/// * `points` are all on the curve or the point at infinity
/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point)
/// * `2^max_scalar_bits != +-1 mod modulus::<F>()` where `max_scalar_bits = max_scalar_bits_per_cell * scalars[0].len()`
/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point
pub fn multi_exp_par<F: PrimeField, FC, C>(
chip: &FC,
Expand Down Expand Up @@ -337,7 +336,7 @@ where
// let any_point = (2^num_rounds - 1) * any_base
// TODO: can we remove all these random point operations somehow?
let mut any_point = ec_double(chip, ctx, any_points.last().unwrap());
any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], false);
any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true);

// compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point
// (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1)
Expand All @@ -351,8 +350,7 @@ where
}

any_sum = ec_double(chip, ctx, any_sum);
// assume 2^scalar_bits != +-1 mod modulus::<F>()
any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false);
any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true);

ec_sub_strict(chip, ctx, sum, any_sum)
}