Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Fixed
- Close ``HTTPError`` response to prevent ``ResourceWarning`` on Python 3.14 by @veeceey in `#1133 <https://github.com/jpadilla/pyjwt/pull/1133>`__
- Do not keep ``algorithms`` dict in PyJWK instances by @akx in `#1143 <https://github.com/jpadilla/pyjwt/pull/1143>`__
- Validate the crit (Critical) Header Parameter defined in RFC 7515 §4.1.11. by @dmbs335 in `GHSA-752w-5fwx-jx9f <https://github.com/jpadilla/pyjwt/security/advisories/GHSA-752w-5fwx-jx9f>`__
- Use PyJWK algorithm when encoding without explicit algorithm in `#1148 <https://github.com/jpadilla/pyjwt/pull/1148>`__

Added
~~~~~
Expand Down
11 changes: 9 additions & 2 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
from .types import SigOptions

_ALGORITHM_UNSET = object()


class PyJWS:
header_typ = "JWT"
Expand Down Expand Up @@ -119,7 +121,7 @@ def encode(
self,
payload: bytes,
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = "HS256",
algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment]
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
Expand All @@ -128,7 +130,12 @@ def encode(
segments: list[bytes] = []

# declare a new var to narrow the type for type checkers
if algorithm is None:
if algorithm is _ALGORITHM_UNSET:
if isinstance(key, PyJWK):
algorithm_ = key.algorithm_name
else:
algorithm_ = "HS256"
elif algorithm is None:
if isinstance(key, PyJWK):
algorithm_ = key.algorithm_name
else:
Expand Down
4 changes: 2 additions & 2 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Union, cast

from .api_jws import PyJWS, _jws_global_obj
from .api_jws import PyJWS, _ALGORITHM_UNSET, _jws_global_obj
from .exceptions import (
DecodeError,
ExpiredSignatureError,
Expand Down Expand Up @@ -91,7 +91,7 @@ def encode(
self,
payload: dict[str, Any],
key: AllowedPrivateKeyTypes,
algorithm: str | None = "HS256",
algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment]
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,27 @@ def test_encode_with_jwk(self, jws: PyJWS, payload: bytes) -> None:
),
}

def test_encode_with_jwk_uses_key_algorithm(
self, jws: PyJWS, payload: bytes
) -> None:
"""Test that encoding with a PyJWK key uses the key's algorithm
when no algorithm is explicitly specified. Regression test for #1147."""
jwk = PyJWK(
{
"kty": "oct",
"alg": "HS384",
"k": "c2VjcmV0", # "secret"
}
)
# Should use HS384 from the key, not default to HS256
msg = jws.encode(payload, key=jwk)
header = jws.get_unverified_header(msg)
assert header["alg"] == "HS384"

# Should also be decodable with the same key
decoded = jws.decode(msg, key=jwk)
assert decoded == payload

def test_decode_algorithm_param_should_be_case_sensitive(self, jws: PyJWS) -> None:
example_jws = (
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256
Expand Down
17 changes: 17 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from jwt.types import Options
from jwt.api_jwk import PyJWK
from jwt.api_jwt import PyJWT
from jwt.exceptions import (
DecodeError,
Expand Down Expand Up @@ -45,6 +46,22 @@ def test_jwt_with_options(self) -> None:
# assert that verify_signature is respected unless verify_exp is overridden
assert jwt.options["verify_exp"] is False

def test_encode_with_jwk_uses_key_algorithm(self, jwt: PyJWT) -> None:
"""Test that encoding with a PyJWK key uses the key's algorithm
when no algorithm is explicitly specified. Regression test for #1147."""
jwk = PyJWK(
{
"kty": "oct",
"alg": "HS384",
"k": "c2VjcmV0", # "secret"
}
)
payload = {"hello": "world"}
# Should use HS384 from the key, not default to HS256
token = jwt.encode(payload, jwk)
header = jwt.decode_complete(token, jwk, algorithms=["HS384"])["header"]
assert header["alg"] == "HS384"

def test_decodes_valid_jwt(self, jwt: PyJWT) -> None:
example_payload = {"hello": "world"}
example_secret = "secret"
Expand Down
Loading