Skip to content

Commit

Permalink
xmr: use trezor.utils.ensure
Browse files Browse the repository at this point in the history
  • Loading branch information
jpochyla committed Oct 18, 2018
1 parent adf119a commit f82bd9c
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 37 deletions.
5 changes: 0 additions & 5 deletions src/apps/monero/protocol/signing/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,5 @@ def mem_trace(self, x=None, collect=False):
if collect:
gc.collect()

def assrt(self, condition, msg=None):
if condition:
return
raise ValueError("Assertion error%s" % (" : %s" % msg if msg else ""))

def change_address(self):
return self.output_change.addr if self.output_change else None
2 changes: 1 addition & 1 deletion src/apps/monero/protocol/signing/step_05_all_inputs_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def all_inputs_set(state: State):
state.output_masks.append(cur_mask)

if state.rct_type == RctType.Simple:
state.assrt(
utils.ensure(
crypto.sc_eq(state.sumout, state.sumpouts_alphas), "Invalid masks sum"
) # sum check
state.sumout = crypto.sc_init(0)
Expand Down
4 changes: 2 additions & 2 deletions src/apps/monero/protocol/signing/step_06_set_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _range_proof(state, amount, rsig_data):
# thus direct serialization cannot be used.
state.full_message_hasher.rsig_val(bp_obj, True, raw=False)
res = range_signatures.verify_bp(bp_obj, state.output_amounts, masks)
state.assrt(res, "BP verification fail")
utils.ensure(res, "BP verification fail")
state.mem_trace("BP verified" if __debug__ else None, collect=True)
del (bp_obj, range_signatures)

Expand Down Expand Up @@ -370,7 +370,7 @@ def _set_out_derivation(state: State, dst_entr, additional_txkey_priv):


def _check_out_commitment(state: State, amount, mask, C):
state.assrt(
utils.ensure(
crypto.point_eq(
C,
crypto.point_add(crypto.scalarmult_base(mask), crypto.scalarmult_h(amount)),
Expand Down
4 changes: 3 additions & 1 deletion src/apps/monero/protocol/signing/step_07_all_outputs_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import gc

from trezor import utils

from .state import State

from apps.monero.controller import misc
Expand Down Expand Up @@ -76,7 +78,7 @@ def _validate(state: State):

# Test if \sum Alpha == \sum A
if state.rct_type == RctType.Simple:
state.assrt(crypto.sc_eq(state.sumout, state.sumpouts_alphas))
utils.ensure(crypto.sc_eq(state.sumout, state.sumpouts_alphas))

# Fee test
if state.fee != (state.summary_inputs_money - state.summary_outs_money):
Expand Down
6 changes: 4 additions & 2 deletions src/apps/monero/protocol/signing/step_09_sign_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import gc

from trezor import utils

from .state import State

from apps.monero.controller import misc
Expand Down Expand Up @@ -110,14 +112,14 @@ async def sign_input(
kLRki = None # for multisig: src_entr.multisig_kLRki

# Private key correctness test
state.assrt(
utils.ensure(
crypto.point_eq(
crypto.decodepoint(src_entr.outputs[src_entr.real_output].key.dest),
crypto.scalarmult_base(input_secret_key.dest),
),
"Real source entry's destination does not equal spend key's",
)
state.assrt(
utils.ensure(
crypto.point_eq(
crypto.decodepoint(src_entr.outputs[src_entr.real_output].key.commitment),
crypto.gen_commitment(input_secret_key.mask, src_entr.amount),
Expand Down
50 changes: 24 additions & 26 deletions src/apps/monero/xmr/bulletproof.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import gc
from trezorutils import memcpy as _memcpy

from trezor import utils
from trezor.utils import memcpy as _memcpy

from apps.monero.xmr import crypto
from apps.monero.xmr.serialize.int_serialize import dump_uvarint_b_into, uvarint_size
Expand Down Expand Up @@ -929,10 +931,6 @@ def gc(self, *args):
if self.gc_fnc:
self.gc_fnc()

def assrt(self, cond, msg=None, *args, **kwargs):
if not cond:
raise ValueError(msg)

def aX_vcts(self, sv, MN):
num_inp = len(sv)

Expand Down Expand Up @@ -1026,8 +1024,8 @@ def prove(self, sv, gamma, proof_v8=False):
return self.prove_batch([sv], [gamma], proof_v8=proof_v8)

def prove_setup(self, sv, gamma, proof_v8=False):
self.assrt(len(sv) == len(gamma), "|sv| != |gamma|")
self.assrt(len(sv) > 0, "sv empty")
utils.ensure(len(sv) == len(gamma), "|sv| != |gamma|")
utils.ensure(len(sv) > 0, "sv empty")

self.proof_sec = crypto.random_bytes(64)
self._det_mask_init()
Expand Down Expand Up @@ -1368,17 +1366,17 @@ def verify_batch(self, proofs, single_optim=True, proof_v8=False):
"""
max_length = 0
for proof in proofs:
self.assrt(is_reduced(proof.taux), "Input scalar not in range")
self.assrt(is_reduced(proof.mu), "Input scalar not in range")
self.assrt(is_reduced(proof.a), "Input scalar not in range")
self.assrt(is_reduced(proof.b), "Input scalar not in range")
self.assrt(is_reduced(proof.t), "Input scalar not in range")
self.assrt(len(proof.V) >= 1, "V does not have at least one element")
self.assrt(len(proof.L) == len(proof.R), "|L| != |R|")
self.assrt(len(proof.L) > 0, "Empty proof")
utils.ensure(is_reduced(proof.taux), "Input scalar not in range")
utils.ensure(is_reduced(proof.mu), "Input scalar not in range")
utils.ensure(is_reduced(proof.a), "Input scalar not in range")
utils.ensure(is_reduced(proof.b), "Input scalar not in range")
utils.ensure(is_reduced(proof.t), "Input scalar not in range")
utils.ensure(len(proof.V) >= 1, "V does not have at least one element")
utils.ensure(len(proof.L) == len(proof.R), "|L| != |R|")
utils.ensure(len(proof.L) > 0, "Empty proof")
max_length = max(max_length, len(proof.L))

self.assrt(max_length < 32, "At least one proof is too large")
utils.ensure(max_length < 32, "At least one proof is too large")

maxMN = 1 << max_length
logN = 6
Expand All @@ -1405,23 +1403,23 @@ def verify_batch(self, proofs, single_optim=True, proof_v8=False):
logM += 1
M = 1 << logM

self.assrt(len(proof.L) == 6 + logM, "Proof is not the expected size")
utils.ensure(len(proof.L) == 6 + logM, "Proof is not the expected size")
MN = M * N
weight_y = crypto.encodeint(crypto.random_scalar())
weight_z = crypto.encodeint(crypto.random_scalar())

# Reconstruct the challenges
hash_cache = hash_vct_to_scalar(None, proof.V)
y = hash_cache_mash(None, hash_cache, proof.A, proof.S)
self.assrt(y != ZERO, "y == 0")
utils.ensure(y != ZERO, "y == 0")
z = hash_to_scalar(None, y)
copy_key(hash_cache, z)
self.assrt(z != ZERO, "z == 0")
utils.ensure(z != ZERO, "z == 0")

x = hash_cache_mash(None, hash_cache, z, proof.T1, proof.T2)
self.assrt(x != ZERO, "x == 0")
utils.ensure(x != ZERO, "x == 0")
x_ip = hash_cache_mash(None, hash_cache, x, proof.taux, proof.mu, proof.t)
self.assrt(x_ip != ZERO, "x_ip == 0")
utils.ensure(x_ip != ZERO, "x_ip == 0")

# PAPER LINE 61
sc_mulsub(m_y0, proof.taux, weight_y, m_y0)
Expand All @@ -1431,7 +1429,7 @@ def verify_batch(self, proofs, single_optim=True, proof_v8=False):
ip1y = vector_power_sum(y, MN)
sc_mulsub(k, zpow[2], ip1y, ZERO)
for j in range(1, M + 1):
self.assrt(j + 2 < len(zpow), "invalid zpow index")
utils.ensure(j + 2 < len(zpow), "invalid zpow index")
sc_mulsub(k, zpow.to(j + 2), BP_IP12, k)

# VERIFY_line_61rl_new
Expand Down Expand Up @@ -1471,15 +1469,15 @@ def verify_batch(self, proofs, single_optim=True, proof_v8=False):

# Compute the number of rounds for the inner product
rounds = logM + logN
self.assrt(rounds > 0, "Zero rounds")
utils.ensure(rounds > 0, "Zero rounds")

# PAPER LINES 21-22
# The inner product challenges are computed per round
w = _ensure_dst_keyvect(None, rounds)
for i in range(rounds):
hash_cache_mash(tmp_bf_0, hash_cache, proof.L[i], proof.R[i])
w.read(i, tmp_bf_0)
self.assrt(w[i] != ZERO, "w[i] == 0")
utils.ensure(w[i] != ZERO, "w[i] == 0")

# Basically PAPER LINES 24-25
# Compute the curvepoints from G[i] and H[i]
Expand Down Expand Up @@ -1513,8 +1511,8 @@ def verify_batch(self, proofs, single_optim=True, proof_v8=False):

# Adjust the scalars using the exponents from PAPER LINE 62
sc_add(g_scalar, g_scalar, z)
self.assrt(2 + i // N < len(zpow), "invalid zpow index")
self.assrt(i % N < len(twoN), "invalid twoN index")
utils.ensure(2 + i // N < len(zpow), "invalid zpow index")
utils.ensure(i % N < len(twoN), "invalid twoN index")
sc_mul(tmp, zpow.to(2 + i // N), twoN.to(i % N))
sc_muladd(tmp, z, ypow, tmp)
sc_mulsub(h_scalar, tmp, yinvpow, h_scalar)
Expand Down

0 comments on commit f82bd9c

Please sign in to comment.