Skip to content

Commit

Permalink
xmr: bp - verification in log(MN) memory for 1 proof
Browse files Browse the repository at this point in the history
- not allocating MN vectors
- sequential multiexec added for memory efficient verification
- bulletproofs: maintain -z4, -z5, and -y0 to avoid subtractions [8276d25]
- bulletproofs: merge multiexps as per sarang's new python code [acd64d2b]
  • Loading branch information
ph4r05 committed Aug 25, 2018
1 parent 75aa7de commit 2a2b0cb
Showing 1 changed file with 144 additions and 109 deletions.
253 changes: 144 additions & 109 deletions src/apps/monero/xmr/bulletproof.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def is_reduced(sc):
return crypto.encodeint(crypto.decodeint(sc)) == sc


class MultiExp(object):
class MultiExpEval(object):
"""
MultiExp object similar to MultiExp array of [(scalar, point), ]
MultiExp computes simply: res = \sum_i scalar_i * point_i
Expand All @@ -653,11 +653,42 @@ class MultiExp(object):
Moreover, Monero needs speed for very fast verification for blockchain verification which is not
priority in this use case.
"""
def __init__(self, size=None):
self.size = size if size else None

def __len__(self):
return self.size

def __getitem__(self, item):
raise IndexError()

@staticmethod
def eval_data(dst, data, GiHi=False):
dst = _ensure_dst_key(dst)
crypto.identity_into(tmp_pt_1)
for i in range(len(data)):
sci, pti = data[i]
crypto.decodeint_into_noreduce(tmp_sc_1, sci)
crypto.decodepoint_into(tmp_pt_2, pti)
crypto.scalarmult_into(tmp_pt_3, tmp_pt_2, tmp_sc_1)
crypto.point_add_into(tmp_pt_1, tmp_pt_1, tmp_pt_3)
crypto.encodepoint_into(tmp_pt_1, dst)
return dst

def eval(self, dst, GiHi=False):
return MultiExpEval.eval_data(dst, self, GiHi)


class MultiExp(MultiExpEval):
"""
Simple MultiExp holder
Supports on the fly evaluation and/or static array
"""

def __init__(
self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None
self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None
):
self.size = size if size else None
super().__init__(size)
self.current_idx = 0

