diff --git a/src/apps/monero/xmr/bulletproof.py b/src/apps/monero/xmr/bulletproof.py index 5ffa5c928..891f123bc 100644 --- a/src/apps/monero/xmr/bulletproof.py +++ b/src/apps/monero/xmr/bulletproof.py @@ -520,6 +520,14 @@ def __init__(self): self.v_sL = None self.v_sR = None self.tmp_sc_1 = crypto.new_scalar() + self.gc_fnc = gc.collect + self.gc_trace = None + + def gc(self, *args): + if self.gc_trace: + self.gc_trace(*args) + if self.gc_fnc: + self.gc_fnc() def set_input(self, value=None, mask=None): self.value = value @@ -605,7 +613,7 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): hash_cache_mash(y, hash_cache, A, S) hash_to_scalar(hash_cache, y) copy_key(z, hash_cache) - gc.collect() + self.gc(1) # Polynomial construction before PAPER LINE 46 t0 = _ensure_dst_key() @@ -613,7 +621,7 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): t2 = _ensure_dst_key() yN = vector_powers(y, BP_N) - gc.collect() + self.gc(2) ip1y = inner_product(self.oneN, yN) sc_muladd(t0, z, ip1y, t0) @@ -630,20 +638,20 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): sc_mul(zcu, zsq, z) sc_mulsub(k, zcu, self.ip12, k) sc_add(t0, t0, k) - gc.collect() + self.gc(3) # step 2 vpIz = vector_scalar(self.oneN, z) aL_vpIz = vector_subtract(self.v_aL, vpIz) aR_vpIz = vector_add(self.v_aR, vpIz) del vpIz - gc.collect() + self.gc(4) HyNsR = hadamard(yN, self.v_sR) ip1 = inner_product(aL_vpIz, HyNsR) ip3 = inner_product(self.v_sL, HyNsR) del HyNsR - gc.collect() + self.gc(5) sc_add(t1, t1, ip1) @@ -656,7 +664,7 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): vector_add(tmp_vct, vp2zsq, tmp_vct) ip2 = inner_product(self.v_sL, tmp_vct) - gc.collect() + self.gc(6) sc_add(t1, t1, ip2) sc_add(t2, t2, ip3) @@ -679,30 +687,30 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r): sc_muladd(taux, tau2, xsq, taux) sc_muladd(taux, self.gamma_enc, zsq, taux) sc_muladd(mu, x, rho, alpha) - gc.collect() + self.gc(7) # PAPER LINES 54-57 vector_add(aL_vpIz, vector_scalar(self.v_sL, x), l) self.v_sL = None del aL_vpIz - gc.collect() + self.gc(8) # Originally: # vector_add(hadamard(yN, vector_add(aR_vpIz, vector_scalar(self.v_sR, x))), vp2zsq, r) vector_scalar(self.v_sR, x, tmp_vct) vector_add(aR_vpIz, tmp_vct, tmp_vct) del aR_vpIz - gc.collect() + self.gc(9) hadamard(yN, tmp_vct, tmp_vct) del yN - gc.collect() + self.gc(10) vector_add(tmp_vct, vp2zsq, r) self.v_sR = None del vp2zsq del tmp_vct - gc.collect() + self.gc(11) inner_product(l, r, t) hash_cache_mash(x_ip, hash_cache, x, taux, mu, t) @@ -721,7 +729,7 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): Gprime[i] = self.Gprec[i] scalarmult_key(Hprime[i], self.Hprec[i], yinvpow) sc_mul(yinvpow, yinvpow, yinv) - gc.collect() + self.gc(20) round = 0 nprime = BP_N @@ -744,7 +752,7 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): _tmp_vct_2.resize(nprime, chop=True) _tmp_vct_3.resize(nprime, chop=True) _tmp_vct_4.resize(nprime, chop=True) - gc.collect() + self.gc(21) # PAPER LINES 16-17 cL = inner_product( @@ -776,7 +784,7 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): bprime.slice(_tmp_vct_4, 0, nprime), R[round], ) - gc.collect() + self.gc(22) sc_mul(tmp, cR, x_ip) add_keys(R[round], R[round], scalarmult_key(_tmp_k_1, XMR_H, tmp)) @@ -803,7 +811,7 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): ), Hprime, ) - gc.collect() + self.gc(23) # PAPER LINES 28-29 vector_add( @@ -825,7 +833,7 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0): ) round += 1 - gc.collect() + self.gc(24) copy_key(aprime0, aprime[0]) copy_key(bprime0, bprime[0]) @@ -858,13 +866,13 @@ def prove(self): r = _ensure_dst_keyvect(None, BP_N) self.init_vct() - gc.collect() + self.gc(50) self.prove_s1(V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r) - gc.collect() + self.gc(51) self.prove_s2(x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0) - gc.collect() + self.gc(52) return Bulletproof( V=[V], @@ -944,7 +952,7 @@ def verify(self, proof): sc_mul(xsq, x, x) scalarmult_key(tmp, proof.T2, xsq) add_keys(L61Right, L61Right, tmp) - gc.collect() + self.gc(60) if L61Right != L61Left: raise ValueError("Verification failure 1") @@ -1010,6 +1018,7 @@ def verify(self, proof): del g_scalar del h_scalar + self.gc(61) # PAPER LINE 26 pprime = _ensure_dst_key() @@ -1030,7 +1039,7 @@ def verify(self, proof): sc_mul(tmp, tmp, x_ip) scalarmult_key(tmp, XMR_H, tmp) add_keys(tmp, tmp, inner_prod) - gc.collect() + self.gc(62) if pprime != tmp: raise ValueError("Verification failure step 2")