From 0744f89eb0213ede0661d4d73e8fbb87b927be09 Mon Sep 17 00:00:00 2001 From: jr conlin Date: Wed, 21 Dec 2016 17:02:15 -0800 Subject: [PATCH] bug: Allow sign and verify to use different hashes than sha256 closes #52 --- CHANGES.txt | 5 ++ pyelliptic/__init__.py | 2 +- pyelliptic/ecc.py | 185 +++++++++++++++++++++++++++-------------- pyelliptic/hash.py | 30 +++++-- pyelliptic/openssl.py | 12 +++ setup.py | 4 +- 6 files changed, 170 insertions(+), 68 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 7856d0e..dbf8585 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,8 @@ +v1.6.0, 2016-12-21 +------------------ + +- allow other SHA digests (other than SHA_256) + v1.5.7, 2015-08-31 ------------------ diff --git a/pyelliptic/__init__.py b/pyelliptic/__init__.py index 85f8aa0..fd631a1 100644 --- a/pyelliptic/__init__.py +++ b/pyelliptic/__init__.py @@ -31,7 +31,7 @@ # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN # IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -__version__ = '1.5.7' +__version__ = '1.6.0' __all__ = [ 'OpenSSL', diff --git a/pyelliptic/ecc.py b/pyelliptic/ecc.py index 86b0938..cc44c94 100644 --- a/pyelliptic/ecc.py +++ b/pyelliptic/ecc.py @@ -68,10 +68,12 @@ class ECC: """ def __init__(self, pubkey=None, privkey=None, pubkey_x=None, - pubkey_y=None, raw_privkey=None, curve='sect283r1'): + pubkey_y=None, raw_privkey=None, curve='sect283r1', + hasher='sha256'): """ - For a normal and High level use, specifie pubkey, - privkey (if you need) and the curve + For a normal and High level use, specify pubkey, + privkey (if you need), the curve, and the hashing method. + """ if type(curve) == str: self.curve = OpenSSL.get_curve(curve) @@ -88,6 +90,19 @@ def __init__(self, pubkey=None, privkey=None, pubkey_x=None, else: self.privkey, self.pubkey_x, self.pubkey_y = self._generate() + _hashers = { + "sha256": OpenSSL.EVP_sha256(), + "sha384": OpenSSL.EVP_sha384(), + "sha512": OpenSSL.EVP_sha512(), + } + try: + self.hasher = _hashers[hasher] + self.hashval = hasher + except KeyError: + valid = '"' + '", "'.join(_hashers.keys()) + '"' + raise Exception("Invalid hasher value specified. Select from " + + valid) + def _set_keys(self, pubkey_x, pubkey_y, privkey): if self.raw_check_key(privkey, pubkey_x, pubkey_y) < 0: self.pubkey_x = None @@ -212,17 +227,21 @@ def _old_decode_privkey(privkey): return curve, privkey, i def _generate(self): + key = pub_key_x = pub_key_y = None try: pub_key_x = OpenSSL.BN_new() pub_key_y = OpenSSL.BN_new() key = OpenSSL.EC_KEY_new_by_curve_name(self.curve) - if key == 0: - raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error()) + if not key: + raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + + OpenSSL.get_error()) if (OpenSSL.EC_KEY_generate_key(key)) == 0: - raise Exception("[OpenSSL] EC_KEY_generate_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_generate_key FAIL ... " + + OpenSSL.get_error()) if (OpenSSL.EC_KEY_check_key(key)) == 0: - raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + + OpenSSL.get_error()) priv_key = OpenSSL.EC_KEY_get0_private_key(key) group = OpenSSL.EC_KEY_get0_group(key) @@ -233,9 +252,11 @@ def _generate(self): pub_key_y, 0 )) == 0: raise Exception( - "[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error()) + "[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + + OpenSSL.get_error()) - field_size = OpenSSL.EC_GROUP_get_degree(OpenSSL.EC_KEY_get0_group(key)) + field_size = OpenSSL.EC_GROUP_get_degree( + OpenSSL.EC_KEY_get0_group(key)) secret_len = int((field_size + 7) / 8) privkey = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(priv_key)) @@ -271,20 +292,24 @@ def get_ecdh_key(self, pubkey, format='binary'): return self.raw_get_ecdh_key(pubkey_x, pubkey_y) def raw_get_ecdh_key(self, pubkey_x, pubkey_y): + other_key = other_pub_key = other_pub_key_x = other_pub_key_y = None + own_key = own_priv_key = None try: ecdh_keybuffer = OpenSSL.malloc(0, 32) other_key = OpenSSL.EC_KEY_new_by_curve_name(self.curve) - if other_key == 0: - raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error()) + if not other_key: + raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + + OpenSSL.get_error()) other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0) other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), 0) other_group = OpenSSL.EC_KEY_get0_group(other_key) other_pub_key = OpenSSL.EC_POINT_new(other_group) - if (other_pub_key == None): - raise Exception("[OpenSSl] EC_POINT_new FAIL ... " + OpenSSL.get_error()) + if other_pub_key is None: + raise Exception("[OpenSSl] EC_POINT_new FAIL ... " + + OpenSSL.get_error()) if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(other_group, other_pub_key, @@ -292,27 +317,33 @@ def raw_get_ecdh_key(self, pubkey_x, pubkey_y): other_pub_key_y, 0)) == 0: raise Exception( - "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ..." + OpenSSL.get_error()) + "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ..." + + OpenSSL.get_error()) if (OpenSSL.EC_KEY_set_public_key(other_key, other_pub_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + + OpenSSL.get_error()) if (OpenSSL.EC_KEY_check_key(other_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + + OpenSSL.get_error()) own_key = OpenSSL.EC_KEY_new_by_curve_name(self.curve) - if own_key == 0: - raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error()) + if not own_key: + raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + + OpenSSL.get_error()) own_priv_key = OpenSSL.BN_bin2bn( self.privkey, len(self.privkey), 0) if (OpenSSL.EC_KEY_set_private_key(own_key, own_priv_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " + + OpenSSL.get_error()) OpenSSL.ECDH_set_method(own_key, OpenSSL.ECDH_OpenSSL()) ecdh_keylen = OpenSSL.ECDH_compute_key( ecdh_keybuffer, 32, other_pub_key, own_key, 0) if ecdh_keylen != 32: - raise Exception("[OpenSSL] ECDH keylen FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] ECDH keylen FAIL ... " + + OpenSSL.get_error()) return ecdh_keybuffer.raw @@ -338,10 +369,12 @@ def check_key(self, privkey, pubkey): def raw_check_key(self, privkey, pubkey_x, pubkey_y): curve = self.curve + key = pub_key_x = pub_key_y = pub_key = priv_key = None try: key = OpenSSL.EC_KEY_new_by_curve_name(curve) - if key == 0: - raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error()) + if not key: + raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + + OpenSSL.get_error()) if privkey is not None: priv_key = OpenSSL.BN_bin2bn(privkey, len(privkey), 0) pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0) @@ -350,7 +383,8 @@ def raw_check_key(self, privkey, pubkey_x, pubkey_y): if privkey is not None: if (OpenSSL.EC_KEY_set_private_key(key, priv_key)) == 0: raise Exception( - "[OpenSSL] EC_KEY_set_private_key FAIL ... " + OpenSSL.get_error()) + "[OpenSSL] EC_KEY_set_private_key FAIL ... " + + OpenSSL.get_error()) group = OpenSSL.EC_KEY_get0_group(key) pub_key = OpenSSL.EC_POINT_new(group) @@ -360,11 +394,14 @@ def raw_check_key(self, privkey, pubkey_x, pubkey_y): pub_key_y, 0)) == 0: raise Exception( - "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error()) + "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + + OpenSSL.get_error()) if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + + OpenSSL.get_error()) if (OpenSSL.EC_KEY_check_key(key)) == 0: - raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + + OpenSSL.get_error()) return 0 finally: @@ -379,6 +416,7 @@ def sign(self, inputb): """ Sign the input with ECDSA method and returns the signature """ + key = pub_key = pub_key_x = pub_key_y = priv_key = md_ctx = None try: size = len(inputb) buff = OpenSSL.malloc(inputb, size) @@ -389,15 +427,17 @@ def sign(self, inputb): sig = OpenSSL.malloc(0, 151) key = OpenSSL.EC_KEY_new_by_curve_name(self.curve) - if key == 0: - raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error()) + if not key: + raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + + OpenSSL.get_error()) priv_key = OpenSSL.BN_bin2bn(self.privkey, len(self.privkey), 0) pub_key_x = OpenSSL.BN_bin2bn(self.pubkey_x, len(self.pubkey_x), 0) pub_key_y = OpenSSL.BN_bin2bn(self.pubkey_y, len(self.pubkey_y), 0) if (OpenSSL.EC_KEY_set_private_key(key, priv_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " + + OpenSSL.get_error()) group = OpenSSL.EC_KEY_get0_group(key) pub_key = OpenSSL.EC_POINT_new(group) @@ -407,22 +447,32 @@ def sign(self, inputb): pub_key_y, 0)) == 0: raise Exception( - "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error()) - if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error()) - if (OpenSSL.EC_KEY_check_key(key)) == 0: - raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error()) - - OpenSSL.EVP_MD_CTX_init(md_ctx) - OpenSSL.EVP_DigestInit_ex(md_ctx, OpenSSL.EVP_sha256(), None) - - if (OpenSSL.EVP_DigestUpdate(md_ctx, buff, size)) == 0: - raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ... " + OpenSSL.get_error()) - OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len) + "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + + OpenSSL.get_error()) + if OpenSSL.EC_KEY_set_public_key(key, pub_key) == 0: + raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + + OpenSSL.get_error()) + if OpenSSL.EC_KEY_check_key(key) == 0: + raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + + OpenSSL.get_error()) + + if OpenSSL.EVP_MD_CTX_init(md_ctx) == 0: + raise Exception("[OpenSSL] Failed to create context ..." + + OpenSSL.get_error()) + if OpenSSL.EVP_DigestInit_ex(md_ctx, self.hasher, None) == 0: + raise Exception("[OpenSSL] Could not initialize digest ..." + + OpenSSL.get_error()) + if OpenSSL.EVP_DigestUpdate(md_ctx, buff, size) == 0: + raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ..." + + OpenSSL.get_error()) + if OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len) == 0: + raise Exception("[OpenSSL] Could not finalize digest ..." + + OpenSSL.get_error()) OpenSSL.ECDSA_sign(0, digest, dgst_len.contents, sig, siglen, key) if (OpenSSL.ECDSA_verify(0, digest, dgst_len.contents, sig, siglen.contents, key)) != 1: - raise Exception("[OpenSSL] ECDSA_verify FAIL ... " + OpenSSL.get_error()) + raise Exception("[OpenSSL] ECDSA_verify FAIL ... " + + OpenSSL.get_error()) return sig.raw[0:siglen.contents.value] @@ -439,40 +489,54 @@ def verify(self, sig, inputb): Verify the signature with the input and the local public key. Returns a boolean """ + key = pub_key = pub_key_x = pub_key_y = md_ctx = None try: bsig = OpenSSL.malloc(sig, len(sig)) binputb = OpenSSL.malloc(inputb, len(inputb)) digest = OpenSSL.malloc(0, 64) dgst_len = OpenSSL.pointer(OpenSSL.c_int(0)) md_ctx = OpenSSL.EVP_MD_CTX_create() + if not md_ctx: + raise Exception("[OpenSSL] Failed to create context ..."+ + OpenSSL.get_error()) key = OpenSSL.EC_KEY_new_by_curve_name(self.curve) - if key == 0: - raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error()) + if not key: + raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + + OpenSSL.get_error()) pub_key_x = OpenSSL.BN_bin2bn(self.pubkey_x, len(self.pubkey_x), 0) pub_key_y = OpenSSL.BN_bin2bn(self.pubkey_y, len(self.pubkey_y), 0) group = OpenSSL.EC_KEY_get0_group(key) pub_key = OpenSSL.EC_POINT_new(group) - if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(group, pub_key, - pub_key_x, - pub_key_y, - 0)) == 0: + if OpenSSL.EC_POINT_set_affine_coordinates_GFp(group, + pub_key, + pub_key_x, + pub_key_y, + 0) == 0: raise Exception( - "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error()) - if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0: - raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error()) - if (OpenSSL.EC_KEY_check_key(key)) == 0: - raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error()) - - OpenSSL.EVP_MD_CTX_init(md_ctx) - OpenSSL.EVP_DigestInit_ex(md_ctx, OpenSSL.EVP_sha256(), None) - if (OpenSSL.EVP_DigestUpdate(md_ctx, binputb, len(inputb))) == 0: - raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ... " + OpenSSL.get_error()) - - OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len) + "[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + + OpenSSL.get_error()) + if OpenSSL.EC_KEY_set_public_key(key, pub_key) == 0: + raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + + OpenSSL.get_error()) + if OpenSSL.EC_KEY_check_key(key) == 0: + raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + + OpenSSL.get_error()) + if OpenSSL.EVP_MD_CTX_init(md_ctx) == 0: + raise Exception("[OpenSSL] Failed to create context ..." + + OpenSSL.get_error()) + if OpenSSL.EVP_DigestInit_ex(md_ctx, self.hasher, None) == 0: + raise Exception("[OpenSSL] Could not initialize digest ..." + + OpenSSL.get_error()) + if OpenSSL.EVP_DigestUpdate(md_ctx, binputb, len(inputb)) == 0: + raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ... " + + OpenSSL.get_error()) + if OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len) == 0: + raise Exception("[OpenSSL] Could not finalize digest ..." + + OpenSSL.get_error()) ret = OpenSSL.ECDSA_verify( 0, digest, dgst_len.contents, bsig, len(sig), key) @@ -483,7 +547,6 @@ def verify(self, sig, inputb): return False # Bad signature ! else: return True # Good - return False finally: OpenSSL.EC_KEY_free(key) diff --git a/pyelliptic/hash.py b/pyelliptic/hash.py index 98d89e3..6a5936f 100644 --- a/pyelliptic/hash.py +++ b/pyelliptic/hash.py @@ -66,7 +66,23 @@ def hmac_sha256(k, m): d = OpenSSL.malloc(m, len(m)) md = OpenSSL.malloc(0, 32) i = OpenSSL.pointer(OpenSSL.c_int(0)) - OpenSSL.HMAC(OpenSSL.EVP_sha256(), key, len(k), d, len(m), md, i) + if not OpenSSL.HMAC(OpenSSL.EVP_sha256(), key, len(k), d, len(m), md, i): + raise Exception("[OpenSSL] sha256 HMAC failed ..." + + OpenSSL.get_error()) + return md.raw + + +def hmac_sha384(k, m): + """ + Compute the key and the message with HMAC SHA384 + """ + key = OpenSSL.malloc(k, len(k)) + d = OpenSSL.malloc(m, len(m)) + md = OpenSSL.malloc(0, 64) + i = OpenSSL.pointer(OpenSSL.c_int(0)) + if not OpenSSL.HMAC(OpenSSL.EVP_sha384(), key, len(k), d, len(m), md, i): + raise Exception("[OpenSSL] sha384 HMAC failed ..." + + OpenSSL.get_error()) return md.raw @@ -78,7 +94,9 @@ def hmac_sha512(k, m): d = OpenSSL.malloc(m, len(m)) md = OpenSSL.malloc(0, 64) i = OpenSSL.pointer(OpenSSL.c_int(0)) - OpenSSL.HMAC(OpenSSL.EVP_sha512(), key, len(k), d, len(m), md, i) + if not OpenSSL.HMAC(OpenSSL.EVP_sha512(), key, len(k), d, len(m), md, i): + raise Exception("[OpenSSL] sha512 HMAC failed ..." + + OpenSSL.get_error()) return md.raw @@ -88,7 +106,9 @@ def pbkdf2(password, salt=None, i=10000, keylen=64): p_password = OpenSSL.malloc(password, len(password)) p_salt = OpenSSL.malloc(salt, len(salt)) output = OpenSSL.malloc(0, keylen) - OpenSSL.PKCS5_PBKDF2_HMAC(p_password, len(password), p_salt, - len(p_salt), i, OpenSSL.EVP_sha256(), - keylen, output) + if not OpenSSL.PKCS5_PBKDF2_HMAC(p_password, len(password), p_salt, + len(p_salt), i, OpenSSL.EVP_sha256(), + keylen, output): + raise Exception("[OpenSSL] PBKDF2_HMAC failed ..." + + OpenSSL.get_error()) return salt, output.raw diff --git a/pyelliptic/openssl.py b/pyelliptic/openssl.py index c91b272..320169c 100644 --- a/pyelliptic/openssl.py +++ b/pyelliptic/openssl.py @@ -100,6 +100,14 @@ def __init__(self, library): self.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p] + # The following is useful for debugging and development work + # for libraries that use pyelliptic. This converts a big number + # pointer struct into it's decimal wich will allow developers + # to trace values more thoroughly. + self.BN_bn2dec = self._lib.BN_bn2dec + self.BN_bn2dec.restype = ctypes.c_char_p + self.BN_bn2dec.argtypes = [ctypes.c_void_p] + self.EC_GROUP_get_degree = self._lib.EC_GROUP_get_degree self.EC_GROUP_get_degree.restype = ctypes.c_int self.EC_GROUP_get_degree.argtypes = [ctypes.c_void_p] @@ -344,6 +352,10 @@ def __init__(self, library): self.EVP_sha256.restype = ctypes.c_void_p self.EVP_sha256.argtypes = [] + self.EVP_sha384 = self._lib.EVP_sha384 + self.EVP_sha384.restype = ctypes.c_void_p + self.EVP_sha384.argtypes = [] + self.i2o_ECPublicKey = self._lib.i2o_ECPublicKey self.i2o_ECPublicKey.restype = ctypes.c_int self.i2o_ECPublicKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p] diff --git a/setup.py b/setup.py index a43f08a..ddca82d 100644 --- a/setup.py +++ b/setup.py @@ -31,9 +31,11 @@ from setuptools import setup, find_packages +from pyelliptic import __version__ + setup( name="pyelliptic", - version='1.5.7', + version=__version__, url='https://github.com/yann2192/pyelliptic', license='BSD', description=