diff --git a/src/apps/monero/xmr/mlsag.py b/src/apps/monero/xmr/mlsag.py index 30ce3f025..5cc093c06 100644 --- a/src/apps/monero/xmr/mlsag.py +++ b/src/apps/monero/xmr/mlsag.py @@ -17,7 +17,7 @@ This has one unfortunate effect where `rows` is always equal to 2 and dsRows always to 1, but the algorithm is still written as the numbers -can be arbitrary. That's why there are loops such as `for i in range(dsRows)` +are arbitrary. That's why there are loops such as `for i in range(dsRows)` where it is run only once currently. ---------- @@ -45,39 +45,90 @@ from apps.monero.xmr import crypto -def key_vector(rows): - return [None] * rows +def generate_mlsag_full( + message, pubs, in_sk, out_sk_mask, out_pk_commitments, kLRki, index, txn_fee_key +): + cols = len(pubs) + if cols == 0: + raise ValueError("Empty pubs") + rows = len(pubs[0]) + if rows == 0: + raise ValueError("Empty pub row") + for i in range(cols): + if len(pubs[i]) != rows: + raise ValueError("pub is not rectangular") + if len(in_sk) != rows: + raise ValueError("Bad inSk size") + if len(out_sk_mask) != len(out_pk_commitments): + raise ValueError("Bad outsk/putpk size") -def key_matrix(rows, cols): - """ - first index is columns (so slightly backward from math) - """ - rv = [None] * cols - for i in range(0, cols): - rv[i] = key_vector(rows) - return rv + sk = _key_vector(rows + 1) + M = _key_matrix(rows + 1, cols) + for i in range(rows + 1): + sk[i] = crypto.sc_0() + for i in range(cols): + M[i][rows] = crypto.identity() + for j in range(rows): + M[i][j] = crypto.decodepoint(pubs[i][j].dest) + M[i][rows] = crypto.point_add( + M[i][rows], crypto.decodepoint(pubs[i][j].mask) + ) -def _generate_random_vector(n): - """ - Generates vector of random scalars - """ - return [crypto.random_scalar() for _ in range(0, n)] + sk[rows] = crypto.sc_0() + for j in range(rows): + sk[j] = in_sk[j].dest + sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row + for i in range(cols): + for j in range(len(out_pk_commitments)): + M[i][rows] = crypto.point_sub( + M[i][rows], crypto.decodepoint(out_pk_commitments[j]) + ) # subtract output Ci's in last row + + # Subtract txn fee output in last row + M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key) + + for j in range(len(out_pk_commitments)): + sk[rows] = crypto.sc_sub( + sk[rows], out_sk_mask[j] + ) # subtract output masks in last row + + return generate_mlsag(message, M, sk, kLRki, index, rows) -def hasher_message(message): + +def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index): """ - Returns incremental hasher for MLSAG + MLSAG for RctType.Simple + :param message: the full message to be signed (actually its hash) + :param pubs: vector of MoneroRctKey; this forms the ring; point values in encoded form; (dest, mask) = (P, C) + :param in_sk: CtKey; spending private key with input commitment mask (original); better_name: input_secret_key + :param a: mask from the pseudo output commitment; better name: pseudo_out_alpha + :param cout: pseudo output commitment; point, decoded; better name: pseudo_out_c + :param kLRki: used only in multisig, currently not implemented + :param index: specifies corresponding public key to the `in_sk` in the pubs array + :return: MgSig """ - ctx = crypto.get_keccak() - ctx.update(message) - return ctx + # Monero signs inputs separately, so `rows` always equals 2 (pubkey, commitment) + # and `dsRows` is always 1 (denotes where the pubkeys "end") + rows = 2 + dsRows = 1 + cols = len(pubs) + if cols == 0: + raise ValueError("Empty pubs") + sk = _key_vector(rows) + M = _key_matrix(rows, cols) -def hash_point(hasher, point, tmp_buff): - crypto.encodepoint_into(tmp_buff, point) - hasher.update(tmp_buff) + sk[0] = in_sk.dest + sk[1] = crypto.sc_sub(in_sk.mask, a) + + for i in range(cols): + M[i][0] = crypto.decodepoint(pubs[i].dest) + M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].mask), cout) + + return generate_mlsag(message, M, sk, kLRki, index, dsRows) def gen_mlsag_assert(pk, xx, kLRki, index, dsRows): @@ -108,7 +159,9 @@ def gen_mlsag_assert(pk, xx, kLRki, index, dsRows): return rows, cols -def generate_first_c_and_key_images(message, rv, pk, xx, kLRki, index, dsRows, rows, cols): +def generate_first_c_and_key_images( + message, rv, pk, xx, kLRki, index, dsRows, rows, cols +): """ MLSAG computation - the part with secret keys :param message: the full message to be signed (actually its hash) @@ -121,13 +174,13 @@ def generate_first_c_and_key_images(message, rv, pk, xx, kLRki, index, dsRows, r :param rows: total number of rows :param cols: size of ring """ - Ip = key_vector(dsRows) - rv.II = key_vector(dsRows) - alpha = key_vector(rows) - rv.ss = key_matrix(rows, cols) + Ip = _key_vector(dsRows) + rv.II = _key_vector(dsRows) + alpha = _key_vector(rows) + rv.ss = _key_matrix(rows, cols) tmp_buff = bytearray(32) - hasher = hasher_message(message) + hasher = _hasher_message(message) for i in range(dsRows): # this is somewhat extra as compared to the Ring Confidential Tx paper @@ -149,8 +202,8 @@ def generate_first_c_and_key_images(message, rv, pk, xx, kLRki, index, dsRows, r aHPi = crypto.scalarmult(Hi, alpha[i]) # key image rv.II[i] = crypto.scalarmult(Hi, xx[i]) - hash_point(hasher, aGi, tmp_buff) - hash_point(hasher, aHPi, tmp_buff) + _hash_point(hasher, aGi, tmp_buff) + _hash_point(hasher, aHPi, tmp_buff) Ip[i] = rv.II[i] @@ -161,8 +214,8 @@ def generate_first_c_and_key_images(message, rv, pk, xx, kLRki, index, dsRows, r # for some reasons we omit calculating R here, which seems # contrary to the paper, but it is in the Monero official client # see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191 - hash_point(hasher, pk[index][i], tmp_buff) - hash_point(hasher, aGi, tmp_buff) + _hash_point(hasher, pk[index][i], tmp_buff) + _hash_point(hasher, aGi, tmp_buff) # the first c c_old = hasher.digest() @@ -201,7 +254,7 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows): tmp_buff = bytearray(32) while i != index: rv.ss[i] = _generate_random_vector(rows) - hasher = hasher_message(message) + hasher = _hasher_message(message) for j in range(dsRows): # L = rv.ss[i][j] * G + c_old * pk[i][j] @@ -209,15 +262,15 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows): Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j])) # R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j] R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j]) - hash_point(hasher, pk[i][j], tmp_buff) - hash_point(hasher, L, tmp_buff) - hash_point(hasher, R, tmp_buff) + _hash_point(hasher, pk[i][j], tmp_buff) + _hash_point(hasher, L, tmp_buff) + _hash_point(hasher, R, tmp_buff) for j in range(dsRows, rows): # again, omitting R here as discussed above L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) - hash_point(hasher, pk[i][j], tmp_buff) - hash_point(hasher, L, tmp_buff) + _hash_point(hasher, pk[i][j], tmp_buff) + _hash_point(hasher, L, tmp_buff) c = crypto.decodeint(hasher.digest()) c_old = c @@ -232,87 +285,36 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows): return rv -def generate_mlsag_full( - message, pubs, in_sk, out_sk_mask, out_pk_commitments, kLRki, index, txn_fee_key -): - cols = len(pubs) - if cols == 0: - raise ValueError("Empty pubs") - rows = len(pubs[0]) - if rows == 0: - raise ValueError("Empty pub row") - for i in range(cols): - if len(pubs[i]) != rows: - raise ValueError("pub is not rectangular") - - if len(in_sk) != rows: - raise ValueError("Bad inSk size") - if len(out_sk_mask) != len(out_pk_commitments): - raise ValueError("Bad outsk/putpk size") - - sk = key_vector(rows + 1) - M = key_matrix(rows + 1, cols) - for i in range(rows + 1): - sk[i] = crypto.sc_0() - - for i in range(cols): - M[i][rows] = crypto.identity() - for j in range(rows): - M[i][j] = crypto.decodepoint(pubs[i][j].dest) - M[i][rows] = crypto.point_add( - M[i][rows], crypto.decodepoint(pubs[i][j].mask) - ) - - sk[rows] = crypto.sc_0() - for j in range(rows): - sk[j] = in_sk[j].dest - sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row - - for i in range(cols): - for j in range(len(out_pk_commitments)): - M[i][rows] = crypto.point_sub( - M[i][rows], crypto.decodepoint(out_pk_commitments[j]) - ) # subtract output Ci's in last row - - # Subtract txn fee output in last row - M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key) +def _key_vector(rows): + return [None] * rows - for j in range(len(out_pk_commitments)): - sk[rows] = crypto.sc_sub( - sk[rows], out_sk_mask[j] - ) # subtract output masks in last row - return generate_mlsag(message, M, sk, kLRki, index, rows) +def _key_matrix(rows, cols): + """ + first index is columns (so slightly backward from math) + """ + rv = [None] * cols + for i in range(0, cols): + rv[i] = _key_vector(rows) + return rv -def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index): +def _generate_random_vector(n): """ - MLSAG for RctType.Simple - :param message: the full message to be signed (actually its hash) - :param pubs: vector of MoneroRctKey; this forms the ring; point values in encoded form; (dest, mask) = (P, C) - :param in_sk: CtKey; spending private key with input commitment mask (original); better_name: input_secret_key - :param a: mask from the pseudo output commitment; better name: pseudo_out_alpha - :param cout: pseudo output commitment; point, decoded; better name: pseudo_out_c - :param kLRki: used only in multisig, currently not implemented - :param index: specifies corresponding public key to the `in_sk` in the pubs array - :return: MgSig + Generates vector of random scalars """ - # Monero signs inputs separately, so `rows` always equals 2 (pubkey, commitment) - # and `dsRows` is always 1 (denotes where the pubkeys "end") - rows = 2 - dsRows = 1 - cols = len(pubs) - if cols == 0: - raise ValueError("Empty pubs") + return [crypto.random_scalar() for _ in range(0, n)] - sk = key_vector(rows) - M = key_matrix(rows, cols) - sk[0] = in_sk.dest - sk[1] = crypto.sc_sub(in_sk.mask, a) +def _hasher_message(message): + """ + Returns incremental hasher for MLSAG + """ + ctx = crypto.get_keccak() + ctx.update(message) + return ctx - for i in range(cols): - M[i][0] = crypto.decodepoint(pubs[i].dest) - M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].mask), cout) - return generate_mlsag(message, M, sk, kLRki, index, dsRows) +def _hash_point(hasher, point, tmp_buff): + crypto.encodepoint_into(tmp_buff, point) + hasher.update(tmp_buff)