Skip to content

Commit

Permalink
jwt.decode: Require algorithms keyword argument
Browse files Browse the repository at this point in the history
On decode, require algorithms to be specified to avoid algorithm
confusion when verify_signature is True.

This is similar to what pyJWT is doing in
https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py#L146-L149

See mpdavis#346
  • Loading branch information
danigm committed May 2, 2024
1 parent 34bd82c commit 057bc6f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 42 deletions.
8 changes: 8 additions & 0 deletions jose/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None

verify_signature = defaults.get("verify_signature", True)

# Forbid the usage of the jwt.decode without alogrightms parameter
# See https://github.com/mpdavis/python-jose/issues/346 for more
# information CVE-2024-33663
if verify_signature and algorithms is None:
raise JWTError("It is required that you pass in a value for "
'the "algorithms" argument when calling '
"decode().")

try:
payload = jws.verify(token, key, algorithms, verify=verify_signature)
except JWSError as e:
Expand Down
88 changes: 46 additions & 42 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from jose import jws, jwt
from jose.constants import ALGORITHMS
from jose.exceptions import JWTError, JWKError


Expand Down Expand Up @@ -56,7 +57,7 @@ def test_no_alg(self, claims, key):
],
)
def test_numeric_key(self, key, token):
token_info = jwt.decode(token, key)
token_info = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED)
assert token_info == {"name": "test"}

def test_invalid_claims_json(self):
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_no_alg_default_headers(self, claims, key, headers):

def test_non_default_headers(self, claims, key, headers):
encoded = jwt.encode(claims, key, headers=headers)
decoded = jwt.decode(encoded, key)
decoded = jwt.decode(encoded, key, algorithms=ALGORITHMS.HS256)
assert claims == decoded
all_headers = jwt.get_unverified_headers(encoded)
for k, v in headers.items():
Expand Down Expand Up @@ -159,7 +160,7 @@ def test_encode(self, claims, key):
def test_decode(self, claims, key):
token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" ".eyJhIjoiYiJ9" ".jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8"

decoded = jwt.decode(token, key)
decoded = jwt.decode(token, key, algorithms=ALGORITHMS.SUPPORTED)

assert decoded == claims

Expand Down Expand Up @@ -190,31 +191,31 @@ def test_leeway_is_timedelta(self, claims, key):
options = {"leeway": leeway}

token = jwt.encode(claims, key)
jwt.decode(token, key, options=options)
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)

def test_iat_not_int(self, key):
claims = {"iat": "test"}

token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_nbf_not_int(self, key):
claims = {"nbf": "test"}

token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_nbf_datetime(self, key):
nbf = datetime.utcnow() - timedelta(seconds=5)

claims = {"nbf": nbf}

token = jwt.encode(claims, key)
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_nbf_with_leeway(self, key):
nbf = datetime.utcnow() + timedelta(seconds=5)
Expand All @@ -226,7 +227,7 @@ def test_nbf_with_leeway(self, key):
options = {"leeway": 10}

token = jwt.encode(claims, key)
jwt.decode(token, key, options=options)
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)

def test_nbf_in_future(self, key):
nbf = datetime.utcnow() + timedelta(seconds=5)
Expand All @@ -236,7 +237,7 @@ def test_nbf_in_future(self, key):
token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_nbf_skip(self, key):
nbf = datetime.utcnow() + timedelta(seconds=5)
Expand All @@ -246,27 +247,27 @@ def test_nbf_skip(self, key):
token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

options = {"verify_nbf": False}

jwt.decode(token, key, options=options)
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)

def test_exp_not_int(self, key):
claims = {"exp": "test"}

token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_exp_datetime(self, key):
exp = datetime.utcnow() + timedelta(seconds=5)

claims = {"exp": exp}

token = jwt.encode(claims, key)
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_exp_with_leeway(self, key):
exp = datetime.utcnow() - timedelta(seconds=5)
Expand All @@ -278,7 +279,7 @@ def test_exp_with_leeway(self, key):
options = {"leeway": 10}

token = jwt.encode(claims, key)
jwt.decode(token, key, options=options)
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)

def test_exp_in_past(self, key):
exp = datetime.utcnow() - timedelta(seconds=5)
Expand All @@ -288,7 +289,7 @@ def test_exp_in_past(self, key):
token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_exp_skip(self, key):
exp = datetime.utcnow() - timedelta(seconds=5)
Expand All @@ -298,35 +299,35 @@ def test_exp_skip(self, key):
token = jwt.encode(claims, key)

with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

options = {"verify_exp": False}

jwt.decode(token, key, options=options)
jwt.decode(token, key, options=options, algorithms=ALGORITHMS.HS256)

def test_aud_string(self, key):
aud = "audience"

claims = {"aud": aud}

token = jwt.encode(claims, key)
jwt.decode(token, key, audience=aud)
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)

def test_aud_list(self, key):
aud = "audience"

claims = {"aud": [aud]}

token = jwt.encode(claims, key)
jwt.decode(token, key, audience=aud)
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)

def test_aud_list_multiple(self, key):
aud = "audience"

claims = {"aud": [aud, "another"]}

token = jwt.encode(claims, key)
jwt.decode(token, key, audience=aud)
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)

def test_aud_list_is_strings(self, key):
aud = "audience"
Expand All @@ -335,7 +336,7 @@ def test_aud_list_is_strings(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, audience=aud)
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)

def test_aud_case_sensitive(self, key):
aud = "audience"
Expand All @@ -344,13 +345,13 @@ def test_aud_case_sensitive(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, audience="AUDIENCE")
jwt.decode(token, key, audience="AUDIENCE", algorithms=ALGORITHMS.HS256)

