From 42b011410b28ea3e3920f40641a57c8893e2e04e Mon Sep 17 00:00:00 2001 From: Mark Adams Date: Thu, 5 May 2016 23:58:32 -0500 Subject: [PATCH] Add JWK support for HMAC and RSA keys - JWKs for RSA and HMAC can be encoded / decoded using the .to_jwk() and .from_jwk() methods on their respective jwt.algorithms instances - Replaced tests.utils ensure_unicode and ensure_bytes with jwt.utils versions --- jwt/algorithms.py | 158 ++++++++++++++++-- jwt/api_jws.py | 14 +- jwt/compat.py | 28 +++- jwt/utils.py | 46 ++++++ tests/contrib/test_algorithms.py | 36 +++-- tests/keys/__init__.py | 30 +--- tests/test_algorithms.py | 270 ++++++++++++++++++++++++++++--- tests/test_api_jws.py | 37 +++-- tests/test_compat.py | 9 +- tests/test_utils.py | 40 +++++ tests/utils.py | 16 -- 11 files changed, 554 insertions(+), 130 deletions(-) create mode 100644 tests/test_utils.py diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 51e8f160..3144f230 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,9 +1,15 @@ import hashlib import hmac +import json -from .compat import binary_type, constant_time_compare, is_string_type + +from .compat import constant_time_compare, string_types from .exceptions import InvalidKeyError -from .utils import der_to_raw_signature, raw_to_der_signature +from .utils import ( + base64url_decode, base64url_encode, der_to_raw_signature, + force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature, + to_base64url_uint +) try: from cryptography.hazmat.primitives import hashes @@ -11,7 +17,8 @@ load_pem_private_key, load_pem_public_key, load_ssh_public_key ) from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, RSAPublicKey + RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers, + rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp ) from cryptography.hazmat.primitives.asymmetric.ec import ( EllipticCurvePrivateKey, EllipticCurvePublicKey @@ -77,6 +84,20 @@ def verify(self, msg, key, sig): """ raise NotImplementedError + @staticmethod + def to_jwk(key_obj): + """ + Serializes a given RSA key into a JWK + """ + raise NotImplementedError + + @staticmethod + def from_jwk(jwk): + """ + Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + """ + raise NotImplementedError + class NoneAlgorithm(Algorithm): """ @@ -112,11 +133,7 @@ def __init__(self, hash_alg): self.hash_alg = hash_alg def prepare_key(self, key): - if not is_string_type(key): - raise TypeError('Expecting a string- or bytes-formatted key.') - - if not isinstance(key, binary_type): - key = key.encode('utf-8') + key = force_bytes(key) invalid_strings = [ b'-----BEGIN PUBLIC KEY-----', @@ -131,6 +148,22 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + return json.dumps({ + 'k': force_unicode(base64url_encode(force_bytes(key_obj))), + 'kty': 'oct' + }) + + @staticmethod + def from_jwk(jwk): + obj = json.loads(jwk) + + if obj.get('kty') != 'oct': + raise InvalidKeyError('Not an HMAC key') + + return base64url_decode(obj['k']) + def sign(self, msg, key): return hmac.new(key, msg, self.hash_alg).digest() @@ -156,9 +189,8 @@ def prepare_key(self, key): isinstance(key, RSAPublicKey): return key - if is_string_type(key): - if not isinstance(key, binary_type): - key = key.encode('utf-8') + if isinstance(key, string_types): + key = force_bytes(key) try: if key.startswith(b'ssh-rsa'): @@ -172,6 +204,105 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + obj = None + + if getattr(key_obj, 'private_numbers', None): + # Private key + numbers = key_obj.private_numbers() + + obj = { + 'kty': 'RSA', + 'key_ops': ['sign'], + 'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)), + 'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)), + 'd': force_unicode(to_base64url_uint(numbers.d)), + 'p': force_unicode(to_base64url_uint(numbers.p)), + 'q': force_unicode(to_base64url_uint(numbers.q)), + 'dp': force_unicode(to_base64url_uint(numbers.dmp1)), + 'dq': force_unicode(to_base64url_uint(numbers.dmq1)), + 'qi': force_unicode(to_base64url_uint(numbers.iqmp)) + } + + elif getattr(key_obj, 'verifier', None): + # Public key + numbers = key_obj.public_numbers() + + obj = { + 'kty': 'RSA', + 'key_ops': ['verify'], + 'n': force_unicode(to_base64url_uint(numbers.n)), + 'e': force_unicode(to_base64url_uint(numbers.e)) + } + else: + raise InvalidKeyError('Not a public or private key') + + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk): + try: + obj = json.loads(jwk) + except ValueError: + raise InvalidKeyError('Key is not valid JSON') + + if obj.get('kty') != 'RSA': + raise InvalidKeyError('Not an RSA key') + + if 'd' in obj and 'e' in obj and 'n' in obj: + # Private key + if 'oth' in obj: + raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported') + + other_props = ['p', 'q', 'dp', 'dq', 'qi'] + props_found = [prop in obj for prop in other_props] + any_props_found = any(props_found) + + if any_props_found and not all(props_found): + raise InvalidKeyError('RSA key must include all parameters if any are present besides d') + + public_numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + if any_props_found: + numbers = RSAPrivateNumbers( + d=from_base64url_uint(obj['d']), + p=from_base64url_uint(obj['p']), + q=from_base64url_uint(obj['q']), + dmp1=from_base64url_uint(obj['dp']), + dmq1=from_base64url_uint(obj['dq']), + iqmp=from_base64url_uint(obj['qi']), + public_numbers=public_numbers + ) + else: + d = from_base64url_uint(obj['d']) + p, q = rsa_recover_prime_factors( + public_numbers.n, d, public_numbers.e + ) + + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers + ) + + return numbers.private_key(default_backend()) + elif 'n' in obj and 'e' in obj: + # Public key + numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + return numbers.public_key(default_backend()) + else: + raise InvalidKeyError('Not a public or private key') + def sign(self, msg, key): signer = key.signer( padding.PKCS1v15(), @@ -213,9 +344,8 @@ def prepare_key(self, key): isinstance(key, EllipticCurvePublicKey): return key - if is_string_type(key): - if not isinstance(key, binary_type): - key = key.encode('utf-8') + if isinstance(key, string_types): + key = force_bytes(key) # Attempt to load key. We don't know if it's # a Signing Key or a Verifying Key, so we try diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 177f5ff5..f4e58b74 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -7,7 +7,7 @@ from .algorithms import Algorithm, get_default_algorithms # NOQA from .compat import binary_type, string_types, text_type from .exceptions import DecodeError, InvalidAlgorithmError, InvalidTokenError -from .utils import base64url_decode, base64url_encode, merge_dict +from .utils import base64url_decode, base64url_encode, force_bytes, merge_dict class PyJWS(object): @@ -82,11 +82,13 @@ def encode(self, payload, key, algorithm='HS256', headers=None, self._validate_headers(headers) header.update(headers) - json_header = json.dumps( - header, - separators=(',', ':'), - cls=json_encoder - ).encode('utf-8') + json_header = force_bytes( + json.dumps( + header, + separators=(',', ':'), + cls=json_encoder + ) + ) segments.append(base64url_encode(json_header)) segments.append(base64url_encode(payload)) diff --git a/jwt/compat.py b/jwt/compat.py index dafd0c75..b928c7d2 100644 --- a/jwt/compat.py +++ b/jwt/compat.py @@ -3,8 +3,9 @@ versions of python, and compatibility wrappers around optional packages. """ # flake8: noqa -import sys import hmac +import struct +import sys PY3 = sys.version_info[0] == 3 @@ -20,10 +21,6 @@ string_types = (text_type, binary_type) -def is_string_type(val): - return any([isinstance(val, typ) for typ in string_types]) - - def timedelta_total_seconds(delta): try: delta.total_seconds @@ -56,3 +53,24 @@ def constant_time_compare(val1, val2): result |= ord(x) ^ ord(y) return result == 0 + +# Use int.to_bytes if it exists (Python 3) +if getattr(int, 'to_bytes', None): + def bytes_from_int(val): + remaining = val + byte_length = 0 + + while remaining != 0: + remaining = remaining >> 8 + byte_length += 1 + + return val.to_bytes(byte_length, 'big', signed=False) +else: + def bytes_from_int(val): + buf = [] + while val: + val, remainder = divmod(val, 256) + buf.append(remainder) + + buf.reverse() + return struct.pack('%sB' % len(buf), *buf) diff --git a/jwt/utils.py b/jwt/utils.py index 637b8929..4f0a4968 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,5 +1,8 @@ import base64 import binascii +import struct + +from .compat import binary_type, bytes_from_int, text_type try: from cryptography.hazmat.primitives.asymmetric.utils import ( @@ -9,7 +12,28 @@ pass +def force_unicode(value): + if isinstance(value, binary_type): + return value.decode('utf-8') + elif isinstance(value, text_type): + return value + else: + raise TypeError('Expected a string value') + + +def force_bytes(value): + if isinstance(value, text_type): + return value.encode('utf-8') + elif isinstance(value, binary_type): + return value + else: + raise TypeError('Expected a string value') + + def base64url_decode(input): + if isinstance(input, text_type): + input = input.encode('ascii') + rem = len(input) % 4 if rem > 0: @@ -22,6 +46,28 @@ def base64url_encode(input): return base64.urlsafe_b64encode(input).replace(b'=', b'') +def to_base64url_uint(val): + if val < 0: + raise ValueError('Must be a positive integer') + + int_bytes = bytes_from_int(val) + + if len(int_bytes) == 0: + int_bytes = b'\x00' + + return base64url_encode(int_bytes) + + +def from_base64url_uint(val): + if isinstance(val, text_type): + val = val.encode('ascii') + + data = base64url_decode(val) + + buf = struct.unpack('%sB' % len(data), data) + return int(''.join(["%02x" % byte for byte in buf]), 16) + + def merge_dict(original, updates): if not updates: return original diff --git a/tests/contrib/test_algorithms.py b/tests/contrib/test_algorithms.py index 1fbe10d1..6d5ca75b 100644 --- a/tests/contrib/test_algorithms.py +++ b/tests/contrib/test_algorithms.py @@ -1,8 +1,10 @@ import base64 +from jwt.utils import force_bytes, force_unicode + import pytest -from ..utils import ensure_bytes, ensure_unicode, key_path +from ..utils import key_path try: from jwt.contrib.algorithms.pycrypto import RSAAlgorithm @@ -29,7 +31,7 @@ def test_rsa_should_accept_unicode_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path('testkey_rsa'), 'r') as rsa_key: - algo.prepare_key(ensure_unicode(rsa_key.read())) + algo.prepare_key(force_unicode(rsa_key.read())) def test_rsa_should_reject_non_string_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -40,9 +42,9 @@ def test_rsa_should_reject_non_string_key(self): def test_rsa_sign_should_generate_correct_signature_value(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - expected_sig = base64.b64decode(ensure_bytes( + expected_sig = base64.b64decode(force_bytes( 'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp' '10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl' '2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix' @@ -63,9 +65,9 @@ def test_rsa_sign_should_generate_correct_signature_value(self): def test_rsa_verify_should_return_false_if_signature_invalid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + jwt_sig = base64.b64decode(force_bytes( 'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp' '10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl' '2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix' @@ -73,7 +75,7 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): 'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA' 'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA==')) - jwt_sig += ensure_bytes('123') # Signature is now invalid + jwt_sig += force_bytes('123') # Signature is now invalid with open(key_path('testkey_rsa.pub'), 'r') as keyfile: jwt_pub_key = algo.prepare_key(keyfile.read()) @@ -84,9 +86,9 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): def test_rsa_verify_should_return_true_if_signature_valid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + jwt_sig = base64.b64decode(force_bytes( 'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp' '10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl' '2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix' @@ -122,14 +124,14 @@ def test_ec_should_accept_unicode_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path('testkey_ec'), 'r') as ec_key: - algo.prepare_key(ensure_unicode(ec_key.read())) + algo.prepare_key(force_unicode(ec_key.read())) def test_ec_sign_should_generate_correct_signature_value(self): algo = ECAlgorithm(ECAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - expected_sig = base64.b64decode(ensure_bytes( + expected_sig = base64.b64decode(force_bytes( 'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M' 'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw' 'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65')) @@ -147,14 +149,14 @@ def test_ec_sign_should_generate_correct_signature_value(self): def test_ec_verify_should_return_false_if_signature_invalid(self): algo = ECAlgorithm(ECAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + jwt_sig = base64.b64decode(force_bytes( 'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M' 'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw' 'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65')) - jwt_sig += ensure_bytes('123') # Signature is now invalid + jwt_sig += force_bytes('123') # Signature is now invalid with open(key_path('testkey_ec.pub'), 'r') as keyfile: jwt_pub_key = algo.prepare_key(keyfile.read()) @@ -165,9 +167,9 @@ def test_ec_verify_should_return_false_if_signature_invalid(self): def test_ec_verify_should_return_true_if_signature_valid(self): algo = ECAlgorithm(ECAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + jwt_sig = base64.b64decode(force_bytes( 'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M' 'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw' 'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65')) diff --git a/tests/keys/__init__.py b/tests/keys/__init__.py index fad09f57..4fae6878 100644 --- a/tests/keys/__init__.py +++ b/tests/keys/__init__.py @@ -1,15 +1,15 @@ import json import os -from jwt.utils import base64url_decode +from jwt.utils import base64url_decode, force_bytes -from tests.utils import ensure_bytes, int_from_bytes +from tests.utils import int_from_bytes BASE_PATH = os.path.dirname(os.path.abspath(__file__)) def decode_value(val): - decoded = base64url_decode(ensure_bytes(val)) + decoded = base64url_decode(force_bytes(val)) return int_from_bytes(decoded, 'big') @@ -17,13 +17,12 @@ def load_hmac_key(): with open(os.path.join(BASE_PATH, 'jwk_hmac.json'), 'r') as infile: keyobj = json.load(infile) - return base64url_decode(ensure_bytes(keyobj['k'])) + return base64url_decode(force_bytes(keyobj['k'])) try: - from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.backends import default_backend - + from jwt.algorithms import RSAAlgorithm has_crypto = True except ImportError: has_crypto = False @@ -31,26 +30,11 @@ def load_hmac_key(): if has_crypto: def load_rsa_key(): with open(os.path.join(BASE_PATH, 'jwk_rsa_key.json'), 'r') as infile: - keyobj = json.load(infile) - - return rsa.RSAPrivateNumbers( - p=decode_value(keyobj['p']), - q=decode_value(keyobj['q']), - d=decode_value(keyobj['d']), - dmp1=decode_value(keyobj['dp']), - dmq1=decode_value(keyobj['dq']), - iqmp=decode_value(keyobj['qi']), - public_numbers=load_rsa_pub_key().public_numbers() - ).private_key(default_backend()) + return RSAAlgorithm.from_jwk(infile.read()) def load_rsa_pub_key(): with open(os.path.join(BASE_PATH, 'jwk_rsa_pub.json'), 'r') as infile: - keyobj = json.load(infile) - - return rsa.RSAPublicNumbers( - n=decode_value(keyobj['n']), - e=decode_value(keyobj['e']) - ).public_key(default_backend()) + return RSAAlgorithm.from_jwk(infile.read()) def load_ec_key(): with open(os.path.join(BASE_PATH, 'jwk_ec_key.json'), 'r') as infile: diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index e3cf1d0b..97fdc225 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,13 +1,14 @@ import base64 +import json from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm from jwt.exceptions import InvalidKeyError -from jwt.utils import base64url_decode +from jwt.utils import base64url_decode, force_bytes, force_unicode import pytest from .keys import load_hmac_key -from .utils import ensure_bytes, ensure_unicode, key_path +from .utils import key_path try: from jwt.algorithms import RSAAlgorithm, ECAlgorithm, RSAPSSAlgorithm @@ -36,6 +37,18 @@ def test_algorithm_should_throw_exception_if_verify_not_impl(self): with pytest.raises(NotImplementedError): algo.verify('message', 'key', 'signature') + def test_algorithm_should_throw_exception_if_to_jwk_not_impl(self): + algo = Algorithm() + + with pytest.raises(NotImplementedError): + algo.from_jwk('value') + + def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self): + algo = Algorithm() + + with pytest.raises(NotImplementedError): + algo.to_jwk('value') + def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): algo = NoneAlgorithm() @@ -49,12 +62,12 @@ def test_hmac_should_reject_nonstring_key(self): algo.prepare_key(object()) exception = context.value - assert str(exception) == 'Expecting a string- or bytes-formatted key.' + assert str(exception) == 'Expected a string value' def test_hmac_should_accept_unicode_key(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) - algo.prepare_key(ensure_unicode('awesome')) + algo.prepare_key(force_unicode('awesome')) def test_hmac_should_throw_exception_if_key_is_pem_public_key(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) @@ -84,6 +97,28 @@ def test_hmac_should_throw_exception_if_key_is_x509_cert(self): with open(key_path('testkey2_rsa.pub.pem'), 'r') as keyfile: algo.prepare_key(keyfile.read()) + def test_hmac_jwk_should_parse_and_verify(self): + algo = HMACAlgorithm(HMACAlgorithm.SHA256) + + with open(key_path('jwk_hmac.json'), 'r') as keyfile: + key = algo.from_jwk(keyfile.read()) + + signature = algo.sign(b'Hello World!', key) + assert algo.verify(b'Hello World!', key, signature) + + def test_hmac_to_jwk_returns_correct_values(self): + algo = HMACAlgorithm(HMACAlgorithm.SHA256) + key = algo.to_jwk('secret') + + assert json.loads(key) == {'kty': 'oct', 'k': 'c2VjcmV0'} + + def test_hmac_from_jwk_should_raise_exception_if_not_hmac_key(self): + algo = HMACAlgorithm(HMACAlgorithm.SHA256) + + with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile: + with pytest.raises(InvalidKeyError): + algo.from_jwk(keyfile.read()) + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_rsa_should_parse_pem_public_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -103,7 +138,7 @@ def test_rsa_should_accept_unicode_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) with open(key_path('testkey_rsa'), 'r') as rsa_key: - algo.prepare_key(ensure_unicode(rsa_key.read())) + algo.prepare_key(force_unicode(rsa_key.read())) @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_rsa_should_reject_non_string_key(self): @@ -116,9 +151,9 @@ def test_rsa_should_reject_non_string_key(self): def test_rsa_verify_should_return_false_if_signature_invalid(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) - message = ensure_bytes('Hello World!') + message = force_bytes('Hello World!') - sig = base64.b64decode(ensure_bytes( + sig = base64.b64decode(force_bytes( 'yS6zk9DBkuGTtcBzLUzSpo9gGJxJFOGvUqN01iLhWHrzBQ9ZEz3+Ae38AXp' '10RWwscp42ySC85Z6zoN67yGkLNWnfmCZSEv+xqELGEvBJvciOKsrhiObUl' '2mveSc1oeO/2ujkGDkkkJ2epn0YliacVjZF5+/uDmImUfAAj8lzjnHlzYix' @@ -126,7 +161,7 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): 'fHJnNUzAEUOXS0WahHVb57D30pcgIji9z923q90p5c7E2cU8V+E1qe8NdCA' 'APCDzZZ9zQ/dgcMVaBrGrgimrcLbPjueOKFgSO+SSjIElKA==')) - sig += ensure_bytes('123') # Signature is now invalid + sig += force_bytes('123') # Signature is now invalid with open(key_path('testkey_rsa.pub'), 'r') as keyfile: pub_key = algo.prepare_key(keyfile.read()) @@ -134,6 +169,191 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): result = algo.verify(message, pub_key, sig) assert not result + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + priv_key = algo.from_jwk(keyfile.read()) + + signature = algo.sign(force_bytes('Hello World!'), priv_key) + assert algo.verify(force_bytes('Hello World!'), pub_key, signature) + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_private_key_to_jwk_works_with_from_jwk(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('testkey_rsa'), 'r') as rsa_key: + orig_key = algo.prepare_key(force_unicode(rsa_key.read())) + + parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + assert parsed_key.private_numbers() == orig_key.private_numbers() + assert parsed_key.private_numbers().public_numbers == orig_key.private_numbers().public_numbers + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_public_key_to_jwk_works_with_from_jwk(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('testkey_rsa.pub'), 'r') as rsa_key: + orig_key = algo.prepare_key(force_unicode(rsa_key.read())) + + parsed_key = algo.from_jwk(algo.to_jwk(orig_key)) + assert parsed_key.public_numbers() == orig_key.public_numbers() + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_private_key_with_other_primes_is_invalid(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + with pytest.raises(InvalidKeyError): + keydata = json.loads(keyfile.read()) + keydata['oth'] = [] + + algo.from_jwk(json.dumps(keydata)) + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_private_key_with_missing_values_is_invalid(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + with pytest.raises(InvalidKeyError): + keydata = json.loads(keyfile.read()) + del keydata['p'] + + algo.from_jwk(json.dumps(keydata)) + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_private_key_can_recover_prime_factors(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + keybytes = keyfile.read() + control_key = algo.from_jwk(keybytes).private_numbers() + + keydata = json.loads(keybytes) + delete_these = ['p', 'q', 'dp', 'dq', 'qi'] + for field in delete_these: + del keydata[field] + + parsed_key = algo.from_jwk(json.dumps(keydata)).private_numbers() + + assert control_key.d == parsed_key.d + assert control_key.p == parsed_key.p + assert control_key.q == parsed_key.q + assert control_key.dmp1 == parsed_key.dmp1 + assert control_key.dmq1 == parsed_key.dmq1 + assert control_key.iqmp == parsed_key.iqmp + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_private_key_with_missing_required_values_is_invalid(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + with pytest.raises(InvalidKeyError): + keydata = json.loads(keyfile.read()) + del keydata['p'] + + algo.from_jwk(json.dumps(keydata)) + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_raises_exception_if_not_a_valid_key(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + # Invalid JSON + with pytest.raises(InvalidKeyError): + algo.from_jwk('{not-a-real-key') + + # Missing key parts + with pytest.raises(InvalidKeyError): + algo.from_jwk('{"kty": "RSA"}') + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_to_jwk_returns_correct_values_for_public_key(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('testkey_rsa.pub'), 'r') as keyfile: + pub_key = algo.prepare_key(keyfile.read()) + + key = algo.to_jwk(pub_key) + + expected = { + 'e': 'AQAB', + 'key_ops': ['verify'], + 'kty': 'RSA', + 'n': ( + '1HgzBfJv2cOjQryCwe8NEelriOTNFWKZUivevUrRhlqcmZJdCvuCJRr-xCN-' + 'OmO8qwgJJR98feNujxVg-J9Ls3_UOA4HcF9nYH6aqVXELAE8Hk_ALvxi96ms' + '1DDuAvQGaYZ-lANxlvxeQFOZSbjkz_9mh8aLeGKwqJLp3p-OhUBQpwvAUAPg' + '82-OUtgTW3nSljjeFr14B8qAneGSc_wl0ni--1SRZUXFSovzcqQOkla3W27r' + 'rLfrD6LXgj_TsDs4vD1PnIm1zcVenKT7TfYI17bsG_O_Wecwz2Nl19pL7gDo' + 'sNruF3ogJWNq1Lyn_ijPQnkPLpZHyhvuiycYcI3DiQ' + ), + } + assert json.loads(key) == expected + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_to_jwk_returns_correct_values_for_private_key(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('testkey_rsa'), 'r') as keyfile: + priv_key = algo.prepare_key(keyfile.read()) + + key = algo.to_jwk(priv_key) + + expected = { + 'key_ops': [u'sign'], + 'kty': 'RSA', + 'e': 'AQAB', + 'n': ( + '1HgzBfJv2cOjQryCwe8NEelriOTNFWKZUivevUrRhlqcmZJdCvuCJRr-xCN-' + 'OmO8qwgJJR98feNujxVg-J9Ls3_UOA4HcF9nYH6aqVXELAE8Hk_ALvxi96ms' + '1DDuAvQGaYZ-lANxlvxeQFOZSbjkz_9mh8aLeGKwqJLp3p-OhUBQpwvAUAPg' + '82-OUtgTW3nSljjeFr14B8qAneGSc_wl0ni--1SRZUXFSovzcqQOkla3W27r' + 'rLfrD6LXgj_TsDs4vD1PnIm1zcVenKT7TfYI17bsG_O_Wecwz2Nl19pL7gDo' + 'sNruF3ogJWNq1Lyn_ijPQnkPLpZHyhvuiycYcI3DiQ' + ), + 'd': ('rfbs8AWdB1RkLJRlC51LukrAvYl5UfU1TE6XRa4o-DTg2-03OXLNEMyVpMr' + 'a47weEnu14StypzC8qXL7vxXOyd30SSFTffLfleaTg-qxgMZSDw-Fb_M-pU' + 'HMPMEDYG-lgGma4l4fd1yTX2ATtoUo9BVOQgWS1LMZqi0ASEOkUfzlBgL04' + 'UoaLhPSuDdLygdlDzgruVPnec0t1uOEObmrcWIkhwU2CGQzeLtuzX6OVgPh' + 'k7xcnjbDurTTVpWH0R0gbZ5ukmQ2P-YuCX8T9iWNMGjPNSkb7h02s2Oe9ZR' + 'zP007xQ0VF-Z7xyLuxk6ASmoX1S39ujSbk2WF0eXNPRgFwQ'), + 'q': ('47hlW2f1ARuWYJf9Dl6MieXjdj2dGx9PL2UH0unVzJYInd56nqXNPrQrc5k' + 'ZU65KApC9n9oKUwIxuqwAAbh8oGNEQDqnuTj-powCkdC6bwA8KH1Y-wotpq' + '_GSjxkNzjWRm2GArJSzZc6Fb8EuObOrAavKJ285-zMPCEfus1WZG0'), + 'p': ('7tr0z929Lp4OHIRJjIKM_rDrWMPtRgnV-51pgWsN6qdpDzns_PgFwrHcoyY' + 'sWIO-4yCdVWPxFOgEZ8xXTM_uwOe4VEmdZhw55Tx7axYZtmZYZbO_RIP4CG' + 'mlJlOFTiYnxpr-2Cx6kIeQmd-hf7fA3tL018aEzwYMbFMcnAGnEg0'), + 'qi': ('djo95mB0LVYikNPa-NgyDwLotLqrueb9IviMmn6zKHCwiOXReqXDX9slB8' + 'RA15uv56bmN04O__NyVFcgJ2ef169GZHiRFIgIy0Pl8LYkMhCYKKhyqM7g' + 'xN-SqGqDTKDC22j00S7jcvCaa1qadn1qbdfukZ4NXv7E2d_LO0Y2Kkc'), + 'dp': ('tgZ2-tJpEdWxu1m1EzeKa644LHVjpTRptk7H0LDc8i6SieADEuWQvkb9df' + 'fpY6tDFaQNQr3fQ6dtdAztmsP7l1b_ynwvT1nDZUcqZvl4ruBgDWFmKbjI' + 'lOCt0v9jX6MEPP5xqBx9axdkw18BnGtUuHrbzHSlUX-yh_rumpVH1SE'), + 'dq': ('xxCIuhD0YlWFbUcwFgGdBWcLIm_WCMGj7SB6aGu1VDTLr4Wu10TFWM0TNu' + 'hc9YPker2gpj5qzAmdAzwcfWSSvXpJTYR43jfulBTMoj8-2o3wCM0anclW' + 'AuKhin-kc4mh9ssDXRQZwlMymZP0QtaxUDw_nlfVrUCZgO7L1_ZsUTk') + } + assert json.loads(key) == expected + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_to_jwk_raises_exception_on_invalid_key(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with pytest.raises(InvalidKeyError): + algo.to_jwk({'not': 'a valid key'}) + + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_from_jwk_raises_exception_on_invalid_key(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_hmac.json'), 'r') as keyfile: + with pytest.raises(InvalidKeyError): + algo.from_jwk(keyfile.read()) + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_ec_should_reject_non_string_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) @@ -146,7 +366,7 @@ def test_ec_should_accept_unicode_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256) with open(key_path('testkey_ec'), 'r') as ec_key: - algo.prepare_key(ensure_unicode(ec_key.read())) + algo.prepare_key(force_unicode(ec_key.read())) @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_ec_should_accept_pem_private_key_bytes(self): @@ -159,10 +379,10 @@ def test_ec_should_accept_pem_private_key_bytes(self): def test_ec_verify_should_return_false_if_signature_invalid(self): algo = ECAlgorithm(ECAlgorithm.SHA256) - message = ensure_bytes('Hello World!') + message = force_bytes('Hello World!') # Mess up the signature by replacing a known byte - sig = base64.b64decode(ensure_bytes( + sig = base64.b64decode(force_bytes( 'AC+m4Jf/xI3guAC6w0w37t5zRpSCF6F4udEz5LiMiTIjCS4vcVe6dDOxK+M' 'mvkF8PxJuvqxP2CO3TR3okDPCl/NjATTO1jE+qBZ966CRQSSzcCM+tzcHzw' 'LZS5kbvKu0Acd/K6Ol2/W3B1NeV5F/gjvZn/jOwaLgWEUYsg0o4XVrAg65'.replace('r', 's'))) @@ -177,9 +397,9 @@ def test_ec_verify_should_return_false_if_signature_invalid(self): def test_ec_verify_should_return_false_if_signature_wrong_length(self): algo = ECAlgorithm(ECAlgorithm.SHA256) - message = ensure_bytes('Hello World!') + message = force_bytes('Hello World!') - sig = base64.b64decode(ensure_bytes('AC+m4Jf/xI3guAC6w0w3')) + sig = base64.b64decode(force_bytes('AC+m4Jf/xI3guAC6w0w3')) with open(key_path('testkey_ec.pub'), 'r') as keyfile: pub_key = algo.prepare_key(keyfile.read()) @@ -191,7 +411,7 @@ def test_ec_verify_should_return_false_if_signature_wrong_length(self): def test_rsa_pss_sign_then_verify_should_return_true(self): algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256) - message = ensure_bytes('Hello World!') + message = force_bytes('Hello World!') with open(key_path('testkey_rsa'), 'r') as keyfile: priv_key = algo.prepare_key(keyfile.read()) @@ -207,9 +427,9 @@ def test_rsa_pss_sign_then_verify_should_return_true(self): def test_rsa_pss_verify_should_return_false_if_signature_invalid(self): algo = RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256) - jwt_message = ensure_bytes('Hello World!') + jwt_message = force_bytes('Hello World!') - jwt_sig = base64.b64decode(ensure_bytes( + jwt_sig = base64.b64decode(force_bytes( 'ywKAUGRIDC//6X+tjvZA96yEtMqpOrSppCNfYI7NKyon3P7doud5v65oWNu' 'vQsz0fzPGfF7mQFGo9Cm9Vn0nljm4G6PtqZRbz5fXNQBH9k10gq34AtM02c' '/cveqACQ8gF3zxWh6qr9jVqIpeMEaEBIkvqG954E0HT9s9ybHShgHX9mlWk' @@ -217,7 +437,7 @@ def test_rsa_pss_verify_should_return_false_if_signature_invalid(self): 'daOCWqbpZDuLb1imKpmm8Nsm56kAxijMLZnpCcnPgyb7CqG+B93W9GHglA5' 'drUeR1gRtO7vqbZMsCAQ4bpjXxwbYyjQlEVuMl73UL6sOWg==')) - jwt_sig += ensure_bytes('123') # Signature is now invalid + jwt_sig += force_bytes('123') # Signature is now invalid with open(key_path('testkey_rsa.pub'), 'r') as keyfile: jwt_pub_key = algo.prepare_key(keyfile.read()) @@ -239,7 +459,7 @@ def test_hmac_verify_should_return_true_for_test_vector(self): Reference: https://tools.ietf.org/html/rfc7520#section-4.4 """ - signing_input = ensure_bytes( + signing_input = force_bytes( 'eyJhbGciOiJIUzI1NiIsImtpZCI6IjAxOGMwYWU1LTRkOWItNDcxYi1iZmQ2LWVlZ' 'jMxNGJjNzAzNyJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ' '29pbmcgb3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIG' @@ -247,7 +467,7 @@ def test_hmac_verify_should_return_true_for_test_vector(self): 'gd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4' ) - signature = base64url_decode(ensure_bytes( + signature = base64url_decode(force_bytes( 's0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0' )) @@ -265,7 +485,7 @@ def test_rsa_verify_should_return_true_for_test_vector(self): Reference: https://tools.ietf.org/html/rfc7520#section-4.1 """ - signing_input = ensure_bytes( + signing_input = force_bytes( 'eyJhbGciOiJSUzI1NiIsImtpZCI6ImJpbGJvLmJhZ2dpbnNAaG9iYml0b24uZXhhb' 'XBsZSJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ29pbmcgb' '3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIGlmIHlvdS' @@ -273,7 +493,7 @@ def test_rsa_verify_should_return_true_for_test_vector(self): 'geW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4' ) - signature = base64url_decode(ensure_bytes( + signature = base64url_decode(force_bytes( 'MRjdkly7_-oTPTS3AXP41iQIGKa80A0ZmTuV5MEaHoxnW2e5CZ5NlKtainoFmKZop' 'dHM1O2U4mwzJdQx996ivp83xuglII7PNDi84wnB-BDkoBwA78185hX-Es4JIwmDLJ' 'K3lfWRa-XtL0RnltuYv746iYTh_qHRD68BNt1uSNCrUCTJDt5aAE6x8wW1Kt9eRo4' @@ -296,7 +516,7 @@ def test_rsapss_verify_should_return_true_for_test_vector(self): Reference: https://tools.ietf.org/html/rfc7520#section-4.2 """ - signing_input = ensure_bytes( + signing_input = force_bytes( 'eyJhbGciOiJQUzM4NCIsImtpZCI6ImJpbGJvLmJhZ2dpbnNAaG9iYml0b24uZXhhb' 'XBsZSJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ29pbmcgb' '3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIGlmIHlvdS' @@ -304,7 +524,7 @@ def test_rsapss_verify_should_return_true_for_test_vector(self): 'geW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4' ) - signature = base64url_decode(ensure_bytes( + signature = base64url_decode(force_bytes( 'cu22eBqkYDKgIlTpzDXGvaFfz6WGoz7fUDcfT0kkOy42miAh2qyBzk1xEsnk2IpN6' '-tPid6VrklHkqsGqDqHCdP6O8TTB5dDDItllVo6_1OLPpcbUrhiUSMxbbXUvdvWXz' 'g-UD8biiReQFlfz28zGWVsdiNAUf8ZnyPEgVFn442ZdNqiVJRmBqrYRXe8P_ijQ7p' @@ -327,7 +547,7 @@ def test_ec_verify_should_return_true_for_test_vector(self): Reference: https://tools.ietf.org/html/rfc7520#section-4.3 """ - signing_input = ensure_bytes( + signing_input = force_bytes( 'eyJhbGciOiJFUzUxMiIsImtpZCI6ImJpbGJvLmJhZ2dpbnNAaG9iYml0b24uZXhhb' 'XBsZSJ9.SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywgZ29pbmcgb' '3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9hZCwgYW5kIGlmIHlvdS' @@ -335,7 +555,7 @@ def test_ec_verify_should_return_true_for_test_vector(self): 'geW91IG1pZ2h0IGJlIHN3ZXB0IG9mZiB0by4' ) - signature = base64url_decode(ensure_bytes( + signature = base64url_decode(force_bytes( 'AE_R_YZCChjn4791jSQCrdPZCNYqHXCTZH0-JZGYNlaAjP2kqaluUIIUnC9qvbu9P' 'lon7KRTzoNEuT4Va2cmL1eJAQy3mtPBu_u_sDDyYjnAMDxXPn7XrT0lw-kvAD890j' 'l8e2puQens_IEKBpHABlsbEPX6sFY8OcGDqoRuBomu9xQ2' diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index c56ec4b2..855575c3 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -8,12 +8,11 @@ from jwt.exceptions import ( DecodeError, InvalidAlgorithmError, InvalidTokenError ) -from jwt.utils import base64url_decode +from jwt.utils import base64url_decode, force_bytes, force_unicode import pytest from .compat import string_types, text_type -from .utils import ensure_bytes, ensure_unicode try: from cryptography.hazmat.backends import default_backend @@ -34,7 +33,7 @@ def jws(): @pytest.fixture def payload(): """ Creates a sample jws claimset for use as a payload during tests """ - return ensure_bytes('hello world') + return force_bytes('hello world') class TestJWS: @@ -211,7 +210,7 @@ def test_decodes_valid_es384_jws(self, jws): b'HAG0_zxxu0JyINOFT2iqF3URYl9HZ8kZWMeZAtXmn6Cw' b'PXRJD2f7N-f7bJ5JeL9VT5beI2XD3FlK3GgRvI-eE-2Ik') decoded_payload = jws.decode(example_jws, example_pubkey) - json_payload = json.loads(ensure_unicode(decoded_payload)) + json_payload = json.loads(force_unicode(decoded_payload)) assert json_payload == example_payload @@ -236,7 +235,7 @@ def test_decodes_valid_rs384_jws(self, jws): b'uwmrtSWCBUjiN8sqJ00CDgycxKqHfUndZbEAOjcCAhBr' b'qWW3mSVivUfubsYbwUdUG3fSRPjaUPcpe8A') decoded_payload = jws.decode(example_jws, example_pubkey) - json_payload = json.loads(ensure_unicode(decoded_payload)) + json_payload = json.loads(force_unicode(decoded_payload)) assert json_payload == example_payload @@ -410,12 +409,12 @@ def test_get_unverified_header_fails_on_bad_header_types(self, jws, payload): def test_encode_decode_with_rsa_sha256(self, jws, payload): # PEM-formatted RSA key with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: - priv_rsakey = load_pem_private_key(ensure_bytes(rsa_priv_file.read()), + priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()), password=None, backend=default_backend()) jws_message = jws.encode(payload, priv_rsakey, algorithm='RS256') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: - pub_rsakey = load_ssh_public_key(ensure_bytes(rsa_pub_file.read()), + pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()), backend=default_backend()) jws.decode(jws_message, pub_rsakey) @@ -433,12 +432,12 @@ def test_encode_decode_with_rsa_sha256(self, jws, payload): def test_encode_decode_with_rsa_sha384(self, jws, payload): # PEM-formatted RSA key with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: - priv_rsakey = load_pem_private_key(ensure_bytes(rsa_priv_file.read()), + priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()), password=None, backend=default_backend()) jws_message = jws.encode(payload, priv_rsakey, algorithm='RS384') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: - pub_rsakey = load_ssh_public_key(ensure_bytes(rsa_pub_file.read()), + pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()), backend=default_backend()) jws.decode(jws_message, pub_rsakey) @@ -455,12 +454,12 @@ def test_encode_decode_with_rsa_sha384(self, jws, payload): def test_encode_decode_with_rsa_sha512(self, jws, payload): # PEM-formatted RSA key with open('tests/keys/testkey_rsa', 'r') as rsa_priv_file: - priv_rsakey = load_pem_private_key(ensure_bytes(rsa_priv_file.read()), + priv_rsakey = load_pem_private_key(force_bytes(rsa_priv_file.read()), password=None, backend=default_backend()) jws_message = jws.encode(payload, priv_rsakey, algorithm='RS512') with open('tests/keys/testkey_rsa.pub', 'r') as rsa_pub_file: - pub_rsakey = load_ssh_public_key(ensure_bytes(rsa_pub_file.read()), + pub_rsakey = load_ssh_public_key(force_bytes(rsa_pub_file.read()), backend=default_backend()) jws.decode(jws_message, pub_rsakey) @@ -497,12 +496,12 @@ def test_rsa_related_algorithms(self, jws): def test_encode_decode_with_ecdsa_sha256(self, jws, payload): # PEM-formatted EC key with open('tests/keys/testkey_ec', 'r') as ec_priv_file: - priv_eckey = load_pem_private_key(ensure_bytes(ec_priv_file.read()), + priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()), password=None, backend=default_backend()) jws_message = jws.encode(payload, priv_eckey, algorithm='ES256') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: - pub_eckey = load_pem_public_key(ensure_bytes(ec_pub_file.read()), + pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), backend=default_backend()) jws.decode(jws_message, pub_eckey) @@ -520,12 +519,12 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload): # PEM-formatted EC key with open('tests/keys/testkey_ec', 'r') as ec_priv_file: - priv_eckey = load_pem_private_key(ensure_bytes(ec_priv_file.read()), + priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()), password=None, backend=default_backend()) jws_message = jws.encode(payload, priv_eckey, algorithm='ES384') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: - pub_eckey = load_pem_public_key(ensure_bytes(ec_pub_file.read()), + pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), backend=default_backend()) jws.decode(jws_message, pub_eckey) @@ -542,12 +541,12 @@ def test_encode_decode_with_ecdsa_sha384(self, jws, payload): def test_encode_decode_with_ecdsa_sha512(self, jws, payload): # PEM-formatted EC key with open('tests/keys/testkey_ec', 'r') as ec_priv_file: - priv_eckey = load_pem_private_key(ensure_bytes(ec_priv_file.read()), + priv_eckey = load_pem_private_key(force_bytes(ec_priv_file.read()), password=None, backend=default_backend()) jws_message = jws.encode(payload, priv_eckey, algorithm='ES512') with open('tests/keys/testkey_ec.pub', 'r') as ec_pub_file: - pub_eckey = load_pem_public_key(ensure_bytes(ec_pub_file.read()), backend=default_backend()) + pub_eckey = load_pem_public_key(force_bytes(ec_pub_file.read()), backend=default_backend()) jws.decode(jws_message, pub_eckey) # string-formatted key @@ -606,8 +605,8 @@ def default(self, o): token = jws.encode(payload, 'secret', headers=data, json_encoder=CustomJSONEncoder) - header = ensure_bytes(ensure_unicode(token).split('.')[0]) - header = json.loads(ensure_unicode(base64url_decode(header))) + header = force_bytes(force_unicode(token).split('.')[0]) + header = json.loads(force_unicode(base64url_decode(header))) assert 'some_decimal' in header assert header['some_decimal'] == 'it worked' diff --git a/tests/test_compat.py b/tests/test_compat.py index 6f6d6d0e..10beb94e 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,20 +1,19 @@ from jwt.compat import constant_time_compare - -from .utils import ensure_bytes +from jwt.utils import force_bytes class TestCompat: def test_constant_time_compare_returns_true_if_same(self): assert constant_time_compare( - ensure_bytes('abc'), ensure_bytes('abc') + force_bytes('abc'), force_bytes('abc') ) def test_constant_time_compare_returns_false_if_diff_lengths(self): assert not constant_time_compare( - ensure_bytes('abc'), ensure_bytes('abcd') + force_bytes('abc'), force_bytes('abcd') ) def test_constant_time_compare_returns_false_if_totally_different(self): assert not constant_time_compare( - ensure_bytes('abcd'), ensure_bytes('efgh') + force_bytes('abcd'), force_bytes('efgh') ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..7549faa2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,40 @@ +from jwt.utils import ( + force_bytes, force_unicode, from_base64url_uint, to_base64url_uint +) + +import pytest + + +@pytest.mark.parametrize("inputval,expected", [ + (0, b'AA'), + (1, b'AQ'), + (255, b'_w'), + (65537, b'AQAB'), + (123456789, b'B1vNFQ'), + pytest.mark.xfail((-1, ''), raises=ValueError) +]) +def test_to_base64url_uint(inputval, expected): + actual = to_base64url_uint(inputval) + assert actual == expected + + +@pytest.mark.parametrize("inputval,expected", [ + (b'AA', 0), + (b'AQ', 1), + (b'_w', 255), + (b'AQAB', 65537), + (b'B1vNFQ', 123456789, ), +]) +def test_from_base64url_uint(inputval, expected): + actual = from_base64url_uint(inputval) + assert actual == expected + + +def test_force_unicode_raises_error_on_invalid_object(): + with pytest.raises(TypeError): + force_unicode({}) + + +def test_force_bytes_raises_error_on_invalid_object(): + with pytest.raises(TypeError): + force_bytes({}) diff --git a/tests/utils.py b/tests/utils.py index 77227023..2e3f043b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,22 +4,6 @@ from calendar import timegm from datetime import datetime -from .compat import text_type - - -def ensure_bytes(key): - if isinstance(key, text_type): - key = key.encode('utf-8') - - return key - - -def ensure_unicode(key): - if not isinstance(key, text_type): - key = key.decode() - - return key - def utc_timestamp(): return timegm(datetime.utcnow().utctimetuple())