Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_onion.c
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,21 @@ static void make_hmac(const struct hop *hops, size_t num_hops,
#endif
}

void _dump_hex(unsigned char *x, size_t s) {
printf(" ");
while (s > 0) {
printf("%02x", *x);
x++; s--;
}
}
#define dump_hex(x) _dump_hex((void*)&x, sizeof(x))
void dump_pkey(secp256k1_context *ctx, secp256k1_pubkey pkey) {
unsigned char tmp[65];
size_t len;
secp256k1_ec_pubkey_serialize(ctx, tmp, &len, &pkey, 0);
dump_hex(tmp);
}

static bool check_hmac(struct onion *onion, const struct hmackey *hmackey)
{
struct sha256 hmac;
Expand Down Expand Up @@ -438,6 +453,7 @@ bool create_onion(const secp256k1_pubkey pubkey[],

gen_keys(ctx, &seckeys[i], &pubkeys[i]);


/* Make shared secret. */
if (!secp256k1_ecdh(ctx, secret, &pubkey[i], seckeys[i].u.u8))
goto fail;
Expand Down Expand Up @@ -639,17 +655,23 @@ int main(int argc, char *argv[])
for (i = 0; i < hops; i++) {
asprintf(&msgs[i], "Message to %zu", i);
random_key(ctx, &seckeys[i], &pubkeys[i]);
printf(" * Keypair %zu:", i);
dump_hex(seckeys[i]);
dump_pkey(ctx, pubkeys[i]);
printf("\n");
}

if (!create_onion(pubkeys, msgs, hops, &onion))
errx(1, "Creating onion packet failed");
printf(" * Message:"); dump_hex(onion); printf("\n");

/* Now parse and peel. */
for (i = 0; i < hops; i++) {
struct enckey enckey;
struct iv pad_iv;

printf("Decrypting with key %zi\n", i);

if (!decrypt_onion(&seckeys[i], &onion, &enckey, &pad_iv, i))
errx(1, "Decrypting onion for hop %zi", i);
if (strcmp((char *)myhop(&onion)->msg, msgs[i]) != 0)
Expand Down
329 changes: 329 additions & 0 deletions test/test_onion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
#!/usr/bin/env python

import sys

from hashlib import sha256
from binascii import hexlify, unhexlify
import hmac
import random

from cryptography.hazmat.primitives.ciphers import Cipher, modes, algorithms
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import CTR
from cryptography.hazmat.backends import default_backend
# http://cryptography.io

from pyelliptic import ecc

class MyEx(Exception): pass

def hmac_sha256(k, m):
return hmac.new(k, m, sha256).digest()






## pyelliptic doesn't support compressed pubkey representations
## so we have to add some code...
from pyelliptic.openssl import OpenSSL
import ctypes

OpenSSL.EC_POINT_set_compressed_coordinates_GFp = \
OpenSSL._lib.EC_POINT_set_compressed_coordinates_GFp
OpenSSL.EC_POINT_set_compressed_coordinates_GFp.restype = ctypes.c_int
OpenSSL.EC_POINT_set_compressed_coordinates_GFp.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int,
ctypes.c_void_p]

def ecc_ecdh_key(sec, pub):
assert isinstance(sec, ecc.ECC)
if isinstance(pub, ecc.ECC):
pub = pub.get_pubkey()
#return sec.get_ecdh_key(pub)

pubkey_x, pubkey_y = ecc.ECC._decode_pubkey(pub, 'binary')

other_key = other_pub_key_x = other_pub_key_y = other_pub_key = None
own_priv_key = res = res_x = res_y = None
try:
other_key = OpenSSL.EC_KEY_new_by_curve_name(sec.curve)
if other_key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())

other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), 0)

other_group = OpenSSL.EC_KEY_get0_group(other_key)
other_pub_key = OpenSSL.EC_POINT_new(other_group)
if (other_pub_key == None):
raise Exception("[OpenSSl] EC_POINT_new FAIL ... " + OpenSSL.get_error())

