diff --git a/jose/constants.py b/jose/constants.py index 01a2fe8c..c446fd8d 100644 --- a/jose/constants.py +++ b/jose/constants.py @@ -1,6 +1,7 @@ import hashlib -class ALGORITHMS(object): + +class Algorithms(object): NONE = 'none' HS256 = 'HS256' HS384 = 'HS384' @@ -12,13 +13,13 @@ class ALGORITHMS(object): ES384 = 'ES384' ES512 = 'ES512' - HMAC = (HS256, HS384, HS512) - RSA = (RS256, RS384, RS512) - EC = (ES256, ES384, ES512) + HMAC = set([HS256, HS384, HS512]) + RSA = set([RS256, RS384, RS512]) + EC = set([ES256, ES384, ES512]) - SUPPORTED = HMAC + RSA + EC + SUPPORTED = HMAC.union(RSA).union(EC) - ALL = SUPPORTED + (NONE, ) + ALL = SUPPORTED.union([NONE]) HASHES = { HS256: hashlib.sha256, @@ -31,3 +32,8 @@ class ALGORITHMS(object): ES384: hashlib.sha384, ES512: hashlib.sha512, } + + KEYS = {} + + +ALGORITHMS = Algorithms() diff --git a/jose/jwk.py b/jose/jwk.py index 5edfc981..fbc19071 100644 --- a/jose/jwk.py +++ b/jose/jwk.py @@ -47,6 +47,26 @@ def base64_to_long(data): return int_arr_to_long(struct.unpack('%sB' % len(_d), _d)) +def get_key(algorithm): + if algorithm in ALGORITHMS.KEYS: + return ALGORITHMS.KEYS[algorithm] + elif algorithm in ALGORITHMS.HMAC: + return HMACKey + elif algorithm in ALGORITHMS.RSA: + return RSAKey + elif algorithm in ALGORITHMS.EC: + return ECKey + return None + + +def register_key(algorithm, key_class): + if not issubclass(key_class, Key): + raise TypeError("Key class not a subclass of jwk.Key") + ALGORITHMS.KEYS[algorithm] = key_class + ALGORITHMS.SUPPORTED.add(algorithm) + return True + + def construct(key_data, algorithm=None): """ Construct a Key object for the given algorithm with the given @@ -60,14 +80,10 @@ def construct(key_data, algorithm=None): if not algorithm: raise JWKError('Unable to find a algorithm for key: %s' % key_data) - if algorithm in ALGORITHMS.HMAC: - return HMACKey(key_data, algorithm) - - if algorithm in ALGORITHMS.RSA: - return RSAKey(key_data, algorithm) - - if algorithm in ALGORITHMS.EC: - return ECKey(key_data, algorithm) + key_class = get_key(algorithm) + if not key_class: + raise JWKError('Unable to find a algorithm for key: %s' % key_data) + return key_class(key_data, algorithm) def get_algorithm_object(algorithm): @@ -91,11 +107,8 @@ class Key(object): """ A simple interface for implementing JWK keys. """ - prepared_key = None - hash_alg = None - - def _process_jwk(self, jwk_dict): - raise NotImplementedError() + def __init__(self, key, algorithm): + pass def sign(self, msg): raise NotImplementedError() @@ -112,13 +125,9 @@ class HMACKey(Key): SHA256 = hashlib.sha256 SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - valid_hash_algs = ALGORITHMS.HMAC - - prepared_key = None - hash_alg = None def __init__(self, key, algorithm): - if algorithm not in self.valid_hash_algs: + if algorithm not in ALGORITHMS.HMAC: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) @@ -174,14 +183,10 @@ class RSAKey(Key): SHA256 = Crypto.Hash.SHA256 SHA384 = Crypto.Hash.SHA384 SHA512 = Crypto.Hash.SHA512 - valid_hash_algs = ALGORITHMS.RSA - - prepared_key = None - hash_alg = None def __init__(self, key, algorithm): - if algorithm not in self.valid_hash_algs: + if algorithm not in ALGORITHMS.RSA: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) @@ -242,7 +247,7 @@ def verify(self, msg, sig): try: return PKCS1_v1_5.new(self.prepared_key).verify(self.hash_alg.new(msg), sig) except Exception as e: - raise JWKError(e) + return False class ECKey(Key): @@ -257,24 +262,19 @@ class ECKey(Key): SHA256 = hashlib.sha256 SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - valid_hash_algs = ALGORITHMS.EC - curve_map = { + CURVE_MAP = { SHA256: ecdsa.curves.NIST256p, SHA384: ecdsa.curves.NIST384p, SHA512: ecdsa.curves.NIST521p, } - prepared_key = None - hash_alg = None - curve = None - def __init__(self, key, algorithm): - if algorithm not in self.valid_hash_algs: + if algorithm not in ALGORITHMS.EC: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) - self.curve = self.curve_map.get(self.hash_alg) + self.curve = self.CURVE_MAP.get(self.hash_alg) if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)): self.prepared_key = key diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 8d3a7f02..1d8b0f33 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -1,25 +1,20 @@ +from jose.jwk import Key -# from jose.jwk import Key -# from jose.exceptions import JOSEError +import pytest -# import pytest +@pytest.fixture +def alg(): + return Key("key", "ALG") -# @pytest.fixture -# def alg(): -# return Key() +class TestBaseAlgorithm: -# class TestBaseAlgorithm: + def test_sign_is_interface(self, alg): + with pytest.raises(NotImplementedError): + alg.sign('msg') -# def test_prepare_key_is_interface(self, alg): -# with pytest.raises(JOSEError): -# alg.prepare_key('secret') + def test_verify_is_interface(self, alg): + with pytest.raises(NotImplementedError): + alg.verify('msg', 'sig') -# def test_sign_is_interface(self, alg): -# with pytest.raises(JOSEError): -# alg.sign('msg', 'secret') - -# def test_verify_is_interface(self, alg): -# with pytest.raises(JOSEError): -# alg.verify('msg', 'secret', 'sig') diff --git a/tests/test_jwk.py b/tests/test_jwk.py index d29e3213..195b98b1 100644 --- a/tests/test_jwk.py +++ b/tests/test_jwk.py @@ -1,10 +1,8 @@ - from jose import jwk from jose.exceptions import JWKError import pytest - hmac_key = { "kty": "oct", "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", @@ -35,10 +33,7 @@ class TestJWK: def test_interface(self): - key = jwk.Key() - - with pytest.raises(NotImplementedError): - key._process_jwk(None) + key = jwk.Key("key", "ALG") with pytest.raises(NotImplementedError): key.sign('') @@ -115,3 +110,20 @@ def test_construct_from_jwk_missing_alg(self): with pytest.raises(JWKError): key = jwk.construct(hmac_key) + + with pytest.raises(JWKError): + key = jwk.construct("key", algorithm="NONEXISTENT") + + def test_get_key(self): + assert jwk.get_key("HS256") == jwk.HMACKey + assert jwk.get_key("RS256") == jwk.RSAKey + assert jwk.get_key("ES256") == jwk.ECKey + + assert jwk.get_key("NONEXISTENT") == None + + def test_register_key(self): + assert jwk.register_key("ALG", jwk.Key) + assert jwk.get_key("ALG") == jwk.Key + + with pytest.raises(TypeError): + assert jwk.register_key("ALG", object)