Skip to content

Commit

Permalink
xmr: step 09 review
Browse files Browse the repository at this point in the history
  • Loading branch information
tsusanka committed Oct 10, 2018
1 parent a510150 commit d8e9937
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 67 deletions.
8 changes: 0 additions & 8 deletions src/apps/monero/layout/confirms.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,6 @@ async def transaction_finished(ctx):
await common.ui_text(text, tm=500 * 1000)


async def transaction_signed(ctx):
"""
Notifies the transaction was completely signed
"""
# todo
pass


async def transaction_step(ctx, step, sub_step=None, sub_step_total=None):
from trezor import ui
from trezor.ui.text import Text
Expand Down
2 changes: 1 addition & 1 deletion src/apps/monero/protocol/signing/step_05_all_inputs_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def all_inputs_set(state: State):
resp = MoneroTransactionAllInputsSetAck(rsig_data=rsig_data)

# If range proofs are being offloaded, we send the masks to the host, which uses them
# to create the range proof. If not, we do not send any and we use it in the following step.
# to create the range proof. If not, we do not send any and we use them in the following step.
if state.rsig_offload:
tmp_buff = bytearray(32)
rsig_data.mask = bytearray(32 * state.output_count)
Expand Down
113 changes: 55 additions & 58 deletions src/apps/monero/protocol/signing/step_09_sign_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,63 +10,68 @@
from apps.monero.layout import confirms
from apps.monero.xmr import common, crypto

if False:
from trezor.messages.MoneroTransactionSourceEntry import (
MoneroTransactionSourceEntry,
)


async def sign_input(
state: State,
src_entr,
vini_bin,
hmac_vini,
pseudo_out,
pseudo_out_hmac,
alpha_enc,
spend_enc,
src_entr: MoneroTransactionSourceEntry,
vini_bin: bytes,
vini_hmac: bytes,
pseudo_out: bytes,
pseudo_out_hmac: bytes,
pseudo_out_alpha_enc: bytes,
spend_enc: bytes,
):
"""
:param state: transaction state
:param src_entr: Source entry
:param vini_bin: tx.vin[i] for the transaction. Contains key image, offsets, amount (usually zero)
:param hmac_vini: HMAC for the tx.vin[i] as returned from Trezor
:param pseudo_out: pedersen commitment for the current input, uses alpha as the mask.
Only in memory offloaded scenario. Tuple containing HMAC, as returned from the Trezor.
:param pseudo_out_hmac:
:param alpha_enc: alpha mask for the current input. Only in memory offloaded scenario,
tuple as returned from the Trezor
:param spend_enc:
:param vini_hmac: HMAC for the tx.vin[i] as returned from Trezor
:param pseudo_out: Pedersen commitment for the current input, uses pseudo_out_alpha
as a mask. Only applicable for RCTTypeSimple.
:param pseudo_out_hmac: HMAC for pseudo_out
:param pseudo_out_alpha_enc: alpha mask used in pseudo_out, only applicable for RCTTypeSimple. Encrypted.
:param spend_enc: one time address spending private key. Encrypted.
:return: Generated signature MGs[i]
"""
from apps.monero.protocol import hmac_encryption_keys

# state.state.set_signature() todo
print("09")
await confirms.transaction_step(
state.ctx, state.STEP_SIGN, state.current_input_index + 1, state.input_count
)

state.current_input_index += 1
if state.current_input_index >= state.input_count:
raise ValueError("Invalid ins")
if state.use_simple_rct and alpha_enc is None:
raise ValueError("Inconsistent1")
raise ValueError("Invalid inputs count")
if state.use_simple_rct and pseudo_out is None:
raise ValueError("Inconsistent2")
raise ValueError("SimpleRCT requires pseudo_out but none provided")
if state.use_simple_rct and pseudo_out_alpha_enc is None:
raise ValueError("SimpleRCT requires pseudo_out's mask but none provided")
if state.current_input_index >= 1 and not state.use_simple_rct:
raise ValueError("Inconsistent3")
raise ValueError("Two and more inputs must imply SimpleRCT")

inv_idx = state.source_permutation[state.current_input_index]
input_position = state.source_permutation[state.current_input_index]

