Skip to content

Commit

Permalink
xmr: step 05 and 06 masks and range proofs review
Browse files Browse the repository at this point in the history
Masks are now always generated in step 5 and stored in state.

Range proofs were reviewed only in a high-level manner and will be
reviewed later.
  • Loading branch information
tsusanka committed Oct 10, 2018
1 parent 67f391c commit b3f1017
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _check_change(state: State, outputs: list):

async def _compute_sec_keys(state: State, tsx_data: MoneroTransactionData):
"""
Generate master key H(TsxData || tx_priv)
Generate master key H( H(TsxData || tx_priv) || rand )
"""
import protobuf
from apps.monero.xmr.sub.keccak_hasher import get_keccak_writer
Expand Down
35 changes: 20 additions & 15 deletions src/apps/monero/protocol/signing/step_05_all_inputs_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,37 @@ async def all_inputs_set(state: State):
)
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData

rsig_data = MoneroTransactionRsigData()
resp = MoneroTransactionAllInputsSetAck(rsig_data=rsig_data)

if not state.rsig_offload:
return resp

# Simple range proof offloading
# Generate random commitment masks that sum to the input mask sum.
# TODO review together with step 6
tmp_buff = bytearray(32)
rsig_data.mask = bytearray(32 * state.output_count)
# Generate random commitment masks to be used in range proofs.
# If SimpleRCT is used the sum of the masks must match the input masks sum.
state.sumout = crypto.sc_init(0)
for i in range(state.output_count):
cur_mask = crypto.new_scalar()
cur_mask = crypto.new_scalar() # new mask for each output
is_last = i + 1 == state.output_count
if is_last and state.use_simple_rct:
# in SimpleRCT the last mask needs to be calculated as an offset of the sum
crypto.sc_sub_into(cur_mask, state.sumpouts_alphas, state.sumout)
else:
crypto.random_scalar(cur_mask)

crypto.sc_add_into(state.sumout, state.sumout, cur_mask)
state.output_masks.append(cur_mask)
crypto.encodeint_into(tmp_buff, cur_mask)
utils.memcpy(rsig_data.mask, 32 * i, tmp_buff, 0, 32)

state.assrt(crypto.sc_eq(state.sumout, state.sumpouts_alphas), "Invalid masks sum")
if state.use_simple_rct:
state.assrt(
crypto.sc_eq(state.sumout, state.sumpouts_alphas), "Invalid masks sum"
) # sum check
state.sumout = crypto.sc_init(0)

rsig_data = MoneroTransactionRsigData()
resp = MoneroTransactionAllInputsSetAck(rsig_data=rsig_data)

# If range proofs are being offloaded, we send the masks to the host, which uses them
# to create the range proof. If not, we do not send any and we use it in the following step.
if state.rsig_offload:
tmp_buff = bytearray(32)
rsig_data.mask = bytearray(32 * state.output_count)
for i in range(state.output_count):
crypto.encodeint_into(tmp_buff, state.output_masks[i])
utils.memcpy(rsig_data.mask, 32 * i, tmp_buff, 0, 32)

return resp
60 changes: 25 additions & 35 deletions src/apps/monero/protocol/signing/step_06_set_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def set_output(state: State, dst_entr, dst_entr_hmac, rsig_data):
utils.unimport_end(mods)
state.mem_trace(5, True)

# Range proof first, memory intensive TODO range proof, TODO mask
# Range proof first, memory intensive
rsig, mask = _range_proof(state, dst_entr.amount, rsig_data)
utils.unimport_end(mods)
state.mem_trace(6, True)
Expand Down Expand Up @@ -93,7 +93,7 @@ async def set_output(state: State, dst_entr, dst_entr_hmac, rsig_data):
return MoneroTransactionSetOutputAck(
tx_out=tx_out_bin,
vouti_hmac=hmac_vouti,
rsig_data=_return_rsig_data(rsig), # TODO
rsig_data=_return_rsig_data(rsig),
out_pk=out_pk_bin,
ecdh_info=ecdh_info_bin,
)
Expand Down Expand Up @@ -137,9 +137,9 @@ async def _set_out_tx_out(state: State, dst_entr, tx_out_key):
return tx_out_bin, hmac_vouti


def _range_proof(state, amount, rsig_data=None): # TODO Heeeere
def _range_proof(state, amount, rsig_data):
"""
Computes rangeproof and related information - out_sk, out_pk, ecdh_info.
Computes rangeproof
In order to optimize incremental transaction build, the mask computation is changed compared
to the official Monero code. In the official code, the input pedersen commitments are computed
after range proof in such a way summed masks for commitments (alpha) and rangeproofs (ai) are equal.
Expand All @@ -151,22 +151,17 @@ def _range_proof(state, amount, rsig_data=None): # TODO Heeeere
"""
from apps.monero.xmr import ring_ct

