diff --git a/ecdsa/__init__.py b/ecdsa/__init__.py index e834b3a8..5c2c3e7c 100644 --- a/ecdsa/__init__.py +++ b/ecdsa/__init__.py @@ -1,9 +1,12 @@ __all__ = ["curves", "der", "ecdsa", "ellipticcurve", "keys", "numbertheory", "test_pyecdsa", "util", "six"] -from .keys import SigningKey, VerifyingKey, BadSignatureError, BadDigestError +from .keys import SigningKey, VerifyingKey, BadSignatureError, BadDigestError,\ + MalformedPointError from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1 +from .der import UnexpectedDER _hush_pyflakes = [SigningKey, VerifyingKey, BadSignatureError, BadDigestError, + MalformedPointError, UnexpectedDER, NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1] del _hush_pyflakes diff --git a/ecdsa/keys.py b/ecdsa/keys.py index 72c0c18f..45198170 100644 --- a/ecdsa/keys.py +++ b/ecdsa/keys.py @@ -3,6 +3,7 @@ from . import ecdsa from . import der from . import rfc6979 +from . import ellipticcurve from .curves import NIST192p, find_curve from .util import string_to_number, number_to_string, randrange from .util import sigencode_string, sigdecode_string @@ -15,6 +16,11 @@ class BadSignatureError(Exception): class BadDigestError(Exception): pass + +class MalformedPointError(AssertionError): + pass + + class VerifyingKey: def __init__(self, _error__please_use_generate=None): if not _error__please_use_generate: @@ -33,17 +39,21 @@ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1): def from_string(klass, string, curve=NIST192p, hashfunc=sha1, validate_point=True): order = curve.order - assert len(string) == curve.verifying_key_length, \ - (len(string), curve.verifying_key_length) + if len(string) != curve.verifying_key_length: + raise MalformedPointError( + "Malformed encoding of public point. Expected string {0} bytes" + " long, received {1} bytes long string".format( + curve.verifying_key_length, len(string))) xs = string[:curve.baselen] ys = string[curve.baselen:] - assert len(xs) == curve.baselen, (len(xs), curve.baselen) - assert len(ys) == curve.baselen, (len(ys), curve.baselen) + if len(xs) != curve.baselen: + raise MalformedPointError("Unexpected length of encoded x") + if len(ys) != curve.baselen: + raise MalformedPointError("Unexpected length of encoded y") x = string_to_number(xs) y = string_to_number(ys) - if validate_point: - assert ecdsa.point_is_valid(curve.generator, x, y) - from . import ellipticcurve + if validate_point and not ecdsa.point_is_valid(curve.generator, x, y): + raise MalformedPointError("Point does not lie on the curve") point = ellipticcurve.Point(curve.curve, x, y, order) return klass.from_public_point(point, curve, hashfunc) @@ -65,13 +75,18 @@ def from_der(klass, string): if empty != b(""): raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" % binascii.hexlify(empty)) - assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey) + if oid_pk != oid_ecPublicKey: + raise der.UnexpectedDER( + "Unexpected OID in encoding, received {0}, expected {1}" + .format(oid_pk, oid_ecPublicKey)) curve = find_curve(oid_curve) point_str, empty = der.remove_bitstring(point_str_bitstring) if empty != b(""): raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" % binascii.hexlify(empty)) - assert point_str.startswith(b("\x00\x04")) + if not point_str.startswith(b("\x00\x04")): + raise der.UnexpectedDER( + "Unsupported or invalid encoding of pubcli key") return klass.from_string(point_str[2:], curve) def to_string(self): @@ -137,7 +152,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1): self.default_hashfunc = hashfunc self.baselen = curve.baselen n = curve.order - assert 1 <= secexp < n + if not 1 <= secexp < n: + raise MalformedPointError( + "Invalid value for secexp, expected integer between 1 and {0}" + .format(n)) pubkey_point = curve.generator*secexp pubkey = ecdsa.Public_key(curve.generator, pubkey_point) pubkey.order = n @@ -149,7 +167,10 @@ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1): @classmethod def from_string(klass, string, curve=NIST192p, hashfunc=sha1): - assert len(string) == curve.baselen, (len(string), curve.baselen) + if len(string) != curve.baselen: + raise MalformedPointError( + "Invalid length of private key, received {0}, expected {1}" + .format(len(string), curve.baselen)) secexp = string_to_number(string) return klass.from_secret_exponent(secexp, curve, hashfunc) diff --git a/ecdsa/test_pyecdsa.py b/ecdsa/test_pyecdsa.py index 326cac4a..c750f5a3 100644 --- a/ecdsa/test_pyecdsa.py +++ b/ecdsa/test_pyecdsa.py @@ -1,6 +1,9 @@ from __future__ import with_statement, division -import unittest +try: + import unittest2 as unittest +except ImportError: + import unittest import os import time import shutil @@ -10,15 +13,17 @@ from .six import b, print_, binary_type from .keys import SigningKey, VerifyingKey -from .keys import BadSignatureError +from .keys import BadSignatureError, MalformedPointError, BadDigestError from . import util from .util import sigencode_der, sigencode_strings from .util import sigdecode_der, sigdecode_strings +from .util import encoded_oid_ecPublicKey, MalformedSignature from .curves import Curve, UnknownCurveError from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1 from .ellipticcurve import Point from . import der from . import rfc6979 +from . import ecdsa class SubprocessError(Exception): pass @@ -258,6 +263,47 @@ def order(self): return 123456789 pub2 = VerifyingKey.from_pem(pem) self.assertTruePubkeysEqual(pub1, pub2) + def test_vk_from_der_garbage_after_curve_oid(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + \ + b('garbage') + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\x00\xff') + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(der.UnexpectedDER): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_invalid_key_type(self): + type_oid_der = der.encode_oid(*(1, 2, 3)) + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\x00\xff') + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(der.UnexpectedDER): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_garbage_after_point_string(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\x00\xff') + b('garbage') + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(der.UnexpectedDER): + VerifyingKey.from_der(to_decode) + + def test_vk_from_der_invalid_bitstring(self): + type_oid_der = encoded_oid_ecPublicKey + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + enc_type_der = der.encode_sequence(type_oid_der, curve_oid_der) + point_der = der.encode_bitstring(b'\x08\xff') + to_decode = der.encode_sequence(enc_type_der, point_der) + + with self.assertRaises(der.UnexpectedDER): + VerifyingKey.from_der(to_decode) + def test_signature_strings(self): priv1 = SigningKey.generate() pub1 = priv1.get_verifying_key() @@ -281,6 +327,86 @@ def test_signature_strings(self): self.assertEqual(type(sig_der), binary_type) self.assertTrue(pub1.verify(sig_der, data, sigdecode=sigdecode_der)) + def test_sig_decode_strings_with_invalid_count(self): + with self.assertRaises(MalformedSignature): + sigdecode_strings([b('one'), b('two'), b('three')], 0xff) + + def test_sig_decode_strings_with_wrong_r_len(self): + with self.assertRaises(MalformedSignature): + sigdecode_strings([b('one'), b('two')], 0xff) + + def test_sig_decode_strings_with_wrong_s_len(self): + with self.assertRaises(MalformedSignature): + sigdecode_strings([b('\xa0'), b('\xb0\xff')], 0xff) + + def test_verify_with_too_long_input(self): + sk = SigningKey.generate() + vk = sk.verifying_key + + with self.assertRaises(BadDigestError): + vk.verify_digest(None, b('\x00') * 128) + + def test_sk_from_secret_exponent_with_wrong_sec_exponent(self): + with self.assertRaises(MalformedPointError): + SigningKey.from_secret_exponent(0) + + def test_sk_from_string_with_wrong_len_string(self): + with self.assertRaises(MalformedPointError): + SigningKey.from_string(b('\x01')) + + def test_sk_from_der_with_junk_after_sequence(self): + ver_der = der.encode_integer(1) + to_decode = der.encode_sequence(ver_der) + b('garbage') + + with self.assertRaises(der.UnexpectedDER): + SigningKey.from_der(to_decode) + + def test_sk_from_der_with_wrong_version(self): + ver_der = der.encode_integer(0) + to_decode = der.encode_sequence(ver_der) + + with self.assertRaises(der.UnexpectedDER): + SigningKey.from_der(to_decode) + + def test_sk_from_der_invalid_const_tag(self): + ver_der = der.encode_integer(1) + privkey_der = der.encode_octet_string(b('\x00\xff')) + curve_oid_der = der.encode_oid(*(1, 2, 3)) + const_der = der.encode_constructed(1, curve_oid_der) + to_decode = der.encode_sequence(ver_der, privkey_der, const_der, + curve_oid_der) + + with self.assertRaises(der.UnexpectedDER): + SigningKey.from_der(to_decode) + + def test_sk_from_der_garbage_after_privkey_oid(self): + ver_der = der.encode_integer(1) + privkey_der = der.encode_octet_string(b('\x00\xff')) + curve_oid_der = der.encode_oid(*(1, 2, 3)) + b('garbage') + const_der = der.encode_constructed(0, curve_oid_der) + to_decode = der.encode_sequence(ver_der, privkey_der, const_der, + curve_oid_der) + + with self.assertRaises(der.UnexpectedDER): + SigningKey.from_der(to_decode) + + def test_sk_from_der_with_short_privkey(self): + ver_der = der.encode_integer(1) + privkey_der = der.encode_octet_string(b('\x00\xff')) + curve_oid_der = der.encode_oid(*(1, 2, 840, 10045, 3, 1, 1)) + const_der = der.encode_constructed(0, curve_oid_der) + to_decode = der.encode_sequence(ver_der, privkey_der, const_der, + curve_oid_der) + + sk = SigningKey.from_der(to_decode) + self.assertEqual(sk.privkey.secret_multiplier, 255) + + def test_sign_with_too_long_hash(self): + sk = SigningKey.from_secret_exponent(12) + + with self.assertRaises(BadDigestError): + sk.sign_digest(b('\xff') * 64) + def test_hashfunc(self): sk = SigningKey.generate(curve=NIST256p, hashfunc=sha256) data = b("security level is 128 bits") @@ -299,6 +425,49 @@ def test_hashfunc(self): curve=NIST256p) self.assertTrue(vk3.verify(sig, data, hashfunc=sha256)) + def test_decoding_with_malformed_uncompressed(self): + enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3' + '\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4' + 'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*') + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(b('\x02') + enc) + + def test_decoding_with_point_not_on_curve(self): + enc = b('\x0c\xe0\x1d\xe0d\x1c\x8eS\x8a\xc0\x9eK\xa8x !\xd5\xc2\xc3' + '\xfd\xc8\xa0c\xff\xfb\x02\xb9\xc4\x84)\x1a\x0f\x8b\x87\xa4' + 'z\x8a#\xb5\x97\xecO\xb6\xa0HQ\x89*') + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(enc[:47] + b('\x00')) + + def test_decoding_with_point_at_infinity(self): + # decoding it is unsupported, as it's not necessary to encode it + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(b('\x00')) + + def test_from_string_with_invalid_curve_too_short_ver_key_len(self): + # both verifying_key_length and baselen are calculated internally + # by the Curve constructor, but since we depend on them verify + # that inconsistent values are detected + curve = Curve("test", ecdsa.curve_192, ecdsa.generator_192, (1, 2)) + curve.verifying_key_length = 16 + curve.baselen = 32 + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(b('\x00')*16, curve) + + def test_from_string_with_invalid_curve_too_long_ver_key_len(self): + # both verifying_key_length and baselen are calculated internally + # by the Curve constructor, but since we depend on them verify + # that inconsistent values are detected + curve = Curve("test", ecdsa.curve_192, ecdsa.generator_192, (1, 2)) + curve.verifying_key_length = 16 + curve.baselen = 16 + + with self.assertRaises(MalformedPointError): + VerifyingKey.from_string(b('\x00')*16, curve) + class OpenSSL(unittest.TestCase): # test interoperability with OpenSSL tools. Note that openssl's ECDSA