-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathntt.c
292 lines (262 loc) · 9.56 KB
/
ntt.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
/**
@file ntt.c
*/
#include "ntt.h"
#include <stdio.h>
#include "defines.h"
#include "fft.h"
#include "fileops.h"
#include "parameters.h"
#include "polymodarith.h"
#include "uintmodarith.h"
#include "util_print.h"
#if defined(SE_NTT_OTF) || defined(SE_NTT_ONE_SHOT)
ZZ get_ntt_root(size_t n, ZZ q); // defined below
#endif
void ntt_roots_initialize(const Parms *parms, ZZ *ntt_roots)
{
#ifdef SE_REVERSE_CT_GEN_ENABLED
SE_UNUSED(parms);
SE_UNUSED(ntt_roots);
if (parms->skip_ntt_load) return;
#endif
#ifdef SE_NTT_OTF
SE_UNUSED(parms);
SE_UNUSED(ntt_roots);
return;
#endif
se_assert(parms && parms->curr_modulus && ntt_roots);
#ifdef SE_NTT_ONE_SHOT
size_t n = parms->coeff_count;
size_t logn = parms->logn;
Modulus *mod = parms->curr_modulus;
ZZ root = get_ntt_root(n, mod->value);
ZZ power = root;
ntt_roots[0] = 1; // Not necessary but set anyway
for (size_t i = 1; i < n; i++)
{
ntt_roots[bitrev(i, logn)] = power;
power = mul_mod(power, root, mod);
}
#elif defined(SE_NTT_FAST)
load_ntt_fast_roots(parms, (MUMO *)ntt_roots);
#elif defined(SE_NTT_REG)
load_ntt_roots(parms, ntt_roots);
#else
se_assert(0);
#endif
}
#ifdef SE_NTT_FAST
/**
Performs a "fast" (a.k.a. "lazy") negacyclic in-place NTT using the Harvey butterfly. Only used if
"SE_NTT_FAST" is defined. See SEAL_Embedded paper for a more detailed description. "Lazy"-ness
refers to fact that we here we lazily opt to reduce values only at the very end.
@param[in] parms Parameters set by ckks_setup
@param[in] ntt_fast_roots NTT roots set by ntt_roots_initialize
@param[in,out] vec Input/output polynomial of n ZZ elements
*/
void ntt_lazy_inpl(const Parms *parms, const MUMO *ntt_fast_roots, ZZ *vec)
{
se_assert(parms && ntt_fast_roots && vec);
size_t n = parms->coeff_count;
Modulus *mod = parms->curr_modulus;
ZZ two_q = mod->value << 1;
// -- Return the NTT in scrambled order
size_t h = 1;
size_t tt = n / 2;
// size_t root_idx = 1;
for (int i = 0; i < parms->logn; i++, h *= 2, tt /= 2) // Rounds
{
// print_poly_full("s in ntt", vec, n);
for (size_t j = 0, kstart = 0; j < h; j++, kstart += 2 * tt) // Groups
{
const MUMO *s = &(ntt_fast_roots[h + j]);
// const MUMO *s = &(ntt_fast_roots[bitrev(h + j, parms->logn)]);
// const MUMO *s = &(ntt_fast_roots[root_idx++]);
// -- The Harvey butterfly. Assume val1, val2 in [0, 2p)
// -- Return vec[k], vec[k+tt] in [0, 4p)
for (size_t k = kstart; k < (kstart + tt); k++) // Pairs
{
ZZ val1 = vec[k];
ZZ val2 = vec[k + tt];
ZZ u = val1 - (two_q & (ZZ)(-(ZZsign)(val1 >= two_q)));
ZZ v = mul_mod_mumo_lazy(val2, s, mod);
// -- We know these will not generate carries/overflows
vec[k] = u + v;
vec[k + tt] = u + two_q - v;
}
}
}
}
#else
/**
Performs a negacyclic in-place NTT using the Harvey butterfly. Only used if SE_NTT_FAST is not
defined.
If SE_NTT_REG or SE_NTT_ONE_SHOT is defined, will use regular NTT computation.
Else, (SE_NTT_OTF is defined), will use truly "on-the-fly" NTT computation. In this last case,
'ntt_roots' may be null (and will be ignored).
@param[in] parms Parameters set by ckks_setup
@param[in] ntt_roots NTT roots set by ntt_roots_initialize. Ignored if SE_NTT_OTF is defined.
@param[in,out] vec Input/output polynomial of n ZZ elements
*/
void ntt_non_lazy_inpl(const Parms *parms, const ZZ *ntt_roots, ZZ *vec)
{
se_assert(parms && parms->curr_modulus && vec);
size_t n = parms->coeff_count;
size_t logn = parms->logn;
Modulus *mod = parms->curr_modulus;
// -- Return the NTT in scrambled order
size_t h = 1;
size_t tt = n / 2;
#ifdef SE_NTT_OTF
SE_UNUSED(ntt_roots);
ZZ root = get_ntt_root(n, mod->value);
#endif
for (size_t i = 0; i < logn; i++, h *= 2, tt /= 2) // rounds
{
for (size_t j = 0, kstart = 0; j < h; j++, kstart += 2 * tt) // groups
{
#ifdef SE_NTT_OTF
// printf("h+j: %zu\n", h+j);
// ZZ power = bitrev(h+j, logn);
// ZZ s = exponentiate_uint_mod(root, power, mod);
ZZ power = h + j;
ZZ s = exponentiate_uint_mod_bitrev(root, power, logn, mod);
#else
se_assert(ntt_roots);
ZZ s = ntt_roots[h + j];
#endif
// -- The Harvey butterfly. Assume val1, val2 in [0, 2p)
// -- Return vec[k], vec[k+tt] in [0, 4p)
for (size_t k = kstart; k < (kstart + tt); k++) // pairs
{
ZZ u = vec[k];
ZZ v = mul_mod(vec[k + tt], s, mod);
vec[k] = add_mod(u, v, mod); // vec[k] = u + v;
vec[k + tt] = sub_mod(u, v, mod); // vec[k+tt] = u - v;
}
}
}
}
#endif
void ntt_inpl(const Parms *parms, const ZZ *ntt_roots, ZZ *vec)
{
se_assert(parms && parms->curr_modulus && vec);
#ifdef SE_NTT_FAST
se_assert(ntt_roots);
ntt_lazy_inpl(parms, (MUMO *)ntt_roots, vec);
// print_poly_full("vec", vec, parms->coeff_count);
// -- Finally, we might need to reduce coefficients modulo q, but we know each
// coefficient is in the range [0, 4q). Since word size is controlled, this
// should be fast.
ZZ q = parms->curr_modulus->value;
ZZ two_q = q << 1;
for (size_t i = 0; i < parms->coeff_count; i++)
{
if (vec[i] >= two_q) vec[i] -= two_q;
if (vec[i] >= q) vec[i] -= q;
}
#else
ntt_non_lazy_inpl(parms, ntt_roots, vec);
#endif
}
#if defined(SE_NTT_OTF) || defined(SE_NTT_ONE_SHOT)
/**
Helper function to return root for certain modulus prime values if SE_NTT_OTF or SE_NTT_ONE_SHOT is
used. Implemented as a table lookup.
@param[in] n Transform size (i.e. polynomial ring degree)
@param[in] q Modulus value
*/
ZZ get_ntt_root(size_t n, ZZ q)
{
/**
For custom primes:
Add cases for custom primes at the indicated locations.
Note that this is only required if:
1) paramaters (specifically, primes) desired are of custom (i.e., non-default) type
2) NTT type is of compute ("on-the-fly" or "one-shot") type
If any of the above conditions are false, no modifications to this file are needed.
Note: wolframalpha is a good tool to use to calculate constants.
*/
ZZ root; // i.e. w = first power of NTT root
switch (n)
{
case 1024:
// -- Add cases for custom primes here
se_assert(q == 134012929);
root = 142143;
break;
case 2048:
// -- Add cases for custom primes here
se_assert(q == 134012929);
root = 85250;
break;
case 4096:
switch (q)
{
// -- Add cases for custom primes here
case 134012929: root = 7470; break; // 27 bit
case 134111233: root = 3856; break; // 27 bit
case 134176769: root = 24149; break; // 27 bit
case 1053818881: root = 503422; break; // 30 bit
case 1054015489: root = 16768; break; // 30 bit
case 1054212097: root = 7305; break; // 30 bit
default: {
printf("Error! Need first power of root for ntt, n = 4K\n");
print_zz("Modulus value", q);
exit(1);
}
}
break;
case 8192:
switch (q)
{
// -- Add cases for custom primes here
case 1053818881: root = 374229; break;
case 1054015489: root = 123363; break;
case 1054212097: root = 79941; break;
case 1055260673: root = 38869; break;
case 1056178177: root = 162146; break;
case 1056440321: root = 81884; break;
default: {
printf("Error! Need first power of root for ntt, n = 8K\n");
print_zz("Modulus value", q);
exit(1);
}
}
break;
case 16384: // TODO: ADD A FLAG TO TURN THESE OFF?
switch (q)
{
// -- Add cases for custom primes here
case 1053818881: root = 13040; break;
case 1054015489: root = 507; break;
case 1054212097: root = 1595; break;
case 1055260673: root = 68507; break;
case 1056178177: root = 3073; break;
case 1056440321: root = 6854; break;
case 1058209793: root = 44467; break;
case 1060175873: root = 16117; break;
case 1060700161: root = 27607; break;
case 1060765697: root = 222391; break;
case 1061093377: root = 105471; break;
case 1062469633: root = 310222; break;
case 1062535169: root = 2005; break;
default: {
printf("Error! Need first power of root for ntt, n = 16K\n");
print_zz("Modulus value", q);
exit(1);
}
}
break;
default: {
printf("Error! Need first power of root for ntt\n");
print_zz("Modulus value", q);
exit(1);
}
}
return root;
}
#endif