From 057bc6f46d14a1e22b2f5aa5430011d7f99f7616 Mon Sep 17 00:00:00 2001 From: Daniel Garcia Moreno Date: Thu, 2 May 2024 09:42:34 +0200 Subject: [PATCH] jwt.decode: Require algorithms keyword argument On decode, require algorithms to be specified to avoid algorithm confusion when verify_signature is True. This is similar to what pyJWT is doing in https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py#L146-L149 See https://github.com/mpdavis/python-jose/issues/346 --- jose/jwt.py | 8 +++++ tests/test_jwt.py | 88 +++++++++++++++++++++++++---------------------- 2 files changed, 54 insertions(+), 42 deletions(-) diff --git a/jose/jwt.py b/jose/jwt.py index b364b4ba..ae7e5b3b 100644 --- a/jose/jwt.py +++ b/jose/jwt.py @@ -141,6 +141,14 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None verify_signature = defaults.get("verify_signature", True) + # Forbid the usage of the jwt.decode without alogrightms parameter + # See https://github.com/mpdavis/python-jose/issues/346 for more + # information CVE-2024-33663 + if verify_signature and algorithms is None: + raise JWTError("It is required that you pass in a value for " + 'the "algorithms" argument when calling ' + "decode().") + try: payload = jws.verify(token, key, algorithms, verify=verify_signature) except JWSError as e: diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 378504f8..d1acbc12 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -5,6 +5,7 @@ import pytest from jose import jws, jwt +from jose.constants import ALGORITHMS from jose.exceptions import JWTError, JWKError @@ -56,7 +57,7 @@ def test_no_alg(self, claims, key): ], ) def test_numeric_key(self, key, token): - token_info = jwt.decode(token, key) + token_info = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED) assert token_info == {"name": "test"} def test_invalid_claims_json(self): @@ -108,7 +109,7 @@ def test_no_alg_default_headers(self, claims, key, headers): def test_non_default_headers(self, claims, key, headers): encoded = jwt.encode(claims, key, headers=headers) - decoded = jwt.decode(encoded, key) + decoded = jwt.decode(encoded, key, algorithms=ALGORITHMS.HS256) assert claims == decoded all_headers = jwt.get_unverified_headers(encoded) for k, v in headers.items(): @@ -159,7 +160,7 @@ def test_encode(self, claims, key): def test_decode(self, claims, key): token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" ".eyJhIjoiYiJ9" ".jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8" - decoded = jwt.decode(token, key) + decoded = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED) assert decoded == claims @@ -190,7 +191,7 @@ def test_leeway_is_timedelta(self, claims, key): options = {"leeway": leeway} token = jwt.encode(claims, key) - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_iat_not_int(self, key): claims = {"iat": "test"} @@ -198,7 +199,7 @@ def test_iat_not_int(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_not_int(self, key): claims = {"nbf": "test"} @@ -206,7 +207,7 @@ def test_nbf_not_int(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_datetime(self, key): nbf = datetime.utcnow() - timedelta(seconds=5) @@ -214,7 +215,7 @@ def test_nbf_datetime(self, key): claims = {"nbf": nbf} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_with_leeway(self, key): nbf = datetime.utcnow() + timedelta(seconds=5) @@ -226,7 +227,7 @@ def test_nbf_with_leeway(self, key): options = {"leeway": 10} token = jwt.encode(claims, key) - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_nbf_in_future(self, key): nbf = datetime.utcnow() + timedelta(seconds=5) @@ -236,7 +237,7 @@ def test_nbf_in_future(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_nbf_skip(self, key): nbf = datetime.utcnow() + timedelta(seconds=5) @@ -246,11 +247,11 @@ def test_nbf_skip(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) options = {"verify_nbf": False} - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_exp_not_int(self, key): claims = {"exp": "test"} @@ -258,7 +259,7 @@ def test_exp_not_int(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_exp_datetime(self, key): exp = datetime.utcnow() + timedelta(seconds=5) @@ -266,7 +267,7 @@ def test_exp_datetime(self, key): claims = {"exp": exp} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_exp_with_leeway(self, key): exp = datetime.utcnow() - timedelta(seconds=5) @@ -278,7 +279,7 @@ def test_exp_with_leeway(self, key): options = {"leeway": 10} token = jwt.encode(claims, key) - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_exp_in_past(self, key): exp = datetime.utcnow() - timedelta(seconds=5) @@ -288,7 +289,7 @@ def test_exp_in_past(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_exp_skip(self, key): exp = datetime.utcnow() - timedelta(seconds=5) @@ -298,11 +299,11 @@ def test_exp_skip(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) options = {"verify_exp": False} - jwt.decode(token, key, options=options) + jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256) def test_aud_string(self, key): aud = "audience" @@ -310,7 +311,7 @@ def test_aud_string(self, key): claims = {"aud": aud} token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_list(self, key): aud = "audience" @@ -318,7 +319,7 @@ def test_aud_list(self, key): claims = {"aud": [aud]} token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_list_multiple(self, key): aud = "audience" @@ -326,7 +327,7 @@ def test_aud_list_multiple(self, key): claims = {"aud": [aud, "another"]} token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_list_is_strings(self, key): aud = "audience" @@ -335,7 +336,7 @@ def test_aud_list_is_strings(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_case_sensitive(self, key): aud = "audience" @@ -344,13 +345,13 @@ def test_aud_case_sensitive(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, audience="AUDIENCE") + jwt.decode(token, key, audience="AUDIENCE", algorithms=ALGORITHMS.HS256) def test_aud_empty_claim(self, claims, key): aud = "audience" token = jwt.encode(claims, key) - jwt.decode(token, key, audience=aud) + jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256) def test_aud_not_string_or_list(self, key): aud = 1 @@ -359,7 +360,7 @@ def test_aud_not_string_or_list(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_aud_given_number(self, key): aud = "audience" @@ -368,7 +369,7 @@ def test_aud_given_number(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, audience=1) + jwt.decode(token, key, audience=1, algorithms=ALGORITHMS.HS256) def test_iss_string(self, key): iss = "issuer" @@ -376,7 +377,7 @@ def test_iss_string(self, key): claims = {"iss": iss} token = jwt.encode(claims, key) - jwt.decode(token, key, issuer=iss) + jwt.decode(token, key, issuer=iss, algorithms=ALGORITHMS.HS256) def test_iss_list(self, key): iss = "issuer" @@ -384,7 +385,7 @@ def test_iss_list(self, key): claims = {"iss": iss} token = jwt.encode(claims, key) - jwt.decode(token, key, issuer=["https://issuer", "issuer"]) + jwt.decode(token, key, issuer=["https://issuer", "issuer"], algorithms=ALGORITHMS.HS256) def test_iss_tuple(self, key): iss = "issuer" @@ -392,7 +393,7 @@ def test_iss_tuple(self, key): claims = {"iss": iss} token = jwt.encode(claims, key) - jwt.decode(token, key, issuer=("https://issuer", "issuer")) + jwt.decode(token, key, issuer=("https://issuer", "issuer"), algorithms=ALGORITHMS.HS256) def test_iss_invalid(self, key): iss = "issuer" @@ -401,7 +402,7 @@ def test_iss_invalid(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, issuer="another") + jwt.decode(token, key, issuer="another", algorithms=ALGORITHMS.HS256) def test_sub_string(self, key): sub = "subject" @@ -409,7 +410,7 @@ def test_sub_string(self, key): claims = {"sub": sub} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_sub_invalid(self, key): sub = 1 @@ -418,7 +419,7 @@ def test_sub_invalid(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_sub_correct(self, key): sub = "subject" @@ -426,7 +427,7 @@ def test_sub_correct(self, key): claims = {"sub": sub} token = jwt.encode(claims, key) - jwt.decode(token, key, subject=sub) + jwt.decode(token, key, subject=sub, algorithms=ALGORITHMS.HS256) def test_sub_incorrect(self, key): sub = "subject" @@ -435,7 +436,7 @@ def test_sub_incorrect(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, subject="another") + jwt.decode(token, key, subject="another", algorithms=ALGORITHMS.HS256) def test_jti_string(self, key): jti = "JWT ID" @@ -443,7 +444,7 @@ def test_jti_string(self, key): claims = {"jti": jti} token = jwt.encode(claims, key) - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_jti_invalid(self, key): jti = 1 @@ -452,33 +453,33 @@ def test_jti_invalid(self, key): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_at_hash(self, claims, key): access_token = "" token = jwt.encode(claims, key, access_token=access_token) - payload = jwt.decode(token, key, access_token=access_token) + payload = jwt.decode(token, key, access_token=access_token, algorithms=ALGORITHMS.HS256) assert "at_hash" in payload def test_at_hash_invalid(self, claims, key): token = jwt.encode(claims, key, access_token="") with pytest.raises(JWTError): - jwt.decode(token, key, access_token="") + jwt.decode(token, key, access_token="", algorithms=ALGORITHMS.HS256) def test_at_hash_missing_access_token(self, claims, key): token = jwt.encode(claims, key, access_token="") with pytest.raises(JWTError): - jwt.decode(token, key) + jwt.decode(token, key, algorithms=ALGORITHMS.HS256) def test_at_hash_missing_claim(self, claims, key): token = jwt.encode(claims, key) - payload = jwt.decode(token, key, access_token="") + payload = jwt.decode(token, key, access_token="", algorithms=ALGORITHMS.HS256) assert "at_hash" not in payload def test_at_hash_unable_to_calculate(self, claims, key): token = jwt.encode(claims, key, access_token="") with pytest.raises(JWTError): - jwt.decode(token, key, access_token="\xe2") + jwt.decode(token, key, access_token="\xe2", algorithms=ALGORITHMS.HS256) def test_bad_claims(self): bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck" @@ -516,12 +517,12 @@ def test_require(self, claims, key, claim, value): token = jwt.encode(claims, key) with pytest.raises(JWTError): - jwt.decode(token, key, options=options, audience=str(value)) + jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256) new_claims = dict(claims) new_claims[claim] = value token = jwt.encode(new_claims, key) - jwt.decode(token, key, options=options, audience=str(value)) + jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256) def test_CVE_2024_33663(self): """Test based on https://github.com/mpdavis/python-jose/issues/346""" @@ -554,4 +555,7 @@ def test_CVE_2024_33663(self): # algorithm field is left unspecified # but the library will happily still verify without warning, trusting the user-controlled alg field of the token header with pytest.raises(JWKError): + data = jwt.decode(evil_token, PUBKEY, algorithms=ALGORITHMS.HS256) + + with pytest.raises(JWTError, match='.*required.*"algorithms".*'): data = jwt.decode(evil_token, PUBKEY)