diff --git a/src/apps/monero/xmr/bulletproof.py b/src/apps/monero/xmr/bulletproof.py index 0a09a11df..cf7858830 100644 --- a/src/apps/monero/xmr/bulletproof.py +++ b/src/apps/monero/xmr/bulletproof.py @@ -44,7 +44,9 @@ # -# Rct keys operation +# Rct keys operations +# tmp_x are global working registers to minimize memory allocations / heap fragmentation. +# Caution has to be exercised when using the registers and operations using the registers # tmp_bf_1 = bytearray(32) @@ -73,6 +75,14 @@ def _ensure_dst_key(dst=None): return dst +def alloc_keys(num=1): + return (_ensure_dst_key() for _ in range(num)) + + +def alloc_scalars(num=1): + return (crypto.new_scalar() for _ in range(num)) + + def copy_key(dst, src): for i in range(32): dst[i] = src[i] @@ -388,9 +398,11 @@ def copy(self, dst=None): dst = KeyV(src=self) return dst - def resize(self, nsize, chop=False): + def resize(self, nsize, chop=False, realloc=False): if self.size == nsize: return self + elif self.size > nsize and realloc: + self.d = bytearray(self.d[: nsize * 32]) elif self.size > nsize and not chop: self.d = self.d[: nsize * 32] else: @@ -471,7 +483,11 @@ def resize(self, nsize, chop=False): class KeyVPowers(KeyVBase): - def __init__(self, size, x): + """ + Vector of x^i. Allows only sequential access (no jumping). Resets on [0,1] access. + """ + + def __init__(self, size, x, **kwargs): super().__init__(size) self.x = x self.cur = bytearray(32) @@ -492,6 +508,31 @@ def __getitem__(self, item): IndexError("Only linear scan allowed") +class KeyVZtwo(KeyVBase): + """ + Ztwo vector - see vector_z_two_i + """ + + def __init__(self, N, logN, M, zpow, twoN, raw=False): + super().__init__(N * M) + self.N = N + self.logN = logN + self.M = M + self.zpow = zpow + self.twoN = twoN + self.raw = raw + self.sc = crypto.new_scalar() + self.cur = bytearray(32) if not raw else None + + def __getitem__(self, item): + vector_z_two_i(self.logN, self.zpow, self.twoN, self.idxize(item), self.sc) + if self.raw: + return self.sc + + crypto.encodeint_into(self.sc, self.cur) + return self.cur + + def _ensure_dst_keyvect(dst=None, size=None): if dst is None: dst = KeyV(elems=size) @@ -519,7 +560,9 @@ def vector_exponent_custom(A, B, a, b, dst=None): return dst -def vector_powers(x, n, dst=None): +def vector_powers(x, n, dst=None, dynamic=False, **kwargs): + if dynamic: + return KeyVPowers(n, x, **kwargs) dst = _ensure_dst_keyvect(dst, n) if n == 0: return dst @@ -577,6 +620,84 @@ def hadamard2(a, b, dst=None): return dst +def hadamard_fold(v, a, b): + """ + Folds a curvepoint array using a two way scaled Hadamard product + + ln = len(v); h = ln // 2 + v[i] = a * v[i] + b * v[h + i] + + :param v: + :param a: + :param b: + :return: + """ + h = len(v) // 2 + crypto.decodeint_into_noreduce(tmp_sc_1, a) + crypto.decodeint_into_noreduce(tmp_sc_2, b) + for i in range(h): + crypto.decodepoint_into(tmp_pt_1, v[i]) + crypto.decodepoint_into(tmp_pt_2, v[h + i]) + crypto.add_keys3_into(tmp_pt_3, tmp_sc_1, tmp_pt_1, tmp_sc_2, tmp_pt_2) + crypto.encodepoint_into(tmp_pt_3, v[i]) + v.resize(h, realloc=True) + return v + + +def scalar_fold(v, a, b): + """ + ln = len(v); h = ln // 2 + v[i] = v[i] * a + v[h+i] * b) + + :param v: + :param a: + :param b: + :return: + """ + h = len(v) // 2 + crypto.decodeint_into_noreduce(tmp_sc_1, a) + crypto.decodeint_into_noreduce(tmp_sc_2, b) + + for i in range(h): + crypto.decodeint_into_noreduce(tmp_sc_3, v[i]) + crypto.decodeint_into_noreduce(tmp_sc_4, v[i + h]) + crypto.sc_mul_into(tmp_sc_3, tmp_sc_3, tmp_sc_1) + crypto.sc_mul_into(tmp_sc_4, tmp_sc_4, tmp_sc_2) + crypto.sc_add_into(tmp_sc_3, tmp_sc_3, tmp_sc_4) + crypto.encodeint_into(tmp_sc_3, v[i]) + v.resize(h, realloc=True) + return v + + +def cross_inner_product(l0, r0, l1, r1): + """ + t1_1 = l0 . r1, t1_2 = l1 . r0 + t1 = t1_1 + t1_2, t2 = l1 . r1 + """ + sc_t1_1, sc_t1_2, sc_t2 = alloc_scalars(3) + cl0, cr0, cl1, cr1 = alloc_scalars(4) + + for i in range(len(l0)): + crypto.decodeint_into_noreduce(cl0, l0[i]) + crypto.decodeint_into_noreduce(cr0, r0[i]) + crypto.decodeint_into_noreduce(cl1, l1[i]) + crypto.decodeint_into_noreduce(cr1, r1[i]) + + crypto.sc_muladd_into(sc_t1_1, cl0, cr1, sc_t1_1) + crypto.sc_muladd_into(sc_t1_2, cl1, cr0, sc_t1_2) + crypto.sc_muladd_into(sc_t2, cl1, cr1, sc_t2) + + crypto.sc_add_into(sc_t1_1, sc_t1_1, sc_t1_2) + return crypto.encodeint(sc_t1_1), crypto.encodeint(sc_t2) + + +def vector_gen(dst, size, op): + dst = _ensure_dst_keyvect(dst, size) + for i in range(size): + op(i, dst[i]) + return dst + + def vector_add(a, b, dst=None): dst = _ensure_dst_keyvect(dst, len(a)) for i in range(len(a)): @@ -620,6 +741,35 @@ def vector_sum(a, dst=None): return dst +def vector_z_two_i(logN, zpow, twoN, i, dst_sc=None): + """ + 0...N|N+1...2N|2N+1...3N|.... + zt[i] = z^b 2^c, where + b = 2 + blockNumber. BlockNumber is idx of N block + c = i % N = i - N * blockNumber + """ + j = i >> logN + crypto.decodeint_into_noreduce(tmp_sc_1, zpow[j + 2]) + crypto.decodeint_into_noreduce(tmp_sc_2, twoN[i & ((1 << logN) - 1)]) + crypto.sc_mul_into(dst_sc, tmp_sc_1, tmp_sc_2) + + +def vector_z_two(N, logN, M, zpow, twoN, zero_twos=None, dynamic=False, **kwargs): + if dynamic: + return KeyVZtwo(N, logN, M, zpow, twoN, **kwargs) + + # Original algorithm from Monero + zero_twos = _ensure_dst_keyvect(zero_twos, M * N) + for i in range(M * N): + zero_twos[i] = ZERO + for j in range(1, M + 1): + if i >= (j - 1) * N and i < j * N: + sc_muladd( + zero_twos[i], zpow[1 + j], twoN[i - (j - 1) * N], zero_twos[i] + ) + return zero_twos + + def hash_cache_mash(dst, hash_cache, *args): dst = _ensure_dst_key(dst) ctx = crypto.get_keccak() @@ -653,6 +803,7 @@ class MultiExpEval(object): Moreover, Monero needs speed for very fast verification for blockchain verification which is not priority in this use case. """ + def __init__(self, size=None): self.size = size if size else None @@ -686,7 +837,7 @@ class MultiExp(MultiExpEval): """ def __init__( - self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None + self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None ): super().__init__(size) self.current_idx = 0 @@ -712,10 +863,16 @@ def add_scalar(self, scalar): self.size = len(self.scalars) def get_scalar(self, idx): - return self.scalar_fnc(idx, None) if idx >= len(self.scalars) else self.scalars[idx] + return ( + self.scalar_fnc(idx, None) + if idx >= len(self.scalars) + else self.scalars[idx] + ) def get_point(self, idx): - return self.point_fnc(idx, None) if idx >= len(self.points) else self.points[idx] + return ( + self.point_fnc(idx, None) if idx >= len(self.points) else self.points[idx] + ) def get_idx(self, idx): return self.get_scalar(idx), self.get_point(idx) @@ -741,9 +898,15 @@ class MultiExpSequential(MultiExp): """ def __init__( - self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None + self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None ): - super().__init__(size, scalars=scalars, points=points, scalar_fnc=scalar_fnc, point_fnc=point_fnc) + super().__init__( + size, + scalars=scalars, + points=points, + scalar_fnc=scalar_fnc, + point_fnc=point_fnc, + ) self.current_idx = 0 self.acc = crypto.identity() self.tmp = _ensure_dst_key() @@ -1039,8 +1202,6 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): nprime = BP_N _tmp_k_1 = _ensure_dst_key() - _tmp_vct_1 = _ensure_dst_keyvect(None, nprime // 2) - _tmp_vct_2 = _ensure_dst_keyvect(None, nprime // 2) tmp = _ensure_dst_key() winv = _ensure_dst_key() @@ -1053,8 +1214,6 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): # PAPER LINE 15 npr2 = nprime nprime >>= 1 - _tmp_vct_1.resize(nprime, chop=True) - _tmp_vct_2.resize(nprime, chop=True) self.gc(22) # PAPER LINES 16-17 @@ -1100,28 +1259,20 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): invert(winv, w[round]) self.gc(26) - vector_scalar2(Gprime.slice_view(0, nprime), winv, _tmp_vct_1) - vector_scalar2(Gprime.slice_view(nprime, npr2), w[round], _tmp_vct_2) - hadamard2(_tmp_vct_1, _tmp_vct_2, Gprime) + hadamard_fold(Gprime, winv, w[round]) self.gc(27) - vector_scalar2(Hprime.slice_view(0, nprime), w[round], _tmp_vct_1) - vector_scalar2(Hprime.slice_view(nprime, npr2), winv, _tmp_vct_2) - hadamard2(_tmp_vct_1, _tmp_vct_2, Hprime) + hadamard_fold(Hprime, w[round], winv) self.gc(28) # PAPER LINES 28-29 - vector_scalar(aprime.slice_view(0, nprime), w[round], _tmp_vct_1) - vector_scalar(aprime.slice_view(nprime, npr2), winv, _tmp_vct_2) - vector_add(_tmp_vct_1, _tmp_vct_2, aprime) + scalar_fold(aprime, w[round], winv) self.gc(29) - vector_scalar(bprime.slice_view(0, nprime), winv, _tmp_vct_1) - vector_scalar(bprime.slice_view(nprime, npr2), w[round], _tmp_vct_2) - vector_add(_tmp_vct_1, _tmp_vct_2, bprime) + scalar_fold(bprime, winv, w[round]) + self.gc(30) round += 1 - self.gc(30) copy_key(aprime0, aprime[0]) copy_key(bprime0, bprime[0]) @@ -1211,7 +1362,7 @@ def prove_batch(self, sv, gamma): def e_xL(idx, d=None, is_a=True): j, i = idx // N, idx % N - if j > num_inp: + if j >= num_inp: return ZERO if is_a else MINUS_ONE elif sv[j][i // 8] & (1 << i % 8): return ONE if is_a else ZERO @@ -1255,6 +1406,7 @@ def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): S = _ensure_dst_key() add_keys(S, ve, scalarmult_base(tmp_bf_1, rho)) scalarmult_key(S, S, INV_EIGHT) + del (ve) self.gc(12) # PAPER LINES 43-45 @@ -1276,42 +1428,52 @@ def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): l1 = sL # This computes the ugly sum/concatenation from PAPER LINE 65 - zero_twos = _ensure_dst_keyvect(None, MN) - zpow = vector_powers(z, M + 2) - twoN = self._two_aux(MN) + # r0 = aR + z + r0 = vector_add(aR, zMN) + del (zMN) self.gc(14) - for i in range(MN): - zero_twos[i] = ZERO - for j in range(1, M + 1): - if i >= (j - 1) * N and i < j * N: - sc_muladd( - zero_twos[i], zpow[1 + j], twoN[i - (j - 1) * N], zero_twos[i] - ) - + # r0 = r0 \odot yMN => r0[i] = r0[i] * y^i + # r1 = sR \odot yMN => r1[i] = sR[i] * y^i + yMN = vector_powers(y, MN, dynamic=False) + hadamard(r0, yMN, dst=r0) self.gc(15) - r0 = vector_add(aR, zMN) - del (zMN, twoN) + # r0 = r0 + zero_twos + zpow = vector_powers(z, M + 2) + twoN = self._two_aux(MN) + zero_twos = vector_z_two(N, logN, M, zpow, twoN, dynamic=True, raw=True) + vector_gen( + r0, + len(r0), + lambda i, d: crypto.encodeint_into( + crypto.sc_add_into( + tmp_sc_1, + zero_twos[i], + crypto.decodeint_into_noreduce(tmp_sc_2, r0[i]), + ), + d, + ), + ) - yMN = KeyVPowers(MN, y) # full: vector_powers(y, MN) - hadamard(r0, yMN, dst=r0) - vector_add(r0, zero_twos, dst=r0) - del (zero_twos) - self.gc(16) + del (zero_twos, twoN) + self.gc(15) - r1 = hadamard(yMN, sR) + # Polynomial construction before PAPER LINE 46 + # r1 = KeyVEval(MN, lambda i, d: sc_mul(d, yMN[i], sR[i])) + # r1 optimization possible, but has clashing sc registers. + # Moreover, max memory complexity is 4MN as below (while loop). + r1 = hadamard(yMN, sR, yMN) # re-use yMN vector for r1 del (yMN, sR) self.gc(16) - # Polynomial construction before PAPER LINE 46 - t1_1 = inner_product(l0, r1) - t1_2 = inner_product(l1, r0) - self.gc(17) - - t1 = sc_add(None, t1_1, t1_2) - t2 = inner_product(l1, r1) - del (t1_1, t1_2) + # Inner products + # l0 = aL - z r0 = ((aR + z) \cdot ypow) + zt + # l1 = sL r1 = sR \cdot ypow + # t1_1 = l0 . r1, t1_2 = l1 . r0 + # t1 = t1_1 + t1_2, t2 = l1 . r1 + # l = l0 \odot x*l1 r = r0 \odot x*r1 + t1, t2 = cross_inner_product(l0, r0, l1, r1) self.gc(17) # PAPER LINES 47-48 @@ -1349,11 +1511,17 @@ def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): del (rho, alpha) # PAPER LINES 54-57 - l = vector_add(l0, KeyVEval(len(l1), lambda i, d: sc_mul(d, l1[i], x)), l0) + # l = l0 \odot x*l1, has to evaluated as it becomes aprime in the loop + l = vector_gen( + l0, len(l0), lambda i, d: sc_add(d, d, sc_mul(tmp_bf_1, l1[i], x)) + ) del (l1, sL) self.gc(19) - r = vector_add(r0, vector_scalar(r1, x, r1), r0) + # r = r0 \odot x*r1, has to evaluated as it becomes bprime in the loop + r = vector_gen( + r0, len(r0), lambda i, d: sc_add(d, d, sc_mul(tmp_bf_1, r1[i], x)) + ) t = inner_product(l, r) del (r1) self.gc(19) @@ -1391,8 +1559,6 @@ def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): round = 0 _tmp_k_1 = _ensure_dst_key() - _tmp_vct_1 = _ensure_dst_keyvect(None, nprime // 2) - _tmp_vct_2 = _ensure_dst_keyvect(None, nprime // 2) self.gc(21) # PAPER LINE 13 @@ -1400,9 +1566,6 @@ def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): # PAPER LINE 15 npr2 = nprime nprime >>= 1 - - _tmp_vct_1.resize(nprime, chop=True) - _tmp_vct_2.resize(nprime, chop=True) self.gc(22) # PAPER LINES 16-17 @@ -1451,32 +1614,20 @@ def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): invert(winv, w[round]) self.gc(26) - vector_scalar2(Gprime.slice_view(0, nprime), winv, _tmp_vct_1) - vector_scalar2(Gprime.slice_view(nprime, npr2), w[round], _tmp_vct_2) - Gprime.resize(nprime, chop=True) - hadamard2(_tmp_vct_1, _tmp_vct_2, Gprime) + hadamard_fold(Gprime, winv, w[round]) self.gc(27) - vector_scalar2(Hprime.slice_view(0, nprime), w[round], _tmp_vct_1) - vector_scalar2(Hprime.slice_view(nprime, npr2), winv, _tmp_vct_2) - Hprime.resize(nprime, chop=True) - hadamard2(_tmp_vct_1, _tmp_vct_2, Hprime) + hadamard_fold(Hprime, w[round], winv) self.gc(28) # PAPER LINES 28-29 - vector_scalar(aprime.slice_view(0, nprime), w[round], _tmp_vct_1) - vector_scalar(aprime.slice_view(nprime, npr2), winv, _tmp_vct_2) - aprime.resize(nprime, chop=True) - vector_add(_tmp_vct_1, _tmp_vct_2, aprime) + scalar_fold(aprime, w[round], winv) self.gc(29) - vector_scalar(bprime.slice_view(0, nprime), winv, _tmp_vct_1) - vector_scalar(bprime.slice_view(nprime, npr2), w[round], _tmp_vct_2) - bprime.resize(nprime, chop=True) - vector_add(_tmp_vct_1, _tmp_vct_2, bprime) + scalar_fold(bprime, winv, w[round]) + self.gc(30) round += 1 - self.gc(30) return ( 1, @@ -1854,9 +2005,11 @@ def verify_batch(self, proofs, single_optim=True): sc_sub(tmp, m_y0, z1) z3p = sc_sub(None, z3, y1) - check2 = crypto.encodepoint(crypto.ge_double_scalarmult_base_vartime( - crypto.decodeint(z3p), crypto.gen_H(), crypto.decodeint(tmp) - )) + check2 = crypto.encodepoint( + crypto.ge_double_scalarmult_base_vartime( + crypto.decodeint(z3p), crypto.gen_H(), crypto.decodeint(tmp) + ) + ) add_keys(muex_acc, muex_acc, check2) if not is_single: # ph4