def test_aud_empty_claim(self, claims, key):
aud = "audience"

token = jwt.encode(claims, key)
jwt.decode(token, key, audience=aud)
jwt.decode(token, key, audience=aud, algorithms=ALGORITHMS.HS256)

def test_aud_not_string_or_list(self, key):
aud = 1
Expand All @@ -359,7 +360,7 @@ def test_aud_not_string_or_list(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_aud_given_number(self, key):
aud = "audience"
Expand All @@ -368,31 +369,31 @@ def test_aud_given_number(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, audience=1)
jwt.decode(token, key, audience=1, algorithms=ALGORITHMS.HS256)

def test_iss_string(self, key):
iss = "issuer"

claims = {"iss": iss}

token = jwt.encode(claims, key)
jwt.decode(token, key, issuer=iss)
jwt.decode(token, key, issuer=iss, algorithms=ALGORITHMS.HS256)

def test_iss_list(self, key):
iss = "issuer"

claims = {"iss": iss}

token = jwt.encode(claims, key)
jwt.decode(token, key, issuer=["https://issuer", "issuer"])
jwt.decode(token, key, issuer=["https://issuer", "issuer"], algorithms=ALGORITHMS.HS256)

def test_iss_tuple(self, key):
iss = "issuer"

claims = {"iss": iss}

token = jwt.encode(claims, key)
jwt.decode(token, key, issuer=("https://issuer", "issuer"))
jwt.decode(token, key, issuer=("https://issuer", "issuer"), algorithms=ALGORITHMS.HS256)

def test_iss_invalid(self, key):
iss = "issuer"
Expand All @@ -401,15 +402,15 @@ def test_iss_invalid(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, issuer="another")
jwt.decode(token, key, issuer="another", algorithms=ALGORITHMS.HS256)

def test_sub_string(self, key):
sub = "subject"

claims = {"sub": sub}

token = jwt.encode(claims, key)
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_sub_invalid(self, key):
sub = 1
Expand All @@ -418,15 +419,15 @@ def test_sub_invalid(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_sub_correct(self, key):
sub = "subject"

claims = {"sub": sub}

token = jwt.encode(claims, key)
jwt.decode(token, key, subject=sub)
jwt.decode(token, key, subject=sub, algorithms=ALGORITHMS.HS256)

def test_sub_incorrect(self, key):
sub = "subject"
Expand All @@ -435,15 +436,15 @@ def test_sub_incorrect(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, subject="another")
jwt.decode(token, key, subject="another", algorithms=ALGORITHMS.HS256)

def test_jti_string(self, key):
jti = "JWT ID"

claims = {"jti": jti}

token = jwt.encode(claims, key)
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_jti_invalid(self, key):
jti = 1
Expand All @@ -452,33 +453,33 @@ def test_jti_invalid(self, key):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_at_hash(self, claims, key):
access_token = "<ACCESS_TOKEN>"
token = jwt.encode(claims, key, access_token=access_token)
payload = jwt.decode(token, key, access_token=access_token)
payload = jwt.decode(token, key, access_token=access_token, algorithms=ALGORITHMS.HS256)
assert "at_hash" in payload

def test_at_hash_invalid(self, claims, key):
token = jwt.encode(claims, key, access_token="<ACCESS_TOKEN>")
with pytest.raises(JWTError):
jwt.decode(token, key, access_token="<OTHER_TOKEN>")
jwt.decode(token, key, access_token="<OTHER_TOKEN>", algorithms=ALGORITHMS.HS256)

def test_at_hash_missing_access_token(self, claims, key):
token = jwt.encode(claims, key, access_token="<ACCESS_TOKEN>")
with pytest.raises(JWTError):
jwt.decode(token, key)
jwt.decode(token, key, algorithms=ALGORITHMS.HS256)

def test_at_hash_missing_claim(self, claims, key):
token = jwt.encode(claims, key)
payload = jwt.decode(token, key, access_token="<ACCESS_TOKEN>")
payload = jwt.decode(token, key, access_token="<ACCESS_TOKEN>", algorithms=ALGORITHMS.HS256)
assert "at_hash" not in payload

def test_at_hash_unable_to_calculate(self, claims, key):
token = jwt.encode(claims, key, access_token="<ACCESS_TOKEN>")
with pytest.raises(JWTError):
jwt.decode(token, key, access_token="\xe2")
jwt.decode(token, key, access_token="\xe2", algorithms=ALGORITHMS.HS256)

def test_bad_claims(self):
bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck"
Expand Down Expand Up @@ -516,12 +517,12 @@ def test_require(self, claims, key, claim, value):

token = jwt.encode(claims, key)
with pytest.raises(JWTError):
jwt.decode(token, key, options=options, audience=str(value))
jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256)

new_claims = dict(claims)
new_claims[claim] = value
token = jwt.encode(new_claims, key)
jwt.decode(token, key, options=options, audience=str(value))
jwt.decode(token, key, options=options, audience=str(value), algorithms=ALGORITHMS.HS256)

def test_CVE_2024_33663(self):
"""Test based on https://github.com/mpdavis/python-jose/issues/346"""
Expand Down Expand Up @@ -554,4 +555,7 @@ def test_CVE_2024_33663(self):
# algorithm field is left unspecified
# but the library will happily still verify without warning, trusting the user-controlled alg field of the token header
with pytest.raises(JWKError):
data = jwt.decode(evil_token, PUBKEY, algorithms=ALGORITHMS.HS256)

with pytest.raises(JWTError, match='.*required.*"algorithms".*'):
data = jwt.decode(evil_token, PUBKEY)

0 comments on commit 057bc6f

Please sign in to comment.