Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easier extending/replacing of key algorithms #42

Merged
merged 4 commits into from
Mar 5, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib

class ALGORITHMS(object):

class Algorithms(object):
NONE = 'none'
HS256 = 'HS256'
HS384 = 'HS384'
Expand All @@ -12,13 +13,13 @@ class ALGORITHMS(object):
ES384 = 'ES384'
ES512 = 'ES512'

HMAC = (HS256, HS384, HS512)
RSA = (RS256, RS384, RS512)
EC = (ES256, ES384, ES512)
HMAC = set([HS256, HS384, HS512])
RSA = set([RS256, RS384, RS512])
EC = set([ES256, ES384, ES512])

SUPPORTED = HMAC + RSA + EC
SUPPORTED = HMAC.union(RSA).union(EC)

ALL = SUPPORTED + (NONE, )
ALL = SUPPORTED.union([NONE])

HASHES = {
HS256: hashlib.sha256,
Expand All @@ -31,3 +32,16 @@ class ALGORITHMS(object):
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}

KEYS = {}

def register_key(self, algorithm, key_class):
from jose.jwk import Key
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has a wonky dependency. Can you move it over as well?

if not issubclass(key_class, Key):
raise TypeError("Key class not a subclass of jwk.Key")
self.KEYS[algorithm] = key_class
self.SUPPORTED.add(algorithm)
return True


ALGORITHMS = Algorithms()
56 changes: 24 additions & 32 deletions jose/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def base64_to_long(data):
return int_arr_to_long(struct.unpack('%sB' % len(_d), _d))


def get_key(algorithm):
if algorithm in ALGORITHMS.KEYS:
return ALGORITHMS.KEYS[algorithm]
elif algorithm in ALGORITHMS.HMAC:
return HMACKey
elif algorithm in ALGORITHMS.RSA:
return RSAKey
elif algorithm in ALGORITHMS.EC:
return ECKey
return None


def construct(key_data, algorithm=None):
"""
Construct a Key object for the given algorithm with the given
Expand All @@ -60,14 +72,10 @@ def construct(key_data, algorithm=None):
if not algorithm:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)

if algorithm in ALGORITHMS.HMAC:
return HMACKey(key_data, algorithm)

if algorithm in ALGORITHMS.RSA:
return RSAKey(key_data, algorithm)

if algorithm in ALGORITHMS.EC:
return ECKey(key_data, algorithm)
key_class = get_key(algorithm)
if not key_class:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)
return key_class(key_data, algorithm)


def get_algorithm_object(algorithm):
Expand All @@ -91,11 +99,8 @@ class Key(object):
"""
A simple interface for implementing JWK keys.
"""
prepared_key = None
hash_alg = None

def _process_jwk(self, jwk_dict):
raise NotImplementedError()
def __init__(self, key, algorithm):
pass

def sign(self, msg):
raise NotImplementedError()
Expand All @@ -112,13 +117,9 @@ class HMACKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.HMAC

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.HMAC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -174,14 +175,10 @@ class RSAKey(Key):
SHA256 = Crypto.Hash.SHA256
SHA384 = Crypto.Hash.SHA384
SHA512 = Crypto.Hash.SHA512
valid_hash_algs = ALGORITHMS.RSA

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):

if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.RSA:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -242,7 +239,7 @@ def verify(self, msg, sig):
try:
return PKCS1_v1_5.new(self.prepared_key).verify(self.hash_alg.new(msg), sig)
except Exception as e:
raise JWKError(e)
return False


class ECKey(Key):
Expand All @@ -257,24 +254,19 @@ class ECKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.EC

curve_map = {
CURVE_MAP = {
SHA256: ecdsa.curves.NIST256p,
SHA384: ecdsa.curves.NIST384p,
SHA512: ecdsa.curves.NIST521p,
}

prepared_key = None
hash_alg = None
curve = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.EC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

self.curve = self.curve_map.get(self.hash_alg)
self.curve = self.CURVE_MAP.get(self.hash_alg)

if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
self.prepared_key = key
Expand Down
39 changes: 23 additions & 16 deletions tests/algorithms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@

# from jose.jwk import Key
# from jose.exceptions import JOSEError
from jose.jwk import Key, HMACKey, RSAKey, ECKey
from jose.constants import ALGORITHMS

# import pytest
import pytest


# @pytest.fixture
# def alg():
# return Key()
@pytest.fixture
def alg():
return Key("key", "ALG")


# class TestBaseAlgorithm:
class TestBaseAlgorithm:

# def test_prepare_key_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.prepare_key('secret')
def test_sign_is_interface(self, alg):
with pytest.raises(NotImplementedError):
alg.sign('msg')

# def test_sign_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.sign('msg', 'secret')
def test_verify_is_interface(self, alg):
with pytest.raises(NotImplementedError):
alg.verify('msg', 'sig')

# def test_verify_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.verify('msg', 'secret', 'sig')

class TestAlgorithms:

def test_register_key(self):
assert ALGORITHMS.register_key("ALG", Key)
from jose.jwk import get_key
assert get_key("ALG") == Key

with pytest.raises(TypeError):
assert ALGORITHMS.register_key("ALG", object)
15 changes: 11 additions & 4 deletions tests/test_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ class TestJWK:

def test_interface(self):

key = jwk.Key()

with pytest.raises(NotImplementedError):
key._process_jwk(None)
key = jwk.Key("key", "ALG")

with pytest.raises(NotImplementedError):
key.sign('')
Expand Down Expand Up @@ -115,3 +112,13 @@ def test_construct_from_jwk_missing_alg(self):

with pytest.raises(JWKError):
key = jwk.construct(hmac_key)

with pytest.raises(JWKError):
key = jwk.construct("key", algorithm="NONEXISTENT")

def test_get_key(self):
assert jwk.get_key("HS256") == jwk.HMACKey
assert jwk.get_key("RS256") == jwk.RSAKey
assert jwk.get_key("ES256") == jwk.ECKey

assert jwk.get_key("NONEXISTENT") == None