Skip to content

Commit

Permalink
xmr: bp - memory optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
ph4r05 committed Aug 18, 2018
1 parent d2fcb23 commit a28eb55
Showing 1 changed file with 56 additions and 79 deletions.
135 changes: 56 additions & 79 deletions src/apps/monero/xmr/bulletproof.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#

tmp_bf_1 = bytearray(32)
tmp_bf_2 = bytearray(32)

tmp_pt_1 = crypto.new_point()
tmp_pt_2 = crypto.new_point()
Expand Down Expand Up @@ -75,42 +76,6 @@ def copy_vector(dst, src):
copy_key(dst[i], src[i])


def extended_gcd(aa, bb):
lastremainder, remainder = abs(aa), abs(bb)
x, lastx, y, lasty = 0, 1, 1, 0
gc_ctr = 0
while remainder:
lastremainder, (quotient, remainder) = (
remainder,
divmod(lastremainder, remainder),
)
x, lastx = lastx - quotient * x, x
y, lasty = lasty - quotient * y, y
gc_ctr += 1
if gc_ctr % 10:
gc.collect()
return lastremainder, lastx * (-1 if aa < 0 else 1), lasty * (-1 if bb < 0 else 1)


def modinv(a, m):
g, x, y = extended_gcd(a, m)
if g != 1:
raise ValueError
return x % m


def mul_inverse(x, n):
return pow(x, n - 2, n)


mul_inverse_used = modinv
try:
pow(2, 5, 7)
mul_inverse_used = mul_inverse
except NotImplementedError:
pass


def invert(dst, x):
"""
Modular inversion mod curve order.
Expand Down Expand Up @@ -267,11 +232,12 @@ class KeyV(object):
Constant precomputed buffers = bytes, frozen. Same operation as normal.
"""

