diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cc6ad4e7..3a580064 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Fixed - Close ``HTTPError`` response to prevent ``ResourceWarning`` on Python 3.14 by @veeceey in `#1133 `__ - Do not keep ``algorithms`` dict in PyJWK instances by @akx in `#1143 `__ - Validate the crit (Critical) Header Parameter defined in RFC 7515 ยง4.1.11. by @dmbs335 in `GHSA-752w-5fwx-jx9f `__ +- Use PyJWK algorithm when encoding without explicit algorithm in `#1148 `__ Added ~~~~~ diff --git a/jwt/api_jws.py b/jwt/api_jws.py index 1fdb8cee..0ab7e4b4 100644 --- a/jwt/api_jws.py +++ b/jwt/api_jws.py @@ -27,6 +27,8 @@ from .algorithms import AllowedPrivateKeys, AllowedPublicKeys from .types import SigOptions +_ALGORITHM_UNSET = object() + class PyJWS: header_typ = "JWT" @@ -119,7 +121,7 @@ def encode( self, payload: bytes, key: AllowedPrivateKeys | PyJWK | str | bytes, - algorithm: str | None = "HS256", + algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment] headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, @@ -128,7 +130,12 @@ def encode( segments: list[bytes] = [] # declare a new var to narrow the type for type checkers - if algorithm is None: + if algorithm is _ALGORITHM_UNSET: + if isinstance(key, PyJWK): + algorithm_ = key.algorithm_name + else: + algorithm_ = "HS256" + elif algorithm is None: if isinstance(key, PyJWK): algorithm_ = key.algorithm_name else: diff --git a/jwt/api_jwt.py b/jwt/api_jwt.py index a26fa9c9..429c2d79 100644 --- a/jwt/api_jwt.py +++ b/jwt/api_jwt.py @@ -8,7 +8,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Union, cast -from .api_jws import PyJWS, _jws_global_obj +from .api_jws import PyJWS, _ALGORITHM_UNSET, _jws_global_obj from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -91,7 +91,7 @@ def encode( self, payload: dict[str, Any], key: AllowedPrivateKeyTypes, - algorithm: str | None = "HS256", + algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment] headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, sort_headers: bool = True, diff --git a/tests/test_api_jws.py b/tests/test_api_jws.py index bfe2b6fa..9f7edc04 100644 --- a/tests/test_api_jws.py +++ b/tests/test_api_jws.py @@ -261,6 +261,27 @@ def test_encode_with_jwk(self, jws: PyJWS, payload: bytes) -> None: ), } + def test_encode_with_jwk_uses_key_algorithm( + self, jws: PyJWS, payload: bytes + ) -> None: + """Test that encoding with a PyJWK key uses the key's algorithm + when no algorithm is explicitly specified. Regression test for #1147.""" + jwk = PyJWK( + { + "kty": "oct", + "alg": "HS384", + "k": "c2VjcmV0", # "secret" + } + ) + # Should use HS384 from the key, not default to HS256 + msg = jws.encode(payload, key=jwk) + header = jws.get_unverified_header(msg) + assert header["alg"] == "HS384" + + # Should also be decodable with the same key + decoded = jws.decode(msg, key=jwk) + assert decoded == payload + def test_decode_algorithm_param_should_be_case_sensitive(self, jws: PyJWS) -> None: example_jws = ( "eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256 diff --git a/tests/test_api_jwt.py b/tests/test_api_jwt.py index cb20e053..8b54204a 100644 --- a/tests/test_api_jwt.py +++ b/tests/test_api_jwt.py @@ -7,6 +7,7 @@ import pytest from jwt.types import Options +from jwt.api_jwk import PyJWK from jwt.api_jwt import PyJWT from jwt.exceptions import ( DecodeError, @@ -45,6 +46,22 @@ def test_jwt_with_options(self) -> None: # assert that verify_signature is respected unless verify_exp is overridden assert jwt.options["verify_exp"] is False + def test_encode_with_jwk_uses_key_algorithm(self, jwt: PyJWT) -> None: + """Test that encoding with a PyJWK key uses the key's algorithm + when no algorithm is explicitly specified. Regression test for #1147.""" + jwk = PyJWK( + { + "kty": "oct", + "alg": "HS384", + "k": "c2VjcmV0", # "secret" + } + ) + payload = {"hello": "world"} + # Should use HS384 from the key, not default to HS256 + token = jwt.encode(payload, jwk) + header = jwt.decode_complete(token, jwk, algorithms=["HS384"])["header"] + assert header["alg"] == "HS384" + def test_decodes_valid_jwt(self, jwt: PyJWT) -> None: example_payload = {"hello": "world"} example_secret = "secret"