self.scalars = scalars if scalars else []
Expand All @@ -668,31 +699,26 @@ def __init__(
self.size = max(
len(scalars) if scalars else 0, len(points) if points else 0
)
else:
self.size = 0

def add_pair(self, scalar, point):
self.scalars.append(scalar)
self.points.append(point)
self.size = len(self.points)
self.size = len(self.scalars)

def add_scalar(self, scalar):
self.scalars.append(init_key(scalar))
self.size = len(self.scalars)

def get_idx(self, idx):
dst_scalar = None
dst_point = None

if idx >= len(self.scalars):
dst_scalar = self.scalar_fnc(idx, None)
else:
dst_scalar = self.scalars[idx]
def get_scalar(self, idx):
return self.scalar_fnc(idx, None) if idx >= len(self.scalars) else self.scalars[idx]

if idx >= len(self.points):
dst_point = self.point_fnc(idx, None)
else:
dst_point = self.points[idx]
def get_point(self, idx):
return self.point_fnc(idx, None) if idx >= len(self.points) else self.points[idx]

return dst_scalar, dst_point
def get_idx(self, idx):
return self.get_scalar(idx), self.get_point(idx)

def __getitem__(self, item):
return self.get_idx(item)
Expand All @@ -708,21 +734,42 @@ def __next__(self):
self.current_idx += 1
return self[self.current_idx - 1]

def __len__(self):
return self.size

class MultiExpSequential(MultiExp):
"""
MultiExp holder with sequential evaluation
"""

def multiexp(dst=None, data=None, GiHi=False):
dst = _ensure_dst_key(dst)
crypto.identity_into(tmp_pt_1)
for i in range(len(data)):
sci, pti = data[i]
crypto.decodeint_into_noreduce(tmp_sc_1, sci)
crypto.decodepoint_into(tmp_pt_2, pti)
def __init__(
self, size=None, scalars=None, points=None, scalar_fnc=None, point_fnc=None
):
super().__init__(size, scalars=scalars, points=points, scalar_fnc=scalar_fnc, point_fnc=point_fnc)
self.current_idx = 0
self.acc = crypto.identity()
self.tmp = _ensure_dst_key()
self.eval_idx = 0

def add_pair(self, scalar, point):
self._acc(scalar, point)

def add_scalar(self, scalar):
self._acc(scalar, self.get_point(self.current_idx))

def _acc(self, scalar, point):
crypto.decodeint_into_noreduce(tmp_sc_1, scalar)
crypto.decodepoint_into(tmp_pt_2, point)
crypto.scalarmult_into(tmp_pt_3, tmp_pt_2, tmp_sc_1)
crypto.point_add_into(tmp_pt_1, tmp_pt_1, tmp_pt_3)
crypto.encodepoint_into(tmp_pt_1, dst)
return dst
crypto.point_add_into(self.acc, self.acc, tmp_pt_3)
self.current_idx += 1
self.size += 1

def eval(self, dst, GiHi=False):
dst = _ensure_dst_key(dst)
return crypto.encodepoint_into(self.acc, dst)


def multiexp(dst=None, data=None, GiHi=False):
return data.eval(dst, GiHi)


class BulletProofBuilder(object):
Expand Down Expand Up @@ -1612,10 +1659,11 @@ def verify_testnet(self, proof):
def verify(self, proof):
return self.verify_batch([proof])

def verify_batch(self, proofs):
def verify_batch(self, proofs, single_optim=True):
"""
BP batch verification
:param proofs:
:param single_optim: single proof memory optimization
:return:
"""
max_length = 0
Expand All @@ -1638,17 +1686,17 @@ def verify_batch(self, proofs):
tmp = _ensure_dst_key()

# setup weighted aggregates
Z0 = init_key(ONE)
is_single = len(proofs) == 1 and single_optim # ph4
z1 = init_key(ZERO)
Z2 = init_key(ONE)
z3 = init_key(ZERO)
z4 = vector_dup(ZERO, maxMN)
z5 = vector_dup(ZERO, maxMN)
Y2 = init_key(ONE)
Y3 = init_key(ONE)
Y4 = init_key(ONE)
y0 = init_key(ZERO)
m_z4 = vector_dup(ZERO, maxMN) if not is_single else None
m_z5 = vector_dup(ZERO, maxMN) if not is_single else None
m_y0 = init_key(ZERO)
y1 = init_key(ZERO)
muex_acc = init_key(ONE)

Gprec = self._gprec_aux(maxMN)
Hprec = self._hprec_aux(maxMN)

for proof in proofs:
M = 1
Expand All @@ -1659,7 +1707,8 @@ def verify_batch(self, proofs):

self.assrt(len(proof.L) == 6 + logM, "Proof is not the expected size")
MN = M * N
weight = crypto.encodeint(crypto.random_scalar())
weight_y = crypto.encodeint(crypto.random_scalar())
weight_z = crypto.encodeint(crypto.random_scalar())

# Reconstruct the challenges
hash_cache = hash_vct_to_scalar(None, proof.V)
Expand All @@ -1675,7 +1724,7 @@ def verify_batch(self, proofs):
self.assrt(x_ip != ZERO, "x_ip == 0")

# PAPER LINE 61
sc_muladd(y0, proof.taux, weight, y0)
sc_mulsub(m_y0, proof.taux, weight_y, m_y0)
zpow = vector_powers(z, M + 3)

k = _ensure_dst_key()
Expand All @@ -1688,39 +1737,32 @@ def verify_batch(self, proofs):
# VERIFY_line_61rl_new
sc_muladd(tmp, z, ip1y, k)
sc_sub(tmp, proof.t, tmp)
sc_muladd(y1, tmp, weight, y1)

muex = MultiExp(point_fnc=lambda i, d: proof.V[i])
sc_muladd(y1, tmp, weight_y, y1)
weight_y8 = sc_mul(None, weight_y, EIGHT)

muex = MultiExpSequential(points=[pt for pt in proof.V])
for j in range(len(proof.V)):
sc_mul(tmp, zpow[j + 2], EIGHT)
muex.add_scalar(tmp)
sc_mul(tmp, zpow[j + 2], weight_y8)
muex.add_scalar(init_key(tmp))

sc_mul(tmp, x, weight_y8)
muex.add_pair(init_key(tmp), proof.T1)

add_keys(Y2, Y2, scalarmult_key(None, multiexp(None, muex, False), weight))
weight8 = _ensure_dst_key()
sc_mul(weight8, weight, EIGHT)
sc_mul(tmp, x, weight8)
add_keys(Y3, Y3, scalarmult_key(None, proof.T1, tmp))
xsq = _ensure_dst_key()
sc_mul(xsq, x, x)
sc_mul(tmp, xsq, weight8)
add_keys(Y4, Y4, scalarmult_key(None, proof.T2, tmp))
del weight8

# PAPER LINE 62
sc_mul(tmp, x, EIGHT)
add_keys(
Z0,
Z0,
scalarmult_key(
None,
add_keys(
None,
scalarmult8(None, proof.A),
scalarmult_key(None, proof.S, tmp),
),
weight,
),
)

sc_mul(tmp, xsq, weight_y8)
muex.add_pair(init_key(tmp), proof.T2)

weight_z8 = sc_mul(None, weight_z, EIGHT)
muex.add_pair(weight_z8, proof.A)
sc_mul(tmp, x, weight_z8)
muex.add_pair(init_key(tmp), proof.S)

multiexp(tmp, muex, False)
add_keys(muex_acc, muex_acc, tmp)
del (muex)

# Compute the number of rounds for the inner product
rounds = logM + logN
Expand Down Expand Up @@ -1770,8 +1812,15 @@ def verify_batch(self, proofs):
sc_muladd(tmp, z, ypow, tmp)
sc_mulsub(h_scalar, tmp, yinvpow, h_scalar)

sc_muladd(z4[i], g_scalar, weight, z4[i])
sc_muladd(z5[i], h_scalar, weight, z5[i])
if not is_single: # ph4
sc_mulsub(m_z4[i], g_scalar, weight_z, m_z4[i])
sc_mulsub(m_z5[i], h_scalar, weight_z, m_z5[i])
else:
sc_mul(tmp, g_scalar, weight_z)
sub_keys(muex_acc, muex_acc, scalarmult_key(tmp, Gprec[i], tmp))

sc_mul(tmp, h_scalar, weight_z)
sub_keys(muex_acc, muex_acc, scalarmult_key(tmp, Hprec[i], tmp))

if i != MN - 1:
sc_mul(yinvpow, yinvpow, yinv)
Expand All @@ -1781,58 +1830,44 @@ def verify_batch(self, proofs):
del (g_scalar, h_scalar, twoN)
self.gc(63)

sc_muladd(z1, proof.mu, weight, z1)
muex = MultiExp(
sc_muladd(z1, proof.mu, weight_z, z1)
muex = MultiExpSequential(
point_fnc=lambda i, d: proof.L[i // 2]
if i & 1 == 0
else proof.R[i // 2]
)
for i in range(rounds):
sc_mul(tmp, w[i], w[i])
sc_mul(tmp, tmp, EIGHT)
sc_mul(tmp, tmp, weight_z8)
muex.add_scalar(tmp)
sc_mul(tmp, winv[i], winv[i])
sc_mul(tmp, tmp, EIGHT)
sc_mul(tmp, tmp, weight_z8)
muex.add_scalar(tmp)

acc = multiexp(None, muex, False)
add_keys(Z2, Z2, scalarmult_key(None, acc, weight))
add_keys(muex_acc, muex_acc, acc)

sc_mulsub(tmp, proof.a, proof.b, proof.t)
sc_mul(tmp, tmp, x_ip)
sc_muladd(z3, tmp, weight, z3)

# now check all proofs at once
check1 = _ensure_dst_key()
scalarmult_base(check1, y0)
add_keys(check1, check1, scalarmultH(None, y1))
sub_keys(check1, check1, Y2)
sub_keys(check1, check1, Y3)
sub_keys(check1, check1, Y4)
if check1 != ONE:
raise ValueError("Verification failure at step 1")

sc_sub(tmp, ZERO, z1)
check2 = crypto.ge_double_scalarmult_base_vartime(
crypto.decodeint(z3), crypto.gen_H(), crypto.decodeint(tmp)
)
crypto.point_add_into(check2, check2, crypto.decodepoint(Z0))
crypto.point_add_into(check2, check2, crypto.decodepoint(Z2))
sc_muladd(z3, tmp, weight_z, z3)

Gprec = self._gprec_aux(maxMN)
Hprec = self._hprec_aux(maxMN)
muex = MultiExp(
point_fnc=lambda i, d: Gprec[i // 2] if i & 1 == 0 else Hprec[i // 2]
)
for i in range(maxMN):
sc_sub(tmp, ZERO, z4[i])
muex.add_scalar(tmp)
sc_sub(tmp, ZERO, z5[i])
muex.add_scalar(tmp)

crypto.point_add_into(
check2, check2, crypto.decodepoint(multiexp(None, muex, True))
)
check2_enc = crypto.encodepoint(check2)
if check2_enc != ONE:
sc_sub(tmp, m_y0, z1)
z3p = sc_sub(None, z3, y1)

check2 = crypto.encodepoint(crypto.ge_double_scalarmult_base_vartime(
crypto.decodeint(z3p), crypto.gen_H(), crypto.decodeint(tmp)
))
add_keys(muex_acc, muex_acc, check2)

if not is_single: # ph4
muex = MultiExpSequential(
point_fnc=lambda i, d: Gprec[i // 2] if i & 1 == 0 else Hprec[i // 2]
)
for i in range(maxMN):
muex.add_scalar(m_z4[i])
muex.add_scalar(m_z5[i])
add_keys(muex_acc, muex_acc, multiexp(None, muex, True))

if muex_acc != ONE:
raise ValueError("Verification failure at step 2")
return True

0 comments on commit 2a2b0cb

Please sign in to comment.