# Check HMAC of all inputs
hmac_vini_comp = await hmac_encryption_keys.gen_hmac_vini(
state.key_hmac, src_entr, vini_bin, inv_idx
# Check input's HMAC
vini_hmac_comp = await hmac_encryption_keys.gen_hmac_vini(
state.key_hmac, src_entr, vini_bin, input_position
)
if not common.ct_equal(hmac_vini_comp, hmac_vini):
if not common.ct_equal(vini_hmac_comp, vini_hmac):
raise ValueError("HMAC is not correct")

gc.collect()
state.mem_trace(1)

if state.use_simple_rct:
# both pseudo_out and its mask were offloaded so we need to
# validate pseudo_out's HMAC and decrypt the alpha
pseudo_out_hmac_comp = crypto.compute_hmac(
hmac_encryption_keys.hmac_key_txin_comm(state.key_hmac, inv_idx), pseudo_out
hmac_encryption_keys.hmac_key_txin_comm(state.key_hmac, input_position),
pseudo_out,
)
if not common.ct_equal(pseudo_out_hmac_comp, pseudo_out_hmac):
raise ValueError("HMAC is not correct")
Expand All @@ -76,10 +81,10 @@ async def sign_input(

from apps.monero.xmr.enc import chacha_poly

alpha_c = crypto.decodeint(
pseudo_out_alpha = crypto.decodeint(
chacha_poly.decrypt_pack(
hmac_encryption_keys.enc_key_txin_alpha(state.key_enc, inv_idx),
bytes(alpha_enc),
hmac_encryption_keys.enc_key_txin_alpha(state.key_enc, input_position),
bytes(pseudo_out_alpha_enc),
)
)
pseudo_out_c = crypto.decodepoint(pseudo_out)
Expand All @@ -88,9 +93,10 @@ async def sign_input(
from apps.monero.xmr.enc import chacha_poly
from apps.monero.xmr.serialize_messages.ct_keys import CtKey

input_secret = crypto.decodeint(
spend_key = crypto.decodeint(
chacha_poly.decrypt_pack(
hmac_encryption_keys.enc_key_spend(state.key_enc, inv_idx), bytes(spend_enc)
hmac_encryption_keys.enc_key_spend(state.key_enc, input_position),
bytes(spend_enc),
)
)

Expand All @@ -99,37 +105,42 @@ async def sign_input(

# Basic setup, sanity check
index = src_entr.real_output
in_sk = CtKey(dest=input_secret, mask=crypto.decodeint(src_entr.mask))
input_secret_key = CtKey(dest=spend_key, mask=crypto.decodeint(src_entr.mask))
kLRki = None # for multisig: src_entr.multisig_kLRki

# Private key correctness test
state.assrt(
crypto.point_eq(
crypto.decodepoint(src_entr.outputs[src_entr.real_output].key.dest),
crypto.scalarmult_base(in_sk.dest),
crypto.scalarmult_base(input_secret_key.dest),
),
"a1",
"Real source entry's destination does not equal spend key's",
)
state.assrt(
crypto.point_eq(
crypto.decodepoint(src_entr.outputs[src_entr.real_output].key.mask),
crypto.gen_commitment(in_sk.mask, src_entr.amount),
crypto.gen_commitment(input_secret_key.mask, src_entr.amount),
),
"a2",
"Real source entry's mask does not equal spend key's",
)

gc.collect()
state.mem_trace(4)

# RCT signature
gc.collect()
from apps.monero.xmr import mlsag2

if state.use_simple_rct:
# Simple RingCT
mix_ring = [x.key for x in src_entr.outputs]
mg, msc = mlsag2.prove_rct_mg_simple(
state.full_message, mix_ring, in_sk, alpha_c, pseudo_out_c, kLRki, index
state.full_message,
mix_ring,
input_secret_key,
pseudo_out_alpha,
pseudo_out_c,
kLRki,
index,
)

else:
Expand All @@ -140,7 +151,7 @@ async def sign_input(
mg, msc = mlsag2.prove_rct_mg(
state.full_message,
mix_ring,
[in_sk],
[input_secret_key],
state.output_sk_masks,
state.output_pk_masks,
kLRki,
Expand All @@ -158,14 +169,6 @@ async def sign_input(
gc.collect()
state.mem_trace(6)

# Final state transition
if state.current_input_index + 1 == state.input_count:
# state.state.set_signature_done() todo remove?
await confirms.transaction_signed(state.ctx)

gc.collect()
state.mem_trace()

from trezor.messages.MoneroTransactionSignInputAck import (
MoneroTransactionSignInputAck,
)
Expand All @@ -175,23 +178,17 @@ async def sign_input(
)


def _recode_msg(mgs, encode=True):
def _recode_msg(mgs):
"""
Recodes MGs signatures from raw forms to bytearrays so it works with serialization
:param mgs:
:param encode: if true encodes to byte representation, otherwise decodes from byte representation
:return:
"""
recode_int = crypto.encodeint if encode else crypto.decodeint
recode_point = crypto.encodepoint if encode else crypto.decodepoint

for idx in range(len(mgs)):
mgs[idx].cc = recode_int(mgs[idx].cc)
mgs[idx].cc = crypto.encodeint(mgs[idx].cc)
if hasattr(mgs[idx], "II") and mgs[idx].II:
for i in range(len(mgs[idx].II)):
mgs[idx].II[i] = recode_point(mgs[idx].II[i])
mgs[idx].II[i] = crypto.encodepoint(mgs[idx].II[i])

for i in range(len(mgs[idx].ss)):
for j in range(len(mgs[idx].ss[i])):
mgs[idx].ss[i][j] = recode_int(mgs[idx].ss[i][j])
mgs[idx].ss[i][j] = crypto.encodeint(mgs[idx].ss[i][j])
return mgs

0 comments on commit d8e9937

Please sign in to comment.