Skip to content

Commit

Permalink
bug: Allow sign and verify to use different hashes than sha256
Browse files Browse the repository at this point in the history
  • Loading branch information
jrconlin committed Jan 3, 2017
1 parent b9e9ae5 commit 0744f89
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 68 deletions.
5 changes: 5 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
v1.6.0, 2016-12-21
------------------

- allow other SHA digests (other than SHA_256)

v1.5.7, 2015-08-31
------------------

Expand Down
2 changes: 1 addition & 1 deletion pyelliptic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
# IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

__version__ = '1.5.7'
__version__ = '1.6.0'

__all__ = [
'OpenSSL',
Expand Down
185 changes: 124 additions & 61 deletions pyelliptic/ecc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ class ECC:
"""

def __init__(self, pubkey=None, privkey=None, pubkey_x=None,
pubkey_y=None, raw_privkey=None, curve='sect283r1'):
pubkey_y=None, raw_privkey=None, curve='sect283r1',
hasher='sha256'):
"""
For a normal and High level use, specifie pubkey,
privkey (if you need) and the curve
For a normal and High level use, specify pubkey,
privkey (if you need), the curve, and the hashing method.
"""
if type(curve) == str:
self.curve = OpenSSL.get_curve(curve)
Expand All @@ -88,6 +90,19 @@ def __init__(self, pubkey=None, privkey=None, pubkey_x=None,
else:
self.privkey, self.pubkey_x, self.pubkey_y = self._generate()

_hashers = {
"sha256": OpenSSL.EVP_sha256(),
"sha384": OpenSSL.EVP_sha384(),
"sha512": OpenSSL.EVP_sha512(),
}
try:
self.hasher = _hashers[hasher]
self.hashval = hasher
except KeyError:
valid = '"' + '", "'.join(_hashers.keys()) + '"'
raise Exception("Invalid hasher value specified. Select from " +
valid)

def _set_keys(self, pubkey_x, pubkey_y, privkey):
if self.raw_check_key(privkey, pubkey_x, pubkey_y) < 0:
self.pubkey_x = None
Expand Down Expand Up @@ -212,17 +227,21 @@ def _old_decode_privkey(privkey):
return curve, privkey, i

def _generate(self):
key = pub_key_x = pub_key_y = None
try:
pub_key_x = OpenSSL.BN_new()
pub_key_y = OpenSSL.BN_new()

key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
if not key:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... "
+ OpenSSL.get_error())
if (OpenSSL.EC_KEY_generate_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_generate_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_generate_key FAIL ... " +
OpenSSL.get_error())
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " +
OpenSSL.get_error())
priv_key = OpenSSL.EC_KEY_get0_private_key(key)

group = OpenSSL.EC_KEY_get0_group(key)
Expand All @@ -233,9 +252,11 @@ def _generate(self):
pub_key_y, 0
)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())
"[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ... " +
OpenSSL.get_error())

field_size = OpenSSL.EC_GROUP_get_degree(OpenSSL.EC_KEY_get0_group(key))
field_size = OpenSSL.EC_GROUP_get_degree(
OpenSSL.EC_KEY_get0_group(key))
secret_len = int((field_size + 7) / 8)

privkey = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(priv_key))
Expand Down Expand Up @@ -271,48 +292,58 @@ def get_ecdh_key(self, pubkey, format='binary'):
return self.raw_get_ecdh_key(pubkey_x, pubkey_y)

def raw_get_ecdh_key(self, pubkey_x, pubkey_y):
other_key = other_pub_key = other_pub_key_x = other_pub_key_y = None
own_key = own_priv_key = None
try:
ecdh_keybuffer = OpenSSL.malloc(0, 32)

other_key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if other_key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
if not other_key:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... "
+ OpenSSL.get_error())

other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), 0)

other_group = OpenSSL.EC_KEY_get0_group(other_key)
other_pub_key = OpenSSL.EC_POINT_new(other_group)
if (other_pub_key == None):
raise Exception("[OpenSSl] EC_POINT_new FAIL ... " + OpenSSL.get_error())
if other_pub_key is None:
raise Exception("[OpenSSl] EC_POINT_new FAIL ... " +
OpenSSL.get_error())

if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(other_group,
other_pub_key,
other_pub_key_x,
other_pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ..." + OpenSSL.get_error())
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ..." +
OpenSSL.get_error())
if (OpenSSL.EC_KEY_set_public_key(other_key, other_pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " +
OpenSSL.get_error())
if (OpenSSL.EC_KEY_check_key(other_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " +
OpenSSL.get_error())

own_key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if own_key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
if not own_key:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... "
+ OpenSSL.get_error())
own_priv_key = OpenSSL.BN_bin2bn(
self.privkey, len(self.privkey), 0)

if (OpenSSL.EC_KEY_set_private_key(own_key, own_priv_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " +
OpenSSL.get_error())

OpenSSL.ECDH_set_method(own_key, OpenSSL.ECDH_OpenSSL())
ecdh_keylen = OpenSSL.ECDH_compute_key(
ecdh_keybuffer, 32, other_pub_key, own_key, 0)

if ecdh_keylen != 32:
raise Exception("[OpenSSL] ECDH keylen FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] ECDH keylen FAIL ... " +
OpenSSL.get_error())

return ecdh_keybuffer.raw

Expand All @@ -338,10 +369,12 @@ def check_key(self, privkey, pubkey):

def raw_check_key(self, privkey, pubkey_x, pubkey_y):
curve = self.curve
key = pub_key_x = pub_key_y = pub_key = priv_key = None
try:
key = OpenSSL.EC_KEY_new_by_curve_name(curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
if not key:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... "
+ OpenSSL.get_error())
if privkey is not None:
priv_key = OpenSSL.BN_bin2bn(privkey, len(privkey), 0)
pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
Expand All @@ -350,7 +383,8 @@ def raw_check_key(self, privkey, pubkey_x, pubkey_y):
if privkey is not None:
if (OpenSSL.EC_KEY_set_private_key(key, priv_key)) == 0:
raise Exception(
"[OpenSSL] EC_KEY_set_private_key FAIL ... " + OpenSSL.get_error())
"[OpenSSL] EC_KEY_set_private_key FAIL ... " +
OpenSSL.get_error())

group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_POINT_new(group)
Expand All @@ -360,11 +394,14 @@ def raw_check_key(self, privkey, pubkey_x, pubkey_y):
pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... "
+ OpenSSL.get_error())
if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " +
OpenSSL.get_error())
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " +
OpenSSL.get_error())
return 0

finally:
Expand All @@ -379,6 +416,7 @@ def sign(self, inputb):
"""
Sign the input with ECDSA method and returns the signature
"""
key = pub_key = pub_key_x = pub_key_y = priv_key = md_ctx = None
try:
size = len(inputb)
buff = OpenSSL.malloc(inputb, size)
Expand All @@ -389,15 +427,17 @@ def sign(self, inputb):
sig = OpenSSL.malloc(0, 151)

key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
if not key:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... "
+ OpenSSL.get_error())

priv_key = OpenSSL.BN_bin2bn(self.privkey, len(self.privkey), 0)
pub_key_x = OpenSSL.BN_bin2bn(self.pubkey_x, len(self.pubkey_x), 0)
pub_key_y = OpenSSL.BN_bin2bn(self.pubkey_y, len(self.pubkey_y), 0)

if (OpenSSL.EC_KEY_set_private_key(key, priv_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ... " +
OpenSSL.get_error())

group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_POINT_new(group)
Expand All @@ -407,22 +447,32 @@ def sign(self, inputb):
pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())
if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error())
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error())

OpenSSL.EVP_MD_CTX_init(md_ctx)
OpenSSL.EVP_DigestInit_ex(md_ctx, OpenSSL.EVP_sha256(), None)

if (OpenSSL.EVP_DigestUpdate(md_ctx, buff, size)) == 0:
raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ... " + OpenSSL.get_error())
OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len)
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " +
OpenSSL.get_error())
if OpenSSL.EC_KEY_set_public_key(key, pub_key) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " +
OpenSSL.get_error())
if OpenSSL.EC_KEY_check_key(key) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " +
OpenSSL.get_error())

if OpenSSL.EVP_MD_CTX_init(md_ctx) == 0:
raise Exception("[OpenSSL] Failed to create context ..." +
OpenSSL.get_error())
if OpenSSL.EVP_DigestInit_ex(md_ctx, self.hasher, None) == 0:
raise Exception("[OpenSSL] Could not initialize digest ..." +
OpenSSL.get_error())
if OpenSSL.EVP_DigestUpdate(md_ctx, buff, size) == 0:
raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ..." +
OpenSSL.get_error())
if OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len) == 0:
raise Exception("[OpenSSL] Could not finalize digest ..." +
OpenSSL.get_error())
OpenSSL.ECDSA_sign(0, digest, dgst_len.contents, sig, siglen, key)
if (OpenSSL.ECDSA_verify(0, digest, dgst_len.contents, sig,
siglen.contents, key)) != 1:
raise Exception("[OpenSSL] ECDSA_verify FAIL ... " + OpenSSL.get_error())
raise Exception("[OpenSSL] ECDSA_verify FAIL ... " +
OpenSSL.get_error())

return sig.raw[0:siglen.contents.value]

Expand All @@ -439,40 +489,54 @@ def verify(self, sig, inputb):
Verify the signature with the input and the local public key.
Returns a boolean
"""
key = pub_key = pub_key_x = pub_key_y = md_ctx = None
try:
bsig = OpenSSL.malloc(sig, len(sig))
binputb = OpenSSL.malloc(inputb, len(inputb))
digest = OpenSSL.malloc(0, 64)
dgst_len = OpenSSL.pointer(OpenSSL.c_int(0))
md_ctx = OpenSSL.EVP_MD_CTX_create()
if not md_ctx:
raise Exception("[OpenSSL] Failed to create context ..."+
OpenSSL.get_error())

key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)

