Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easier extending/replacing of key algorithms #42

Merged
merged 4 commits into from
Mar 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions jose/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib

class ALGORITHMS(object):

class Algorithms(object):
NONE = 'none'
HS256 = 'HS256'
HS384 = 'HS384'
Expand All @@ -12,13 +13,13 @@ class ALGORITHMS(object):
ES384 = 'ES384'
ES512 = 'ES512'

HMAC = (HS256, HS384, HS512)
RSA = (RS256, RS384, RS512)
EC = (ES256, ES384, ES512)
HMAC = set([HS256, HS384, HS512])
RSA = set([RS256, RS384, RS512])
EC = set([ES256, ES384, ES512])

SUPPORTED = HMAC + RSA + EC
SUPPORTED = HMAC.union(RSA).union(EC)

ALL = SUPPORTED + (NONE, )
ALL = SUPPORTED.union([NONE])

HASHES = {
HS256: hashlib.sha256,
Expand All @@ -31,3 +32,8 @@ class ALGORITHMS(object):
ES384: hashlib.sha384,
ES512: hashlib.sha512,
}

KEYS = {}


ALGORITHMS = Algorithms()
64 changes: 32 additions & 32 deletions jose/jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ def base64_to_long(data):
return int_arr_to_long(struct.unpack('%sB' % len(_d), _d))


def get_key(algorithm):
if algorithm in ALGORITHMS.KEYS:
return ALGORITHMS.KEYS[algorithm]
elif algorithm in ALGORITHMS.HMAC:
return HMACKey
elif algorithm in ALGORITHMS.RSA:
return RSAKey
elif algorithm in ALGORITHMS.EC:
return ECKey
return None


def register_key(algorithm, key_class):
if not issubclass(key_class, Key):
raise TypeError("Key class not a subclass of jwk.Key")
ALGORITHMS.KEYS[algorithm] = key_class
ALGORITHMS.SUPPORTED.add(algorithm)
return True


def construct(key_data, algorithm=None):
"""
Construct a Key object for the given algorithm with the given
Expand All @@ -60,14 +80,10 @@ def construct(key_data, algorithm=None):
if not algorithm:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)

if algorithm in ALGORITHMS.HMAC:
return HMACKey(key_data, algorithm)

if algorithm in ALGORITHMS.RSA:
return RSAKey(key_data, algorithm)

if algorithm in ALGORITHMS.EC:
return ECKey(key_data, algorithm)
key_class = get_key(algorithm)
if not key_class:
raise JWKError('Unable to find a algorithm for key: %s' % key_data)
return key_class(key_data, algorithm)


def get_algorithm_object(algorithm):
Expand All @@ -91,11 +107,8 @@ class Key(object):
"""
A simple interface for implementing JWK keys.
"""
prepared_key = None
hash_alg = None

def _process_jwk(self, jwk_dict):
raise NotImplementedError()
def __init__(self, key, algorithm):
pass

def sign(self, msg):
raise NotImplementedError()
Expand All @@ -112,13 +125,9 @@ class HMACKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.HMAC

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.HMAC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -174,14 +183,10 @@ class RSAKey(Key):
SHA256 = Crypto.Hash.SHA256
SHA384 = Crypto.Hash.SHA384
SHA512 = Crypto.Hash.SHA512
valid_hash_algs = ALGORITHMS.RSA

prepared_key = None
hash_alg = None

def __init__(self, key, algorithm):

if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.RSA:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

Expand Down Expand Up @@ -242,7 +247,7 @@ def verify(self, msg, sig):
try:
return PKCS1_v1_5.new(self.prepared_key).verify(self.hash_alg.new(msg), sig)
except Exception as e:
raise JWKError(e)
return False


class ECKey(Key):
Expand All @@ -257,24 +262,19 @@ class ECKey(Key):
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
valid_hash_algs = ALGORITHMS.EC

curve_map = {
CURVE_MAP = {
SHA256: ecdsa.curves.NIST256p,
SHA384: ecdsa.curves.NIST384p,
SHA512: ecdsa.curves.NIST521p,
}

prepared_key = None
hash_alg = None
curve = None

def __init__(self, key, algorithm):
if algorithm not in self.valid_hash_algs:
if algorithm not in ALGORITHMS.EC:
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
self.hash_alg = get_algorithm_object(algorithm)

self.curve = self.curve_map.get(self.hash_alg)
self.curve = self.CURVE_MAP.get(self.hash_alg)

if isinstance(key, (ecdsa.SigningKey, ecdsa.VerifyingKey)):
self.prepared_key = key
Expand Down
29 changes: 12 additions & 17 deletions tests/algorithms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from jose.jwk import Key

# from jose.jwk import Key
# from jose.exceptions import JOSEError
import pytest

# import pytest

@pytest.fixture
def alg():
return Key("key", "ALG")

# @pytest.fixture
# def alg():
# return Key()

class TestBaseAlgorithm:

# class TestBaseAlgorithm:
def test_sign_is_interface(self, alg):
with pytest.raises(NotImplementedError):
alg.sign('msg')

# def test_prepare_key_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.prepare_key('secret')
def test_verify_is_interface(self, alg):
with pytest.raises(NotImplementedError):
alg.verify('msg', 'sig')

# def test_sign_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.sign('msg', 'secret')

# def test_verify_is_interface(self, alg):
# with pytest.raises(JOSEError):
# alg.verify('msg', 'secret', 'sig')
24 changes: 18 additions & 6 deletions tests/test_jwk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@

from jose import jwk
from jose.exceptions import JWKError

import pytest


hmac_key = {
"kty": "oct",
"kid": "018c0ae5-4d9b-471b-bfd6-eef314bc7037",
Expand Down Expand Up @@ -35,10 +33,7 @@ class TestJWK:

def test_interface(self):

key = jwk.Key()

with pytest.raises(NotImplementedError):
key._process_jwk(None)
key = jwk.Key("key", "ALG")

with pytest.raises(NotImplementedError):
key.sign('')
Expand Down Expand Up @@ -115,3 +110,20 @@ def test_construct_from_jwk_missing_alg(self):

with pytest.raises(JWKError):
key = jwk.construct(hmac_key)

with pytest.raises(JWKError):
key = jwk.construct("key", algorithm="NONEXISTENT")

def test_get_key(self):
assert jwk.get_key("HS256") == jwk.HMACKey
assert jwk.get_key("RS256") == jwk.RSAKey
assert jwk.get_key("ES256") == jwk.ECKey

assert jwk.get_key("NONEXISTENT") == None

def test_register_key(self):
assert jwk.register_key("ALG", jwk.Key)
assert jwk.get_key("ALG") == jwk.Key

with pytest.raises(TypeError):
assert jwk.register_key("ALG", object)