def __init__(self, elems=64, src=None, buffer=None):
def __init__(self, elems=64, src=None, buffer=None, const=False):
self.current_idx = 0
self.d = None
self.mv = None
self.size = elems
self.const = const
if src:
self.d = bytearray(src.d)
self.size = src.size
Expand Down Expand Up @@ -300,6 +266,8 @@ def __setitem__(self, key, value):
:param value:
:return:
"""
if self.const:
raise ValueError("Constant KeyV")
ck = self[key]
for i in range(32):
ck[i] = value[i]
Expand Down Expand Up @@ -541,10 +509,10 @@ def __init__(self):
self.gamma = None
self.gamma_enc = None
self.proof_sec = None
self.Gprec = KeyV(buffer=BP_GI_PRE)
self.Hprec = KeyV(buffer=BP_HI_PRE)
self.oneN = KeyV(buffer=BP_ONE_N)
self.twoN = KeyV(buffer=BP_TWO_N)
self.Gprec = KeyV(buffer=BP_GI_PRE, const=True)
self.Hprec = KeyV(buffer=BP_HI_PRE, const=True)
self.oneN = KeyV(buffer=BP_ONE_N, const=True)
self.twoN = KeyV(buffer=BP_TWO_N, const=True)
self.ip12 = BP_IP12
self.v_aL = None
self.v_aR = None
Expand Down Expand Up @@ -579,8 +547,8 @@ def aL(self, i, dst=None):

def aR(self, i, dst=None):
dst = _ensure_dst_key(dst)
a_tmp = self.aL(i)
sc_sub(dst, a_tmp, ONE)
self.aL(i, tmp_bf_1)
sc_sub(dst, tmp_bf_1, ONE)
return dst

def aL_vct(self):
Expand Down Expand Up @@ -640,12 +608,12 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r):
alpha = sc_gen()
ve = _ensure_dst_key()
self.vector_exponent(self.v_aL, self.v_aR, ve)
add_keys(A, ve, scalarmult_base(None, alpha))
add_keys(A, ve, scalarmult_base(tmp_bf_1, alpha))

# PAPER LINES 40-42
rho = sc_gen()
self.vector_exponent(self.v_sL, self.v_sR, ve)
add_keys(S, ve, scalarmult_base(None, rho))
add_keys(S, ve, scalarmult_base(tmp_bf_1, rho))

# PAPER LINES 43-45
z = _ensure_dst_key()
Expand Down Expand Up @@ -684,6 +652,8 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r):
vector_scalar(self.oneN, z, tmp_vct)
aL_vpIz = vector_subtract(self.v_aL, tmp_vct)
aR_vpIz = vector_add(self.v_aR, tmp_vct)
self.v_aL = None
self.v_aR = None
self.gc(4)

# tmp_vct = HyNsR
Expand All @@ -710,8 +680,12 @@ def prove_s1(self, V, A, S, T1, T2, taux, mu, t, x_ip, y, hash_cache, l, r):
tau1 = sc_gen()
tau2 = sc_gen()

add_keys(T1, scalarmult_key(None, XMR_H, t1), scalarmult_base(None, tau1))
add_keys(T2, scalarmult_key(None, XMR_H, t2), scalarmult_base(None, tau2))
add_keys(
T1, scalarmult_key(tmp_bf_1, XMR_H, t1), scalarmult_base(tmp_bf_2, tau1)
)
add_keys(
T2, scalarmult_key(tmp_bf_1, XMR_H, t2), scalarmult_base(tmp_bf_2, tau2)
)

# PAPER LINES 49-51
x = _ensure_dst_key()
Expand Down Expand Up @@ -784,6 +758,8 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0):
tmp = _ensure_dst_key()
winv = _ensure_dst_key()
w = _ensure_dst_keyvect(None, BP_LOG_N)
cL = _ensure_dst_key()
cR = _ensure_dst_key()

# PAPER LINE 13
while nprime > 1:
Expand All @@ -796,15 +772,18 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0):
self.gc(22)

# PAPER LINES 16-17
cL = inner_product(
inner_product(
aprime.slice(_tmp_vct_1, 0, nprime),
bprime.slice(_tmp_vct_2, nprime, bprime.size),
cL,
)

cR = inner_product(
inner_product(
aprime.slice(_tmp_vct_1, nprime, aprime.size),
bprime.slice(_tmp_vct_2, 0, nprime),
cR,
)

self.gc(23)

# PAPER LINES 18-19
Expand Down Expand Up @@ -839,46 +818,36 @@ def prove_s2(self, x_ip, y, hash_cache, l, r, L, R, aprime0, bprime0):
invert(winv, w[round])
self.gc(26)

hadamard2(
vector_scalar2(Gprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3),
vector_scalar2(
Gprime.slice(_tmp_vct_2, nprime, len(Gprime)), w[round], _tmp_vct_4
),
Gprime,
vector_scalar2(Gprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3)
vector_scalar2(
Gprime.slice(_tmp_vct_2, nprime, len(Gprime)), w[round], _tmp_vct_4
)
hadamard2(_tmp_vct_3, _tmp_vct_4, Gprime)
self.gc(27)

hadamard2(
vector_scalar2(
Hprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3
),
vector_scalar2(
Hprime.slice(_tmp_vct_2, nprime, len(Hprime)), winv, _tmp_vct_4
),
Hprime,
vector_scalar2(Hprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3)
vector_scalar2(
Hprime.slice(_tmp_vct_2, nprime, len(Hprime)), winv, _tmp_vct_4
)
self.gc(27)
hadamard2(_tmp_vct_3, _tmp_vct_4, Hprime)
self.gc(28)

# PAPER LINES 28-29
vector_add(
vector_scalar(
aprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3
),
vector_scalar(
aprime.slice(_tmp_vct_2, nprime, len(aprime)), winv, _tmp_vct_4
),
aprime,
vector_scalar(aprime.slice(_tmp_vct_1, 0, nprime), w[round], _tmp_vct_3)
vector_scalar(
aprime.slice(_tmp_vct_2, nprime, len(aprime)), winv, _tmp_vct_4
)
vector_add(_tmp_vct_3, _tmp_vct_4, aprime)
self.gc(29)

vector_add(
vector_scalar(bprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3),
vector_scalar(
bprime.slice(_tmp_vct_2, nprime, len(bprime)), w[round], _tmp_vct_4
),
bprime,
vector_scalar(bprime.slice(_tmp_vct_1, 0, nprime), winv, _tmp_vct_3)
vector_scalar(
bprime.slice(_tmp_vct_2, nprime, len(bprime)), w[round], _tmp_vct_4
)
vector_add(_tmp_vct_3, _tmp_vct_4, bprime)

round += 1
self.gc(28)
self.gc(30)

copy_key(aprime0, aprime[0])
copy_key(bprime0, bprime[0])
Expand Down Expand Up @@ -974,6 +943,8 @@ def verify(self, proof):
k = _ensure_dst_key()
yN = vector_powers(y, BP_N)
ip1y = inner_product(self.oneN, yN)
del yN

zsq = _ensure_dst_key()
sc_mul(zsq, z, z)

Expand Down Expand Up @@ -1002,6 +973,12 @@ def verify(self, proof):
if L61Right != L61Left:
raise ValueError("Verification failure 1")

del k
del ip1y
del zcu
del L61Left
del L61Right

# PAPER LINE 62
P = _ensure_dst_key()
add_keys(P, proof.A, scalarmult_key(_tmp_k_1, proof.S, x))
Expand Down

0 comments on commit a28eb55

Please sign in to comment.