From a28eb55f380e320154002368737c165bb1b4485a Mon Sep 17 00:00:00 2001 From: Dusan Klinec Date: Fri, 17 Aug 2018 05:23:51 +0200 Subject: [PATCH] xmr: bp - memory optimizations --- src/apps/monero/xmr/bulletproof.py | 135 ++++++++++++----------------- 1 file changed, 56 insertions(+), 79 deletions(-) diff --git a/src/apps/monero/xmr/bulletproof.py b/src/apps/monero/xmr/bulletproof.py index 8f53bd49b..aea255479 100644 --- a/src/apps/monero/xmr/bulletproof.py +++ b/src/apps/monero/xmr/bulletproof.py @@ -41,6 +41,7 @@ # tmp_bf_1 = bytearray(32) +tmp_bf_2 = bytearray(32) tmp_pt_1 = crypto.new_point() tmp_pt_2 = crypto.new_point() @@ -75,42 +76,6 @@ def copy_vector(dst, src): copy_key(dst[i], src[i]) -def extended_gcd(aa, bb): - lastremainder, remainder = abs(aa), abs(bb) - x, lastx, y, lasty = 0, 1, 1, 0 - gc_ctr = 0 - while remainder: - lastremainder, (quotient, remainder) = ( - remainder, - divmod(lastremainder, remainder), - ) - x, lastx = lastx - quotient * x, x - y, lasty = lasty - quotient * y, y - gc_ctr += 1 - if gc_ctr % 10: - gc.collect() - return lastremainder, lastx * (-1 if aa < 0 else 1), lasty * (-1 if bb < 0 else 1) - - -def modinv(a, m): - g, x, y = extended_gcd(a, m) - if g != 1: - raise ValueError - return x % m - - -def mul_inverse(x, n): - return pow(x, n - 2, n) - - -mul_inverse_used = modinv -try: - pow(2, 5, 7) - mul_inverse_used = mul_inverse -except NotImplementedError: - pass - - def invert(dst, x): """ Modular inversion mod curve order. @@ -267,11 +232,12 @@ class KeyV(object): Constant precomputed buffers = bytes, frozen. Same operation as normal. """ - def __init__(self, elems=64, src=None, buffer=None): + def __init__(self, elems=64, src=None, buffer=None, const=False): self.current_idx = 0 self.d = None self.mv = None self.size = elems + self.const = const if src: self.d = bytearray(src.d) self.size = src.size @@ -300,6 +266,8 @@ def __setitem__(self, key, value): :param value: :return: """ + if self.const: + raise ValueError("Constant KeyV") ck = self[key] for i in range(32): ck[i] = value[i] @@ -541,10 +509,10 @@ def __init__(self): self.gamma = None self.gamma_enc = None self.proof_sec = None - self.Gprec = KeyV(buffer=BP_GI_PRE) - self.Hprec = KeyV(buffer=BP_HI_PRE) - self.oneN = KeyV(buffer=BP_ONE_N) - self.twoN = KeyV(buffer=BP_TWO_N) + self.Gprec = KeyV(buffer=BP_GI_PRE, const=True) + self.Hprec = KeyV(buffer=BP_HI_PRE, const=True) + self.oneN = KeyV(buffer=BP_ONE_N, const=True) + self.twoN = KeyV(buffer=BP_TWO_N, const=True) self.ip12 = BP_IP12 self.v_aL = None self.v_aR = None @@ -579,8 +547,8 @@ def aL(self, i, dst=None): def aR(self, i, dst=None): dst = _ensure_dst_key(dst) - a_tmp = self.aL(i) - sc_sub(dst, a_tmp, ONE) + self.aL(i, tmp_bf_1) + sc_sub(dst, tmp_bf_1, ONE) return dst def aL_vct(self): @@ -640,12 +608,12 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): alpha = sc_gen() ve = _ensure_dst_key() self.vector_exponent(self.v_aL, self.v_aR, ve) - add_keys(A, ve, scalarmult_base(None, alpha)) + add_keys(A, ve, scalarmult_base(tmp_bf_1, alpha)) # PAPER LINES 40-42 rho = sc_gen() self.vector_exponent(self.v_sL, self.v_sR, ve) - add_keys(S, ve, scalarmult_base(None, rho)) + add_keys(S, ve, scalarmult_base(tmp_bf_1, rho)) # PAPER LINES 43-45 z = _ensure_dst_key() @@ -684,6 +652,8 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): vector_scalar(self.oneN, z, tmp_vct) aL_vpIz = vector_subtract(self.v_aL, tmp_vct) aR_vpIz = vector_add(self.v_aR, tmp_vct) + self.v_aL = None + self.v_aR = None self.gc(4) # tmp_vct = HyNsR @@ -710,8 +680,12 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): tau1 = sc_gen() tau2 = sc_gen() - add_keys(T1, scalarmult_key(None, XMR_H, t1), scalarmult_base(None, tau1)) - add_keys(T2, scalarmult_key(None, XMR_H, t2), scalarmult_base(None, tau2)) + add_keys( + T1, scalarmult_key(tmp_bf_1, XMR_H, t1), scalarmult_base(tmp_bf_2, tau1) + ) + add_keys( + T2, scalarmult_key(tmp_bf_1, XMR_H, t2), scalarmult_base(tmp_bf_2, tau2) + ) # PAPER LINES 49-51 x = _ensure_dst_key() @@ -784,6 +758,8 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): tmp = _ensure_dst_key() winv = _ensure_dst_key() w = _ensure_dst_keyvect(None, BP_LOG_N) + cL = _ensure_dst_key() + cR = _ensure_dst_key() # PAPER LINE 13 while nprime > 1: @@ -796,15 +772,18 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): self.gc(22) # PAPER LINES 16-17 - cL = inner_product( + inner_product( aprime.slice(_tmp_vct_1, 0, nprime), bprime.slice(_tmp_vct_2, nprime, bprime.size), + cL, ) - cR = inner_product( + inner_product( aprime.slice(_tmp_vct_1, nprime, aprime.size), bprime.slice(_tmp_vct_2, 0, nprime), + cR, ) + self.gc(23) # PAPER LINES 18-19 @@ -839,46 +818,36 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): invert(winv, w[round]) self.gc(26) - hadamard2( - vector_scalar2(Gprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3), - vector_scalar2( - Gprime.slice(_tmp_vct_2, nprime, len(Gprime)), w[round], _tmp_vct_4 - ), - Gprime, + vector_scalar2(Gprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3) + vector_scalar2( + Gprime.slice(_tmp_vct_2, nprime, len(Gprime)), w[round], _tmp_vct_4 ) + hadamard2(_tmp_vct_3, _tmp_vct_4, Gprime) + self.gc(27) - hadamard2( - vector_scalar2( - Hprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3 - ), - vector_scalar2( - Hprime.slice(_tmp_vct_2, nprime, len(Hprime)), winv, _tmp_vct_4 - ), - Hprime, + vector_scalar2(Hprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3) + vector_scalar2( + Hprime.slice(_tmp_vct_2, nprime, len(Hprime)), winv, _tmp_vct_4 ) - self.gc(27) + hadamard2(_tmp_vct_3, _tmp_vct_4, Hprime) + self.gc(28) # PAPER LINES 28-29 - vector_add( - vector_scalar( - aprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3 - ), - vector_scalar( - aprime.slice(_tmp_vct_2, nprime, len(aprime)), winv, _tmp_vct_4 - ), - aprime, + vector_scalar(aprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3) + vector_scalar( + aprime.slice(_tmp_vct_2, nprime, len(aprime)), winv, _tmp_vct_4 ) + vector_add(_tmp_vct_3, _tmp_vct_4, aprime) + self.gc(29) - vector_add( - vector_scalar(bprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3), - vector_scalar( - bprime.slice(_tmp_vct_2, nprime, len(bprime)), w[round], _tmp_vct_4 - ), - bprime, + vector_scalar(bprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3) + vector_scalar( + bprime.slice(_tmp_vct_2, nprime, len(bprime)), w[round], _tmp_vct_4 ) + vector_add(_tmp_vct_3, _tmp_vct_4, bprime) round += 1 - self.gc(28) + self.gc(30) copy_key(aprime0, aprime[0]) copy_key(bprime0, bprime[0]) @@ -974,6 +943,8 @@ def verify(self, proof): k = _ensure_dst_key() yN = vector_powers(y, BP_N) ip1y = inner_product(self.oneN, yN) + del yN + zsq = _ensure_dst_key() sc_mul(zsq, z, z) @@ -1002,6 +973,12 @@ def verify(self, proof): if L61Right != L61Left: raise ValueError("Verification failure 1") + del k + del ip1y + del zcu + del L61Left + del L61Right + # PAPER LINE 62 P = _ensure_dst_key() add_keys(P, proof.A, scalarmult_key(_tmp_k_1, proof.S, x))