diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9c1a7e80..5cac550c 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,9 +1,13 @@ import hashlib import hmac +import json from .compat import constant_time_compare, string_types, text_type from .exceptions import InvalidKeyError -from .utils import der_to_raw_signature, raw_to_der_signature +from .utils import ( + base64url_decode, base64url_encode, der_to_raw_signature, + from_base64url_uint, raw_to_der_signature, to_base64url_uint +) try: from cryptography.hazmat.primitives import hashes @@ -11,7 +15,8 @@ load_pem_private_key, load_pem_public_key, load_ssh_public_key ) from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, RSAPublicKey + RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers, + rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp ) from cryptography.hazmat.primitives.asymmetric.ec import ( EllipticCurvePrivateKey, EllipticCurvePublicKey @@ -77,6 +82,20 @@ def verify(self, msg, key, sig): """ raise NotImplementedError + @staticmethod + def to_jwk(key_obj): + """ + Serializes a given RSA key into a JWK + """ + raise NotImplementedError + + @staticmethod + def from_jwk(jwk): + """ + Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + """ + raise NotImplementedError + class NoneAlgorithm(Algorithm): """ @@ -131,6 +150,22 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + return json.dumps({ + 'k': base64url_encode(key_obj), + 'typ': 'oct' + }) + + @staticmethod + def from_jwk(jwk): + obj = json.loads(jwk) + + if obj.get('kty') != 'oct': + raise InvalidKeyError('Not an HMAC key') + + return base64url_decode(obj['k']) + def sign(self, msg, key): return hmac.new(key, msg, self.hash_alg).digest() @@ -172,6 +207,101 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + obj = None + + if getattr(key_obj, 'private_numbers', None): + # Private key + numbers = key_obj.private_numbers() + + obj = { + 'kty': 'RSA', + 'key_ops': ['sign'], + 'd': to_base64url_uint(numbers.d), + 'p': to_base64url_uint(numbers.p), + 'q': to_base64url_uint(numbers.q), + 'dp': to_base64url_uint(numbers.dmp1), + 'dq': to_base64url_uint(numbers.dmq1), + 'qi': to_base64url_uint(numbers.iqmp) + } + + elif getattr(key_obj, 'verifier', None): + # Public key + numbers = key_obj.public_numbers() + + obj = { + 'kty': 'RSA', + 'use': 'sig', + 'key_ops': ['verify'], + 'n': to_base64url_uint(numbers.n), + 'e': to_base64url_uint(numbers.e) + } + else: + raise InvalidKeyError('Not a public or private key') + + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk): + obj = json.loads(jwk) + + if obj.get('kty') != 'RSA': + raise InvalidKeyError('Not an RSA key') + + if 'd' in obj and 'e' in obj and 'n' in obj: + # Private key + if 'oth' in obj: + raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported') + + other_props = ['p', 'q', 'dp', 'dq', 'qi'] + props_found = [True for prop in other_props if prop in obj] + any_props_found = any(props_found) + + if any_props_found and not all(props_found): + raise InvalidKeyError('RSA key must include all parameters if any are present besides d') + + public_numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + if any_props_found: + numbers = RSAPrivateNumbers( + d=from_base64url_uint(obj['d']), + p=from_base64url_uint(obj['p']), + q=from_base64url_uint(obj['q']), + dmp1=from_base64url_uint(obj['dp']), + dmq1=from_base64url_uint(obj['dq']), + iqmp=from_base64url_uint(obj['qi']), + public_numbers=public_numbers + ) + else: + p, q = rsa_recover_prime_factors( + public_numbers.n, public_numbers.d, public_numbers.e + ) + d = from_base64url_uint(obj['d']) + + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers + ) + + return numbers.private_key(default_backend()) + elif 'n' in obj and 'e' in obj: + # Public key + numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + return numbers.public_key(default_backend()) + else: + raise InvalidKeyError('Not a public or private key') + def sign(self, msg, key): signer = key.signer( padding.PKCS1v15(), diff --git a/jwt/compat.py b/jwt/compat.py index 8f6bfed1..73f90446 100644 --- a/jwt/compat.py +++ b/jwt/compat.py @@ -5,7 +5,7 @@ # flake8: noqa import sys import hmac - +import struct PY3 = sys.version_info[0] == 3 @@ -52,3 +52,24 @@ def constant_time_compare(val1, val2): result |= ord(x) ^ ord(y) return result == 0 + +# Use int.to_bytes if it exists (Python 3) +if getattr(int, 'to_bytes', None): + def bytes_from_int(val): + remaining = val + byte_length = 0 + + while remaining != 0: + remaining = remaining >> 8 + byte_length += 1 + + return val.to_bytes(byte_length, 'big', signed=False) +else: + def bytes_from_int(val): + buf = [] + while val: + val, remainder = divmod(val, 256) + buf.append(remainder) + + buf.reverse() + return struct.pack('%sB' % len(buf), *buf) diff --git a/jwt/utils.py b/jwt/utils.py index 637b8929..18311426 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,5 +1,8 @@ import base64 import binascii +import struct + +from .compat import bytes_from_int, text_type try: from cryptography.hazmat.primitives.asymmetric.utils import ( @@ -10,6 +13,9 @@ def base64url_decode(input): + if isinstance(input, text_type): + input = input.encode('ascii') + rem = len(input) % 4 if rem > 0: @@ -22,6 +28,28 @@ def base64url_encode(input): return base64.urlsafe_b64encode(input).replace(b'=', b'') +def to_base64url_uint(val): + if val < 0: + raise ValueError('Must be a positive integer') + + int_bytes = bytes_from_int(val) + + if len(int_bytes) == 0: + int_bytes = b'\x00' + + return base64url_encode(int_bytes) + + +def from_base64url_uint(val): + if isinstance(val, text_type): + val = val.encode('ascii') + + data = base64url_decode(val) + + buf = struct.unpack('%sB' % len(data), data) + return int(''.join(["%02x" % byte for byte in buf]), 16) + + def merge_dict(original, updates): if not updates: return original diff --git a/tests/keys/__init__.py b/tests/keys/__init__.py index fad09f57..47bce70b 100644 --- a/tests/keys/__init__.py +++ b/tests/keys/__init__.py @@ -20,10 +20,9 @@ def load_hmac_key(): return base64url_decode(ensure_bytes(keyobj['k'])) try: - from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.backends import default_backend - + from jwt.algorithms import RSAAlgorithm has_crypto = True except ImportError: has_crypto = False @@ -31,26 +30,11 @@ def load_hmac_key(): if has_crypto: def load_rsa_key(): with open(os.path.join(BASE_PATH, 'jwk_rsa_key.json'), 'r') as infile: - keyobj = json.load(infile) - - return rsa.RSAPrivateNumbers( - p=decode_value(keyobj['p']), - q=decode_value(keyobj['q']), - d=decode_value(keyobj['d']), - dmp1=decode_value(keyobj['dp']), - dmq1=decode_value(keyobj['dq']), - iqmp=decode_value(keyobj['qi']), - public_numbers=load_rsa_pub_key().public_numbers() - ).private_key(default_backend()) + return RSAAlgorithm.from_jwk(infile.read()) def load_rsa_pub_key(): with open(os.path.join(BASE_PATH, 'jwk_rsa_pub.json'), 'r') as infile: - keyobj = json.load(infile) - - return rsa.RSAPublicNumbers( - n=decode_value(keyobj['n']), - e=decode_value(keyobj['e']) - ).public_key(default_backend()) + return RSAAlgorithm.from_jwk(infile.read()) def load_ec_key(): with open(os.path.join(BASE_PATH, 'jwk_ec_key.json'), 'r') as infile: diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 45bb6d11..273338b4 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -36,6 +36,18 @@ def test_algorithm_should_throw_exception_if_verify_not_impl(self): with pytest.raises(NotImplementedError): algo.verify('message', 'key', 'signature') + def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self): + algo = Algorithm() + + with pytest.raises(NotImplementedError): + algo.from_jwk('value') + + def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self): + algo = Algorithm() + + with pytest.raises(NotImplementedError): + algo.to_jwk('value') + def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): algo = NoneAlgorithm() @@ -84,6 +96,15 @@ def test_hmac_should_throw_exception_if_key_is_x509_cert(self): with open(key_path('testkey2_rsa.pub.pem'), 'r') as keyfile: algo.prepare_key(keyfile.read()) + def test_hmac_jwk_public_and_private_keys_should_parse_and_verify(self): + algo = HMACAlgorithm(HMACAlgorithm.SHA256) + + with open(key_path('jwk_hmac.json'), 'r') as keyfile: + key = algo.from_jwk(keyfile.read()) + + signature = algo.sign(b'Hello World!', key) + assert algo.verify(b'Hello World!', key, signature) + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_rsa_should_parse_pem_public_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -127,6 +148,19 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): result = algo.verify(message, pub_key, sig) assert not result + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + priv_key = algo.from_jwk(keyfile.read()) + + signature = algo.sign(ensure_bytes('Hello World!'), priv_key) + assert algo.verify(ensure_bytes('Hello World!'), pub_key, signature) + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_ec_should_reject_non_string_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..c9397cc0 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,28 @@ +from jwt.utils import from_base64url_uint, to_base64url_uint + +import pytest + + +@pytest.mark.parametrize("inputval,expected", [ + (0, b'AA'), + (1, b'AQ'), + (255, b'_w'), + (65537, b'AQAB'), + (123456789, b'B1vNFQ'), + pytest.mark.xfail((-1, ''), raises=ValueError) +]) +def test_to_base64url_uint(inputval, expected): + actual = to_base64url_uint(inputval) + assert actual == expected + + +@pytest.mark.parametrize("inputval,expected", [ + (b'AA', 0), + (b'AQ', 1), + (b'_w', 255), + (b'AQAB', 65537), + (b'B1vNFQ', 123456789, ), +]) +def test_from_base64url_uint(inputval, expected): + actual = from_base64url_uint(inputval) + assert actual == expected