From 98e0bea0c4b23a127ba9a1077bfe31198b23acf1 Mon Sep 17 00:00:00 2001 From: Marko Mrdjenovic Date: Tue, 10 Jan 2017 16:27:55 +0100 Subject: [PATCH 1/4] make algorithms extendable --- jose/constants.py | 32 ++++++++++++++++++++++++++------ jose/jwk.py | 21 +++++++-------------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/jose/constants.py b/jose/constants.py index 01a2fe8c..1536d47a 100644 --- a/jose/constants.py +++ b/jose/constants.py @@ -1,6 +1,7 @@ import hashlib -class ALGORITHMS(object): + +class Algorithms(object): NONE = 'none' HS256 = 'HS256' HS384 = 'HS384' @@ -12,13 +13,13 @@ class ALGORITHMS(object): ES384 = 'ES384' ES512 = 'ES512' - HMAC = (HS256, HS384, HS512) - RSA = (RS256, RS384, RS512) - EC = (ES256, ES384, ES512) + HMAC = set([HS256, HS384, HS512]) + RSA = set([RS256, RS384, RS512]) + EC = set([ES256, ES384, ES512]) - SUPPORTED = HMAC + RSA + EC + SUPPORTED = HMAC.union(RSA).union(EC) - ALL = SUPPORTED + (NONE, ) + ALL = SUPPORTED.union([NONE]) HASHES = { HS256: hashlib.sha256, @@ -31,3 +32,22 @@ class ALGORITHMS(object): ES384: hashlib.sha384, ES512: hashlib.sha512, } + + KEYS = {} + + def get_key(self, algorithm): + from jose.jwk import HMACKey, RSAKey, ECKey + if algorithm in self.KEYS: + return self.KEYS[algorithm] + elif algorithm in self.HMAC: + return HMACKey + elif algorithm in self.RSA: + return RSAKey + elif algorithm in self.EC: + return ECKey + + def register_key(self, algorithm, key_class): + self.KEYS[algorithm] = key_class + self.SUPPORTED.add(algorithm) + +ALGORITHMS = Algorithms() diff --git a/jose/jwk.py b/jose/jwk.py index 5edfc981..9b74d676 100644 --- a/jose/jwk.py +++ b/jose/jwk.py @@ -60,14 +60,10 @@ def construct(key_data, algorithm=None): if not algorithm: raise JWKError('Unable to find a algorithm for key: %s' % key_data) - if algorithm in ALGORITHMS.HMAC: - return HMACKey(key_data, algorithm) - - if algorithm in ALGORITHMS.RSA: - return RSAKey(key_data, algorithm) - - if algorithm in ALGORITHMS.EC: - return ECKey(key_data, algorithm) + key_class = ALGORITHMS.get_key(algorithm) + if not key_class: + raise JWKError('Unable to find a algorithm for key: %s' % key_data) + return key_class(key_data, algorithm) def get_algorithm_object(algorithm): @@ -112,13 +108,12 @@ class HMACKey(Key): SHA256 = hashlib.sha256 SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - valid_hash_algs = ALGORITHMS.HMAC prepared_key = None hash_alg = None def __init__(self, key, algorithm): - if algorithm not in self.valid_hash_algs: + if algorithm not in ALGORITHMS.HMAC: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) @@ -174,14 +169,13 @@ class RSAKey(Key): SHA256 = Crypto.Hash.SHA256 SHA384 = Crypto.Hash.SHA384 SHA512 = Crypto.Hash.SHA512 - valid_hash_algs = ALGORITHMS.RSA prepared_key = None hash_alg = None def __init__(self, key, algorithm): - if algorithm not in self.valid_hash_algs: + if algorithm not in ALGORITHMS.RSA: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) @@ -257,7 +251,6 @@ class ECKey(Key): SHA256 = hashlib.sha256 SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - valid_hash_algs = ALGORITHMS.EC curve_map = { SHA256: ecdsa.curves.NIST256p, @@ -270,7 +263,7 @@ class ECKey(Key): curve = None def __init__(self, key, algorithm): - if algorithm not in self.valid_hash_algs: + if algorithm not in ALGORITHMS.EC: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) From 072cf8be770841e0f3b6fb23438766989445668d Mon Sep 17 00:00:00 2001 From: Marko Mrdjenovic Date: Tue, 10 Jan 2017 17:31:55 +0100 Subject: [PATCH 2/4] cleaned jwk, added class checking on key, added tests --- jose/constants.py | 10 ++++++-- jose/jwk.py | 23 ++++-------------- tests/algorithms/test_base.py | 44 ++++++++++++++++++++++------------- tests/test_jwk.py | 8 +++---- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/jose/constants.py b/jose/constants.py index 1536d47a..70f2836d 100644 --- a/jose/constants.py +++ b/jose/constants.py @@ -45,9 +45,15 @@ def get_key(self, algorithm): return RSAKey elif algorithm in self.EC: return ECKey + return None def register_key(self, algorithm, key_class): - self.KEYS[algorithm] = key_class - self.SUPPORTED.add(algorithm) + from jose.jwk import Key + if issubclass(key_class, Key): + self.KEYS[algorithm] = key_class + self.SUPPORTED.add(algorithm) + return True + else: + return False ALGORITHMS = Algorithms() diff --git a/jose/jwk.py b/jose/jwk.py index 9b74d676..e42f41a9 100644 --- a/jose/jwk.py +++ b/jose/jwk.py @@ -87,11 +87,8 @@ class Key(object): """ A simple interface for implementing JWK keys. """ - prepared_key = None - hash_alg = None - - def _process_jwk(self, jwk_dict): - raise NotImplementedError() + def __init__(self, key, algorithm): + pass def sign(self, msg): raise NotImplementedError() @@ -109,9 +106,6 @@ class HMACKey(Key): SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - prepared_key = None - hash_alg = None - def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.HMAC: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) @@ -170,9 +164,6 @@ class RSAKey(Key): SHA384 = Crypto.Hash.SHA384 SHA512 = Crypto.Hash.SHA512 - prepared_key = None - hash_alg = None - def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.RSA: @@ -236,7 +227,7 @@ def verify(self, msg, sig): try: return PKCS1_v1_5.new(self.prepared_key).verify(self.hash_alg.new(msg), sig) except Exception as e: - raise JWKError(e) + return False class ECKey(Key): @@ -252,22 +243,18 @@ class ECKey(Key): SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 - curve_map = { + CURVE_MAP = { SHA256: ecdsa.curves.NIST256p, SHA384: ecdsa.curves.NIST384p, SHA512: ecdsa.curves.NIST521p, } - prepared_key = None - hash_alg = None - curve = None - def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.EC: raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) self.hash_alg = get_algorithm_object(algorithm) - self.curve = self.curve_map.get(self.hash_alg) + self.curve = self.CURVE_MAP.get(self.hash_alg) if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)): self.prepared_key = key diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index 8d3a7f02..aeb737f2 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -1,25 +1,37 @@ -# from jose.jwk import Key -# from jose.exceptions import JOSEError +from jose.jwk import Key, HMACKey, RSAKey, ECKey +from jose.constants import ALGORITHMS -# import pytest +import pytest -# @pytest.fixture -# def alg(): -# return Key() +@pytest.fixture +def alg(): + return Key("key", "ALG") -# class TestBaseAlgorithm: +class TestBaseAlgorithm: -# def test_prepare_key_is_interface(self, alg): -# with pytest.raises(JOSEError): -# alg.prepare_key('secret') + def test_sign_is_interface(self, alg): + with pytest.raises(NotImplementedError): + alg.sign('msg') -# def test_sign_is_interface(self, alg): -# with pytest.raises(JOSEError): -# alg.sign('msg', 'secret') + def test_verify_is_interface(self, alg): + with pytest.raises(NotImplementedError): + alg.verify('msg', 'sig') -# def test_verify_is_interface(self, alg): -# with pytest.raises(JOSEError): -# alg.verify('msg', 'secret', 'sig') + +class TestAlgorithms: + + def test_get_key(self): + assert ALGORITHMS.get_key("HS256") == HMACKey + assert ALGORITHMS.get_key("RS256") == RSAKey + assert ALGORITHMS.get_key("ES256") == ECKey + + assert ALGORITHMS.get_key("NONEXISTENT") == None + + def test_register_key(self): + assert ALGORITHMS.register_key("ALG", Key) == True + assert ALGORITHMS.get_key("ALG") == Key + + assert ALGORITHMS.register_key("ALG", object) == False diff --git a/tests/test_jwk.py b/tests/test_jwk.py index d29e3213..6fe64c3c 100644 --- a/tests/test_jwk.py +++ b/tests/test_jwk.py @@ -35,10 +35,7 @@ class TestJWK: def test_interface(self): - key = jwk.Key() - - with pytest.raises(NotImplementedError): - key._process_jwk(None) + key = jwk.Key("key", "ALG") with pytest.raises(NotImplementedError): key.sign('') @@ -115,3 +112,6 @@ def test_construct_from_jwk_missing_alg(self): with pytest.raises(JWKError): key = jwk.construct(hmac_key) + + with pytest.raises(JWKError): + key = jwk.construct("key", algorithm="NONEXISTENT") From 163b01b8fef83e098821f9c99e07c4927e379bd6 Mon Sep 17 00:00:00 2001 From: Marko Mrdjenovic Date: Sun, 29 Jan 2017 17:30:05 +0100 Subject: [PATCH 3/4] move get_key to jwk, return typeerror in register_key; updated tests --- jose/constants.py | 24 ++++++------------------ jose/jwk.py | 14 +++++++++++++- tests/algorithms/test_base.py | 17 ++++++----------- tests/test_jwk.py | 7 +++++++ 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/jose/constants.py b/jose/constants.py index 70f2836d..30b0b42d 100644 --- a/jose/constants.py +++ b/jose/constants.py @@ -35,25 +35,13 @@ class Algorithms(object): KEYS = {} - def get_key(self, algorithm): - from jose.jwk import HMACKey, RSAKey, ECKey - if algorithm in self.KEYS: - return self.KEYS[algorithm] - elif algorithm in self.HMAC: - return HMACKey - elif algorithm in self.RSA: - return RSAKey - elif algorithm in self.EC: - return ECKey - return None - def register_key(self, algorithm, key_class): from jose.jwk import Key - if issubclass(key_class, Key): - self.KEYS[algorithm] = key_class - self.SUPPORTED.add(algorithm) - return True - else: - return False + if not issubclass(key_class, Key): + raise TypeError("Key class not a subclass of jwk.Key") + self.KEYS[algorithm] = key_class + self.SUPPORTED.add(algorithm) + return True + ALGORITHMS = Algorithms() diff --git a/jose/jwk.py b/jose/jwk.py index e42f41a9..9df8df4c 100644 --- a/jose/jwk.py +++ b/jose/jwk.py @@ -47,6 +47,18 @@ def base64_to_long(data): return int_arr_to_long(struct.unpack('%sB' % len(_d), _d)) +def get_key(algorithm): + if algorithm in ALGORITHMS.KEYS: + return ALGORITHMS.KEYS[algorithm] + elif algorithm in ALGORITHMS.HMAC: + return HMACKey + elif algorithm in ALGORITHMS.RSA: + return RSAKey + elif algorithm in ALGORITHMS.EC: + return ECKey + return None + + def construct(key_data, algorithm=None): """ Construct a Key object for the given algorithm with the given @@ -60,7 +72,7 @@ def construct(key_data, algorithm=None): if not algorithm: raise JWKError('Unable to find a algorithm for key: %s' % key_data) - key_class = ALGORITHMS.get_key(algorithm) + key_class = get_key(algorithm) if not key_class: raise JWKError('Unable to find a algorithm for key: %s' % key_data) return key_class(key_data, algorithm) diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index aeb737f2..f3d9767e 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -23,15 +23,10 @@ def test_verify_is_interface(self, alg): class TestAlgorithms: - def test_get_key(self): - assert ALGORITHMS.get_key("HS256") == HMACKey - assert ALGORITHMS.get_key("RS256") == RSAKey - assert ALGORITHMS.get_key("ES256") == ECKey - - assert ALGORITHMS.get_key("NONEXISTENT") == None - def test_register_key(self): - assert ALGORITHMS.register_key("ALG", Key) == True - assert ALGORITHMS.get_key("ALG") == Key - - assert ALGORITHMS.register_key("ALG", object) == False + assert ALGORITHMS.register_key("ALG", Key) + from jose.jwk import get_key + assert get_key("ALG") == Key + + with pytest.raises(TypeError): + assert ALGORITHMS.register_key("ALG", object) diff --git a/tests/test_jwk.py b/tests/test_jwk.py index 6fe64c3c..0d5a363d 100644 --- a/tests/test_jwk.py +++ b/tests/test_jwk.py @@ -115,3 +115,10 @@ def test_construct_from_jwk_missing_alg(self): with pytest.raises(JWKError): key = jwk.construct("key", algorithm="NONEXISTENT") + + def test_get_key(self): + assert jwk.get_key("HS256") == jwk.HMACKey + assert jwk.get_key("RS256") == jwk.RSAKey + assert jwk.get_key("ES256") == jwk.ECKey + + assert jwk.get_key("NONEXISTENT") == None From fae83ff9333358e2a233d2eac9f7ef5f3ba1269e Mon Sep 17 00:00:00 2001 From: Marko Mrdjenovic Date: Mon, 30 Jan 2017 09:40:26 +0100 Subject: [PATCH 4/4] moved register to jwk --- jose/constants.py | 8 -------- jose/jwk.py | 8 ++++++++ tests/algorithms/test_base.py | 14 +------------- tests/test_jwk.py | 9 +++++++-- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/jose/constants.py b/jose/constants.py index 30b0b42d..c446fd8d 100644 --- a/jose/constants.py +++ b/jose/constants.py @@ -35,13 +35,5 @@ class Algorithms(object): KEYS = {} - def register_key(self, algorithm, key_class): - from jose.jwk import Key - if not issubclass(key_class, Key): - raise TypeError("Key class not a subclass of jwk.Key") - self.KEYS[algorithm] = key_class - self.SUPPORTED.add(algorithm) - return True - ALGORITHMS = Algorithms() diff --git a/jose/jwk.py b/jose/jwk.py index 9df8df4c..fbc19071 100644 --- a/jose/jwk.py +++ b/jose/jwk.py @@ -59,6 +59,14 @@ def get_key(algorithm): return None +def register_key(algorithm, key_class): + if not issubclass(key_class, Key): + raise TypeError("Key class not a subclass of jwk.Key") + ALGORITHMS.KEYS[algorithm] = key_class + ALGORITHMS.SUPPORTED.add(algorithm) + return True + + def construct(key_data, algorithm=None): """ Construct a Key object for the given algorithm with the given diff --git a/tests/algorithms/test_base.py b/tests/algorithms/test_base.py index f3d9767e..1d8b0f33 100644 --- a/tests/algorithms/test_base.py +++ b/tests/algorithms/test_base.py @@ -1,6 +1,4 @@ - -from jose.jwk import Key, HMACKey, RSAKey, ECKey -from jose.constants import ALGORITHMS +from jose.jwk import Key import pytest @@ -20,13 +18,3 @@ def test_verify_is_interface(self, alg): with pytest.raises(NotImplementedError): alg.verify('msg', 'sig') - -class TestAlgorithms: - - def test_register_key(self): - assert ALGORITHMS.register_key("ALG", Key) - from jose.jwk import get_key - assert get_key("ALG") == Key - - with pytest.raises(TypeError): - assert ALGORITHMS.register_key("ALG", object) diff --git a/tests/test_jwk.py b/tests/test_jwk.py index 0d5a363d..195b98b1 100644 --- a/tests/test_jwk.py +++ b/tests/test_jwk.py @@ -1,10 +1,8 @@ - from jose import jwk from jose.exceptions import JWKError import pytest - hmac_key = { "kty": "oct", "kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037", @@ -122,3 +120,10 @@ def test_get_key(self): assert jwk.get_key("ES256") == jwk.ECKey assert jwk.get_key("NONEXISTENT") == None + + def test_register_key(self): + assert jwk.register_key("ALG", jwk.Key) + assert jwk.get_key("ALG") == jwk.Key + + with pytest.raises(TypeError): + assert jwk.register_key("ALG", object)