if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(other_group,
other_pub_key,
other_pub_key_x,
other_pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ..." + OpenSSL.get_error())

own_priv_key = OpenSSL.BN_bin2bn(sec.privkey, len(sec.privkey), 0)

res = OpenSSL.EC_POINT_new(other_group)
if (OpenSSL.EC_POINT_mul(other_group, res, 0, other_pub_key, own_priv_key, 0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_mul FAIL ..." + OpenSSL.get_error())

res_x = OpenSSL.BN_new()
res_y = OpenSSL.BN_new()

if (OpenSSL.EC_POINT_get_affine_coordinates_GFp(other_group, res,
res_x,
res_y, 0
)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())

resx = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(res_x))
resy = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(res_y))

OpenSSL.BN_bn2bin(res_x, resx)
resx = resx.raw
OpenSSL.BN_bn2bin(res_y, resy)
resy = resy.raw

return resx, resy

finally:
if other_key: OpenSSL.EC_KEY_free(other_key)
if other_pub_key_x: OpenSSL.BN_free(other_pub_key_x)
if other_pub_key_y: OpenSSL.BN_free(other_pub_key_y)
if other_pub_key: OpenSSL.EC_POINT_free(other_pub_key)
if own_priv_key: OpenSSL.BN_free(own_priv_key)
if res: OpenSSL.EC_POINT_free(res)
if res_x: OpenSSL.BN_free(res_x)
if res_y: OpenSSL.BN_free(res_y)

def get_pos_y_for_x(pubkey_x, yneg=0):
key = pub_key = pub_key_x = pub_key_y = None
try:
key = OpenSSL.EC_KEY_new_by_curve_name(OpenSSL.get_curve('secp256k1'))
group = OpenSSL.EC_KEY_get0_group(key)
pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
pub_key = OpenSSL.EC_POINT_new(group)

if OpenSSL.EC_POINT_set_compressed_coordinates_GFp(group, pub_key,
pub_key_x, yneg, 0) == 0:
raise Exception("[OpenSSL] EC_POINT_set_compressed_coordinates_GFp FAIL ... " + OpenSSL.get_error())


pub_key_y = OpenSSL.BN_new()
if (OpenSSL.EC_POINT_get_affine_coordinates_GFp(group, pub_key,
pub_key_x,
pub_key_y, 0
)) == 0:
raise Exception("[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())

pubkeyy = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(pub_key_y))
OpenSSL.BN_bn2bin(pub_key_y, pubkeyy)
pubkeyy = pubkeyy.raw
field_size = OpenSSL.EC_GROUP_get_degree(OpenSSL.EC_KEY_get0_group(key))
secret_len = int((field_size + 7) / 8)
if len(pubkeyy) < secret_len:
pubkeyy = pubkeyy.rjust(secret_len, b'\0')
return pubkeyy
finally:
if key is not None: OpenSSL.EC_KEY_free(key)
if pub_key is not None: OpenSSL.EC_POINT_free(pub_key)
if pub_key_x is not None: OpenSSL.BN_free(pub_key_x)
if pub_key_y is not None: OpenSSL.BN_free(pub_key_y)

class Onion(object):
HMAC_LEN = 32
PKEY_LEN = 32
MSG_LEN = 128
ZEROES = b"\x00" * (HMAC_LEN + PKEY_LEN + MSG_LEN)

@staticmethod
def tweak_sha(sha, d):
sha = sha.copy()
sha.update(d)
return sha.digest()

@classmethod
def get_ecdh_secrets(cls, sec, pkey_x, pkey_y):
pkey = unhexlify('04') + pkey_x + pkey_y
tmp_key = ecc.ECC(curve='secp256k1', pubkey=pkey)
sec_x, sec_y = ecc_ecdh_key(sec, tmp_key)

b = '\x02' if ord(sec_y[-1]) % 2 == 0 else '\x03'
sec = sha256(sha256(b + sec_x).digest())

enckey = cls.tweak_sha(sec, b'\x00')[:16]
hmac = cls.tweak_sha(sec, b'\x01')
iv = cls.tweak_sha(sec, b'\x02')[:16]
pad_iv = cls.tweak_sha(sec, b'\x03')[:16]

return enckey, hmac, iv, pad_iv

def enc_pad(self, enckey, pad_iv):
aes = Cipher(AES(enckey), CTR(pad_iv),
default_backend()).encryptor()
return aes.update(self.ZEROES)

class OnionDecrypt(Onion):
def __init__(self, onion, my_ecc):
self.my_ecc = my_ecc

hmac_end = len(onion)
pkey_end = hmac_end - self.HMAC_LEN
self.msg_end = pkey_end - self.PKEY_LEN
self.fwd_end = self.msg_end - self.MSG_LEN

self.onion = onion
self.pkey = onion[self.msg_end:pkey_end]
self.hmac = onion[pkey_end:hmac_end]

self.get_secrets()

def decrypt(self):
pad = self.enc_pad(self.enckey, self.pad_iv)

aes = Cipher(AES(self.enckey), CTR(self.iv),
default_backend()).decryptor()
self.fwd = pad + aes.update(self.onion[:self.fwd_end])
self.msg = aes.update(self.onion[self.fwd_end:self.msg_end])

def get_secrets(self):
pkey_x = self.pkey
pkey_y = get_pos_y_for_x(pkey_x) # always positive by design
enckey, hmac, iv, pad_iv = self.get_ecdh_secrets(self.my_ecc, pkey_x, pkey_y)
if not self.check_hmac(hmac):
raise Exception("HMAC did not verify")
self.enckey = enckey
self.iv = iv
self.pad_iv = pad_iv

def check_hmac(self, hmac_key):
calc = hmac_sha256(hmac_key, self.onion[:-self.HMAC_LEN])
return calc == self.hmac

class OnionEncrypt(Onion):
def __init__(self, msgs, pubkeys):
assert len(msgs) == len(pubkeys)
assert 0 < len(msgs) <= 20
assert all( len(m) <= self.MSG_LEN for m in msgs )

msgs = [m + "\0"*(self.MSG_LEN - len(m)) for m in msgs]
pubkeys = [ecc.ECC(pubkey=pk, curve='secp256k1') for pk in pubkeys]
n = len(msgs)

tmpkeys = []
tmppubkeys = []
for i in range(n):
while True:
t = ecc.ECC(curve='secp256k1')
if ord(t.pubkey_y[-1]) % 2 == 0:
break
# or do the math to "flip" the secret key and pub key
tmpkeys.append(t)
tmppubkeys.append(t.pubkey_x)

enckeys, hmacs, ivs, pad_ivs = zip(*[self.get_ecdh_secrets(tmpkey, pkey.pubkey_x, pkey.pubkey_y)
for tmpkey, pkey in zip(tmpkeys, pubkeys)])

# padding takes the form:
# E_(n-1)(0000s)
# D_(n-1)(
# E(n-2)(0000s)
# D(n-2)(
# ...
# )
# )

padding = ""
for i in range(n-1):
pad = self.enc_pad(enckeys[i], pad_ivs[i])
aes = Cipher(AES(enckeys[i]), CTR(ivs[i]),
default_backend()).decryptor()
padding = pad + aes.update(padding)

if n < 20:
padding += str(bytearray(random.getrandbits(8)
for _ in range(len(self.ZEROES) * (20-n))))

# to encrypt the message we need to bump the counter past all
# the padding, then just encrypt the final message
aes = Cipher(AES(enckeys[-1]), CTR(ivs[-1]),
default_backend()).encryptor()
aes.update(padding) # don't care about cyphertext
msgenc = aes.update(msgs[-1])

msgenc = padding + msgenc + tmppubkeys[-1]
del padding
msgenc += hmac_sha256(hmacs[-1], msgenc)

# *PHEW*
# now iterate

for i in reversed(range(n-1)):
# drop the padding this node will add
msgenc = msgenc[len(self.ZEROES):]
# adding the msg
msgenc += msgs[i]
# encrypt it
aes = Cipher(AES(enckeys[i]), CTR(ivs[i]),
default_backend()).encryptor()
msgenc = aes.update(msgenc)
# add the tmp key
msgenc += tmppubkeys[i]
# add the hmac
msgenc += hmac_sha256(hmacs[i], msgenc)
self.onion = msgenc

def decode_from_file(f):
keys = []
msg = ""
for ln in f.readlines():
if ln.startswith(" * Keypair "):
w = ln.strip().split()
idx = int(w[2].strip(":"))
priv = unhexlify(w[3])
pub = unhexlify(w[4])
assert idx == len(keys)
keys.append(ecc.ECC(privkey=priv, pubkey=pub, curve='secp256k1'))
elif ln.startswith(" * Message:"):
msg = unhexlify(ln[11:].strip())
elif ln.startswith("Decrypting"):
pass
else:
print ln
assert ln.strip() == ""

assert msg != ""
for k in keys:
o = OnionDecrypt(msg, k)
o.decrypt()
print o.msg
msg = o.fwd
print "done"

if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "generate":
if len(sys.argv) == 3:
n = int(sys.argv[2])
else:
n = 20
servers = [ecc.ECC(curve='secp256k1') for _ in range(n)]
server_pubs = [s.get_pubkey() for s in servers]
msgs = ["Howzit %d..." % (i,) for i in range(n)]

o = OnionEncrypt(msgs, server_pubs)

for i, s in enumerate(servers):
print " * Keypair %d: %s %s" % (
i, hexlify(s.privkey), hexlify(s.get_pubkey()))
print " * Message: %s" % (hexlify(o.onion))
else:
decode_from_file(sys.stdin)