if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... " + OpenSSL.get_error())
if not key:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ... "
+ OpenSSL.get_error())

pub_key_x = OpenSSL.BN_bin2bn(self.pubkey_x, len(self.pubkey_x), 0)
pub_key_y = OpenSSL.BN_bin2bn(self.pubkey_y, len(self.pubkey_y), 0)
group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_POINT_new(group)

if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(group, pub_key,
pub_key_x,
pub_key_y,
0)) == 0:
if OpenSSL.EC_POINT_set_affine_coordinates_GFp(group,
pub_key,
pub_key_x,
pub_key_y,
0) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... " + OpenSSL.get_error())
if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " + OpenSSL.get_error())
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " + OpenSSL.get_error())

OpenSSL.EVP_MD_CTX_init(md_ctx)
OpenSSL.EVP_DigestInit_ex(md_ctx, OpenSSL.EVP_sha256(), None)
if (OpenSSL.EVP_DigestUpdate(md_ctx, binputb, len(inputb))) == 0:
raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ... " + OpenSSL.get_error())

OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len)
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ... "
+ OpenSSL.get_error())
if OpenSSL.EC_KEY_set_public_key(key, pub_key) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ... " +
OpenSSL.get_error())
if OpenSSL.EC_KEY_check_key(key) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ... " +
OpenSSL.get_error())
if OpenSSL.EVP_MD_CTX_init(md_ctx) == 0:
raise Exception("[OpenSSL] Failed to create context ..." +
OpenSSL.get_error())
if OpenSSL.EVP_DigestInit_ex(md_ctx, self.hasher, None) == 0:
raise Exception("[OpenSSL] Could not initialize digest ..." +
OpenSSL.get_error())
if OpenSSL.EVP_DigestUpdate(md_ctx, binputb, len(inputb)) == 0:
raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ... " +
OpenSSL.get_error())
if OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len) == 0:
raise Exception("[OpenSSL] Could not finalize digest ..." +
OpenSSL.get_error())
ret = OpenSSL.ECDSA_verify(
0, digest, dgst_len.contents, bsig, len(sig), key)

Expand All @@ -483,7 +547,6 @@ def verify(self, sig, inputb):
return False # Bad signature !
else:
return True # Good
return False

finally:
OpenSSL.EC_KEY_free(key)
Expand Down
Loading

0 comments on commit 0744f89

Please sign in to comment.