idx = state.current_output_index
mask = _get_out_mask(state, idx)
provided_rsig = (
rsig_data.rsig
if rsig_data and rsig_data.rsig and len(rsig_data.rsig) > 0
else None
)
mask = state.output_masks[state.current_output_index]
provided_rsig = None
if rsig_data and rsig_data.rsig and len(rsig_data.rsig) > 0:
provided_rsig = rsig_data.rsig
if not state.rsig_offload and provided_rsig:
raise misc.TrezorError("Provided unexpected rsig")
if not state.rsig_offload:
state.output_masks.append(mask)

# Batching
bidx = _get_rsig_batch(state, idx)
bidx = _get_rsig_batch(state, state.current_output_index)
batch_size = state.rsig_grouping[bidx]
last_in_batch = _is_last_in_batch(state, idx, bidx)
last_in_batch = _is_last_in_batch(state, state.current_output_index, bidx)
if state.rsig_offload and provided_rsig and not last_in_batch:
raise misc.TrezorError("Provided rsig too early")
if state.rsig_offload and last_in_batch and not provided_rsig:
Expand All @@ -182,6 +177,7 @@ def _range_proof(state, amount, rsig_data=None): # TODO Heeeere

state.mem_trace("pre-rproof" if __debug__ else None, collect=True)
if not state.rsig_offload and state.use_bulletproof:
"""Bulletproof calculation in trezor"""
rsig = ring_ct.prove_range_bp_batch(state.output_amounts, state.output_masks)
state.mem_trace("post-bp" if __debug__ else None, collect=True)

Expand All @@ -198,6 +194,7 @@ def _range_proof(state, amount, rsig_data=None): # TODO Heeeere
)

elif not state.rsig_offload and not state.use_bulletproof:
"""Borromean calculation in trezor"""
C, mask, rsig = ring_ct.prove_range_chunked(amount, mask)
del (ring_ct)

Expand All @@ -206,12 +203,14 @@ def _range_proof(state, amount, rsig_data=None): # TODO Heeeere
_check_out_commitment(state, amount, mask, C)

elif state.rsig_offload and state.use_bulletproof:
"""Bulletproof calculated on host, verify in trezor"""
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof

masks = [
_get_out_mask(state, 1 + idx - batch_size + ix) for ix in range(batch_size)
# TODO this should be tested
# last_in_batch = True (see above) so this is fine
masks = state.output_masks[
1 + state.current_output_index - batch_size : 1 + state.current_output_index
]

bp_obj = misc.parse_msg(rsig_data.rsig, Bulletproof)
rsig_data.rsig = None

Expand All @@ -225,15 +224,18 @@ def _range_proof(state, amount, rsig_data=None): # TODO Heeeere
del (bp_obj, ring_ct)

elif state.rsig_offload and not state.use_bulletproof:
state.full_message_hasher.rsig_val(rsig_data.rsig, False, raw=True)
rsig_data.rsig = None
"""Borromean offloading not supported"""
raise misc.TrezorError(
"Unsupported rsig state (Borromean offloaded is not supported)"
)

else:
raise misc.TrezorError("Unexpected rsig state")

state.mem_trace("rproof" if __debug__ else None, collect=True)
state.output_amounts = []
if not state.rsig_offload:
if state.current_output_index + 1 == state.output_count:
# output masks and amounts are not needed anymore
state.output_amounts = []
state.output_masks = []
return rsig, mask

Expand Down Expand Up @@ -375,11 +377,10 @@ def _check_out_commitment(state: State, amount, mask, C):
)


def _is_last_in_batch(state: State, idx, bidx=None):
def _is_last_in_batch(state: State, idx, bidx):
"""
Returns true if the current output is last in the rsig batch
"""
bidx = _get_rsig_batch(state, idx) if bidx is None else bidx
batch_size = state.rsig_grouping[bidx]
return (idx - sum(state.rsig_grouping[:bidx])) + 1 == batch_size

Expand All @@ -394,14 +395,3 @@ def _get_rsig_batch(state: State, idx):
c += state.rsig_grouping[r]
r += 1
return r - 1


def _get_out_mask(state: State, idx):
if state.rsig_offload:
return state.output_masks[idx]
else:
is_last = idx + 1 == state.output_count
if is_last:
return crypto.sc_sub(state.sumpouts_alphas, state.sumout)
else:
return crypto.random_scalar()

0 comments on commit b3f1017

Please sign in to comment.