Skip to content

Commit a88a92b

Browse files
Faster Montgomery multiplication (#974)
This adds an interleaved Montgomery multiplication and further optimizes Boxed Montgomery multiplications by removing allocations. For BoxedMontyForm at 4096 bits this is around a 40% improvement overall. Benchmarked on a 2021 MacBook Pro with custom benchmarks for different limb counts. Signed-off-by: Andrew Whitehead <[email protected]>
1 parent 6619298 commit a88a92b

File tree

7 files changed

+386
-225
lines changed

7 files changed

+386
-225
lines changed

benches/monty.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,6 @@ fn bench_montgomery_ops<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
133133
)
134134
});
135135

136-
group.bench_function("Bernstein-Yang invert, U256", |b| {
137-
b.iter_batched(
138-
|| {
139-
MontyForm::new(
140-
&U256::random_mod(&mut rng, params.modulus().as_nz_ref()),
141-
params,
142-
)
143-
},
144-
|x| black_box(x).invert(),
145-
BatchSize::SmallInput,
146-
)
147-
});
148-
149136
group.bench_function("multiplication, U256*U256", |b| {
150137
b.iter_batched(
151138
|| {
@@ -164,6 +151,19 @@ fn bench_montgomery_ops<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
164151
)
165152
});
166153

154+
group.bench_function("square, U256", |b| {
155+
b.iter_batched(
156+
|| {
157+
MontyForm::new(
158+
&U256::random_mod(&mut rng, params.modulus().as_nz_ref()),
159+
params,
160+
)
161+
},
162+
|x| x.square(),
163+
BatchSize::SmallInput,
164+
)
165+
});
166+
167167
group.bench_function("modpow, U256^U256", |b| {
168168
b.iter_batched(
169169
|| {

src/modular.rs

Lines changed: 161 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ pub trait Retrieve {
5555
#[cfg(test)]
5656
mod tests {
5757
use crate::{
58-
NonZero, U64, U256, Uint, const_monty_params,
58+
NonZero, U64, U128, U256, Uint, const_monty_params,
5959
modular::{
6060
const_monty_form::{ConstMontyForm, ConstMontyParams},
61+
mul::{mul_montgomery_form, square_montgomery_form},
6162
reduction::montgomery_reduction,
6263
},
6364
};
@@ -84,8 +85,10 @@ mod tests {
8485
);
8586
}
8687

88+
const_monty_params!(Modulus128, U128, "000000087b57be17f0ecdbf18a227bd9");
89+
8790
const_monty_params!(
88-
Modulus2,
91+
Modulus256,
8992
U256,
9093
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
9194
);
@@ -94,10 +97,10 @@ mod tests {
9497
fn test_reducing_one() {
9598
// Divide the value R by R, which should equal 1
9699
assert_eq!(
97-
montgomery_reduction::<{ Modulus2::LIMBS }>(
98-
&(Modulus2::PARAMS.one, Uint::ZERO),
99-
&Modulus2::PARAMS.modulus,
100-
Modulus2::PARAMS.mod_neg_inv()
100+
montgomery_reduction::<{ Modulus256::LIMBS }>(
101+
&(Modulus256::PARAMS.one, Uint::ZERO),
102+
&Modulus256::PARAMS.modulus,
103+
Modulus256::PARAMS.mod_neg_inv()
101104
),
102105
Uint::ONE
103106
);
@@ -107,26 +110,26 @@ mod tests {
107110
fn test_reducing_r2() {
108111
// Divide the value R^2 by R, which should equal R
109112
assert_eq!(
110-
montgomery_reduction::<{ Modulus2::LIMBS }>(
111-
&(Modulus2::PARAMS.r2, Uint::ZERO),
112-
&Modulus2::PARAMS.modulus,
113-
Modulus2::PARAMS.mod_neg_inv()
113+
montgomery_reduction::<{ Modulus256::LIMBS }>(
114+
&(Modulus256::PARAMS.r2, Uint::ZERO),
115+
&Modulus256::PARAMS.modulus,
116+
Modulus256::PARAMS.mod_neg_inv()
114117
),
115-
Modulus2::PARAMS.one
118+
Modulus256::PARAMS.one
116119
);
117120
}
118121

119122
#[test]
120123
fn test_reducing_r2_wide() {
121124
// Divide the value ONE^2 by R, which should equal ONE
122-
let (lo, hi) = Modulus2::PARAMS.one.square().split();
125+
let (lo, hi) = Modulus256::PARAMS.one.square().split();
123126
assert_eq!(
124-
montgomery_reduction::<{ Modulus2::LIMBS }>(
127+
montgomery_reduction::<{ Modulus256::LIMBS }>(
125128
&(lo, hi),
126-
&Modulus2::PARAMS.modulus,
127-
Modulus2::PARAMS.mod_neg_inv()
129+
&Modulus256::PARAMS.modulus,
130+
Modulus256::PARAMS.mod_neg_inv()
128131
),
129-
Modulus2::PARAMS.one
132+
Modulus256::PARAMS.one
130133
);
131134
}
132135

@@ -135,12 +138,12 @@ mod tests {
135138
// Reducing xR should return x
136139
let x =
137140
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
138-
let product = x.widening_mul(&Modulus2::PARAMS.one);
141+
let product = x.widening_mul(&Modulus256::PARAMS.one);
139142
assert_eq!(
140-
montgomery_reduction::<{ Modulus2::LIMBS }>(
143+
montgomery_reduction::<{ Modulus256::LIMBS }>(
141144
&product,
142-
&Modulus2::PARAMS.modulus,
143-
Modulus2::PARAMS.mod_neg_inv()
145+
&Modulus256::PARAMS.modulus,
146+
Modulus256::PARAMS.mod_neg_inv()
144147
),
145148
x
146149
);
@@ -151,21 +154,152 @@ mod tests {
151154
// Reducing xR^2 should return xR
152155
let x =
153156
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
154-
let product = x.widening_mul(&Modulus2::PARAMS.r2);
157+
let product = x.widening_mul(&Modulus256::PARAMS.r2);
155158

156159
// Computing xR mod modulus without Montgomery reduction
157-
let (lo, hi) = x.widening_mul(&Modulus2::PARAMS.one);
160+
let (lo, hi) = x.widening_mul(&Modulus256::PARAMS.one);
158161
let c = lo.concat(&hi);
159162
let red =
160-
c.rem_vartime(&NonZero::new(Modulus2::PARAMS.modulus.0.concat(&U256::ZERO)).unwrap());
163+
c.rem_vartime(&NonZero::new(Modulus256::PARAMS.modulus.0.concat(&U256::ZERO)).unwrap());
161164
let (lo, hi) = red.split();
162165
assert_eq!(hi, Uint::ZERO);
163166

164167
assert_eq!(
165-
montgomery_reduction::<{ Modulus2::LIMBS }>(
168+
montgomery_reduction::<{ Modulus256::LIMBS }>(
166169
&product,
167-
&Modulus2::PARAMS.modulus,
168-
Modulus2::PARAMS.mod_neg_inv()
170+
&Modulus256::PARAMS.modulus,
171+
Modulus256::PARAMS.mod_neg_inv()
172+
),
173+
lo
174+
);
175+
}
176+
177+
#[test]
178+
fn monty_mul_one_r() {
179+
// Multiply 1 by R and divide by R, which should equal 1
180+
assert_eq!(
181+
mul_montgomery_form::<{ Modulus128::LIMBS }>(
182+
&Uint::ONE,
183+
&Modulus128::PARAMS.one,
184+
&Modulus128::PARAMS.modulus,
185+
Modulus128::PARAMS.mod_neg_inv()
186+
),
187+
Uint::ONE
188+
);
189+
assert_eq!(
190+
mul_montgomery_form::<{ Modulus256::LIMBS }>(
191+
&Uint::ONE,
192+
&Modulus256::PARAMS.one,
193+
&Modulus256::PARAMS.modulus,
194+
Modulus256::PARAMS.mod_neg_inv()
195+
),
196+
Uint::ONE
197+
);
198+
}
199+
200+
#[test]
201+
fn monty_mul_r_r() {
202+
// Multiply R by R and divide by R, which should equal R
203+
assert_eq!(
204+
mul_montgomery_form::<{ Modulus128::LIMBS }>(
205+
&Modulus128::PARAMS.one,
206+
&Modulus128::PARAMS.one,
207+
&Modulus128::PARAMS.modulus,
208+
Modulus128::PARAMS.mod_neg_inv()
209+
),
210+
Modulus128::PARAMS.one
211+
);
212+
assert_eq!(
213+
mul_montgomery_form::<{ Modulus256::LIMBS }>(
214+
&Modulus256::PARAMS.one,
215+
&Modulus256::PARAMS.one,
216+
&Modulus256::PARAMS.modulus,
217+
Modulus256::PARAMS.mod_neg_inv()
218+
),
219+
Modulus256::PARAMS.one
220+
);
221+
}
222+
223+
#[test]
224+
fn monty_square_r() {
225+
// Square R and divide by R, which should equal R
226+
assert_eq!(
227+
square_montgomery_form::<{ Modulus128::LIMBS }>(
228+
&Modulus128::PARAMS.one,
229+
&Modulus128::PARAMS.modulus,
230+
Modulus128::PARAMS.mod_neg_inv()
231+
),
232+
Modulus128::PARAMS.one
233+
);
234+
assert_eq!(
235+
square_montgomery_form::<{ Modulus256::LIMBS }>(
236+
&Modulus256::PARAMS.one,
237+
&Modulus256::PARAMS.modulus,
238+
Modulus256::PARAMS.mod_neg_inv()
239+
),
240+
Modulus256::PARAMS.one
241+
);
242+
}
243+
244+
#[test]
245+
fn monty_mul_r2() {
246+
// Multiply 1 by R2 and divide by R, which should equal R
247+
assert_eq!(
248+
mul_montgomery_form::<{ Modulus128::LIMBS }>(
249+
&Uint::ONE,
250+
&Modulus128::PARAMS.r2,
251+
&Modulus128::PARAMS.modulus,
252+
Modulus128::PARAMS.mod_neg_inv()
253+
),
254+
Modulus128::PARAMS.one
255+
);
256+
assert_eq!(
257+
mul_montgomery_form::<{ Modulus256::LIMBS }>(
258+
&Uint::ONE,
259+
&Modulus256::PARAMS.r2,
260+
&Modulus256::PARAMS.modulus,
261+
Modulus256::PARAMS.mod_neg_inv()
262+
),
263+
Modulus256::PARAMS.one
264+
);
265+
}
266+
267+
#[test]
268+
fn monty_mul_xr() {
269+
// Reducing xR should return x
270+
let x =
271+
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
272+
assert_eq!(
273+
mul_montgomery_form::<{ Modulus256::LIMBS }>(
274+
&x,
275+
&Modulus256::PARAMS.one,
276+
&Modulus256::PARAMS.modulus,
277+
Modulus256::PARAMS.mod_neg_inv()
278+
),
279+
x
280+
);
281+
}
282+
283+
#[test]
284+
fn monty_mul_xr2() {
285+
let x =
286+
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
287+
288+
// Computing xR mod modulus without Montgomery reduction
289+
let (lo, hi) = x.widening_mul(&Modulus256::PARAMS.one);
290+
let c = lo.concat(&hi);
291+
let red =
292+
c.rem_vartime(&NonZero::new(Modulus256::PARAMS.modulus.0.concat(&U256::ZERO)).unwrap());
293+
let (lo, hi) = red.split();
294+
assert_eq!(hi, Uint::ZERO);
295+
296+
// Reducing xR^2 should return xR
297+
assert_eq!(
298+
mul_montgomery_form::<{ Modulus256::LIMBS }>(
299+
&x,
300+
&Modulus256::PARAMS.r2,
301+
&Modulus256::PARAMS.modulus,
302+
Modulus256::PARAMS.mod_neg_inv()
169303
),
170304
lo
171305
);
@@ -175,7 +309,7 @@ mod tests {
175309
fn test_new_retrieve() {
176310
let x =
177311
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
178-
let x_mod = ConstMontyForm::<Modulus2, { Modulus2::LIMBS }>::new(&x);
312+
let x_mod = ConstMontyForm::<Modulus256, { Modulus256::LIMBS }>::new(&x);
179313

180314
// Confirm that when creating a Modular and retrieving the value, that it equals the original
181315
assert_eq!(x, x_mod.retrieve());

0 commit comments

Comments
 (0)