Skip to content

Commit

Permalink
refactor: backends design + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
k4black committed May 6, 2024
1 parent 93cbdee commit ad4221d
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 321 deletions.
34 changes: 19 additions & 15 deletions fastapi_jwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@
from fastapi.security import APIKeyCookie, HTTPBearer
from starlette.status import HTTP_401_UNAUTHORIZED

from .jwt_backends import AbstractJWTBackend, AuthlibJWTBackend, PythonJoseJWTBackend
from .jwt_backends import AbstractJWTBackend, authlib_backend, python_jose_backend
from .jwt_backends.abstract_backend import BackendException

DEFAULT_JWT_BACKEND: Optional[Type[AbstractJWTBackend]] = None
if authlib_backend.authlib_jose is not None:
DEFAULT_JWT_BACKEND = authlib_backend.AuthlibJWTBackend
elif python_jose_backend.jose is not None:
DEFAULT_JWT_BACKEND = python_jose_backend.PythonJoseJWTBackend

Check warning on line 19 in fastapi_jwt/jwt.py

View check run for this annotation

Codecov / codecov/patch

fastapi_jwt/jwt.py#L18-L19

Added lines #L18 - L19 were not covered by tests
else: # pragma: nocover
raise ImportError("No JWT backend found, please install 'python-jose' or 'authlib'")


def define_default_jwt_backend(cls: Type[AbstractJWTBackend]) -> None:
def force_jwt_backend(cls: Type[AbstractJWTBackend]) -> None:
global DEFAULT_JWT_BACKEND
DEFAULT_JWT_BACKEND = cls


if AuthlibJWTBackend is not None:
DEFAULT_JWT_BACKEND = AuthlibJWTBackend
elif PythonJoseJWTBackend is not None:
DEFAULT_JWT_BACKEND = PythonJoseJWTBackend
else: # pragma: nocover
raise ImportError("No JWT backend found, please install 'python-jose' or 'authlib'")


def utcnow() -> datetime:
try:
from datetime import UTC
Expand All @@ -39,7 +38,7 @@ def utcnow() -> datetime:


__all__ = [
"define_default_jwt_backend",
"force_jwt_backend",
"JwtAuthorizationCredentials",
"JwtAccessBearer",
"JwtAccessCookie",
Expand Down Expand Up @@ -89,10 +88,9 @@ def __init__(

self.jwt_backend = DEFAULT_JWT_BACKEND(algorithm)
self.secret_key = secret_key
if places:
assert places.issubset({"header", "cookie"}), "only 'header'/'cookie' are supported"

self.places = places or {"header"}
assert self.places.issubset({"header", "cookie"}), "only 'header' and/or 'cookie' places are supported"
self.auto_error = auto_error
self.access_expires_delta = access_expires_delta or timedelta(minutes=15)
self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31)
Expand Down Expand Up @@ -152,7 +150,13 @@ async def _get_payload(
return None

# Try to decode jwt token. auto_error on error
return self.jwt_backend.decode(token, self.secret_key, self.auto_error)
try:
return self.jwt_backend.decode(token, self.secret_key)
except BackendException as e:
if self.auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=str(e))
else:
return None

def create_access_token(
self,
Expand Down Expand Up @@ -354,7 +358,7 @@ async def _get_credentials(
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Wrong token: 'type' is not 'refresh'",
detail="Invalid token: 'type' is not 'refresh'",
)
else:
return None
Expand Down
13 changes: 3 additions & 10 deletions fastapi_jwt/jwt_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
try:
from .authlib_backend import AuthlibJWTBackend
except ImportError:
AuthlibJWTBackend = None # type: ignore

try:
from .python_jose_backend import PythonJoseJWTBackend
except ImportError:
PythonJoseJWTBackend = None # type: ignore

from . import abstract_backend, authlib_backend, python_jose_backend # noqa: F401
from .abstract_backend import AbstractJWTBackend # noqa: F401
from .authlib_backend import AuthlibJWTBackend # noqa: F401
from .python_jose_backend import PythonJoseJWTBackend # noqa: F401
22 changes: 6 additions & 16 deletions fastapi_jwt/jwt_backends/abstract_backend.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from abc import ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional


class AbstractJWTBackend(metaclass=ABCMeta):
# simple "SingletonArgs" implementation to keep a JWTBackend per algorithm
_instances: Dict[Any, "AbstractJWTBackend"] = {}
class BackendException(Exception):
pass

def __new__(cls, algorithm: Optional[str]) -> "AbstractJWTBackend":
instance_key = (cls, algorithm)
if instance_key not in cls._instances:
cls._instances[instance_key] = super(AbstractJWTBackend, cls).__new__(cls)
return cls._instances[instance_key]

class AbstractJWTBackend(ABC):
@abstractmethod
def __init__(self, algorithm: Optional[str]) -> None:
raise NotImplementedError

@property
@abstractmethod
def default_algorithm(self) -> str:
def __init__(self, algorithm: Optional[str] = None) -> None:
raise NotImplementedError

Check warning on line 12 in fastapi_jwt/jwt_backends/abstract_backend.py

View check run for this annotation

Codecov / codecov/patch

fastapi_jwt/jwt_backends/abstract_backend.py#L12

Added line #L12 was not covered by tests

@property
Expand All @@ -31,5 +21,5 @@ def encode(self, to_encode: Dict[str, Any], secret_key: str) -> str:
raise NotImplementedError

Check warning on line 21 in fastapi_jwt/jwt_backends/abstract_backend.py

View check run for this annotation

Codecov / codecov/patch

fastapi_jwt/jwt_backends/abstract_backend.py#L21

Added line #L21 was not covered by tests

@abstractmethod
def decode(self, token: str, secret_key: str, auto_error: bool) -> Optional[Dict[str, Any]]:
def decode(self, token: str, secret_key: str) -> Optional[Dict[str, Any]]:
raise NotImplementedError

Check warning on line 25 in fastapi_jwt/jwt_backends/abstract_backend.py

View check run for this annotation

Codecov / codecov/patch

fastapi_jwt/jwt_backends/abstract_backend.py#L25

Added line #L25 was not covered by tests
22 changes: 7 additions & 15 deletions fastapi_jwt/jwt_backends/authlib_backend.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Any, Dict, Optional

from fastapi import HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED

try:
import authlib.jose as authlib_jose
import authlib.jose.errors as authlib_jose_errors
except ImportError:
authlib_jose = None

Check warning on line 7 in fastapi_jwt/jwt_backends/authlib_backend.py

View check run for this annotation

Codecov / codecov/patch

fastapi_jwt/jwt_backends/authlib_backend.py#L6-L7

Added lines #L6 - L7 were not covered by tests

from .abstract_backend import AbstractJWTBackend
from .abstract_backend import AbstractJWTBackend, BackendException


class AuthlibJWTBackend(AbstractJWTBackend):
Expand All @@ -18,8 +15,9 @@ def __init__(self, algorithm: Optional[str] = None) -> None:

self._algorithm = algorithm or self.default_algorithm
# from https://github.com/lepture/authlib/blob/85f9ff/authlib/jose/__init__.py#L45
valid_algorithms = authlib_jose.JsonWebSignature.ALGORITHMS_REGISTRY.keys()
assert self._algorithm in valid_algorithms, f"{self._algorithm} algorithm is not supported by authlib"
assert (
self._algorithm in authlib_jose.JsonWebSignature.ALGORITHMS_REGISTRY.keys()
), f"{self._algorithm} algorithm is not supported by authlib"
self.jwt = authlib_jose.JsonWebToken(algorithms=[self._algorithm])

@property
Expand All @@ -34,22 +32,16 @@ def encode(self, to_encode: Dict[str, Any], secret_key: str) -> str:
token = self.jwt.encode(header={"alg": self.algorithm}, payload=to_encode, key=secret_key)
return token.decode() # convert to string

def decode(self, token: str, secret_key: str, auto_error: bool) -> Optional[Dict[str, Any]]:
def decode(self, token: str, secret_key: str) -> Optional[Dict[str, Any]]:
try:
payload = self.jwt.decode(token, secret_key)
payload.validate(leeway=10)
return dict(payload)
except authlib_jose_errors.ExpiredTokenError as e:
if auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}")
else:
return None
raise BackendException(f"Token time expired: {e}")
except (
authlib_jose_errors.InvalidClaimError,
authlib_jose_errors.InvalidTokenError,
authlib_jose_errors.DecodeError,
) as e:
if auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}")
else:
return None
raise BackendException(f"Invalid token: {e}")
38 changes: 16 additions & 22 deletions fastapi_jwt/jwt_backends/python_jose_backend.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,46 @@
import warnings
from typing import Any, Dict, Optional

from fastapi import HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED

try:
from jose import jwt
import jose
import jose.jwt
except ImportError:
jwt = None # type: ignore
jose = None # type: ignore

Check warning on line 8 in fastapi_jwt/jwt_backends/python_jose_backend.py

View check run for this annotation

Codecov / codecov/patch

fastapi_jwt/jwt_backends/python_jose_backend.py#L7-L8

Added lines #L7 - L8 were not covered by tests

from .abstract_backend import AbstractJWTBackend
from .abstract_backend import AbstractJWTBackend, BackendException


class PythonJoseJWTBackend(AbstractJWTBackend):
def __init__(self, algorithm: Optional[str] = None) -> None:
assert jwt is not None, "To use PythonJoseJWTBackend, you need to install python-jose"
assert jose is not None, "To use PythonJoseJWTBackend, you need to install python-jose"
warnings.warn("PythonJoseJWTBackend is deprecated as python-jose library is not maintained anymore.")

self._algorithm = algorithm or self.default_algorithm
assert (
hasattr(jwt.ALGORITHMS, self._algorithm) is True # type: ignore[attr-defined]
hasattr(jose.jwt.ALGORITHMS, self._algorithm) is True # type: ignore[attr-defined]
), f"{algorithm} algorithm is not supported by python-jose library"

@property
def default_algorithm(self) -> str:
return jwt.ALGORITHMS.HS256
return jose.jwt.ALGORITHMS.HS256 # type: ignore[attr-defined]

@property
def algorithm(self) -> str:
return self._algorithm

def encode(self, to_encode: Dict[str, Any], secret_key: str) -> str:
return jwt.encode(to_encode, secret_key, algorithm=self._algorithm)
return jose.jwt.encode(to_encode, secret_key, algorithm=self._algorithm)

def decode(self, token: str, secret_key: str, auto_error: bool) -> Optional[Dict[str, Any]]:
def decode(self, token: str, secret_key: str) -> Optional[Dict[str, Any]]:
try:
payload: Dict[str, Any] = jwt.decode(
payload: Dict[str, Any] = jose.jwt.decode(
token,
secret_key,
algorithms=[self._algorithm],
options={"leeway": 10},
)
return payload
except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
if auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=f"Token time expired: {e}")
else:
return None
except jwt.JWTError as e: # type: ignore[attr-defined]
if auto_error:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=f"Wrong token: {e}")
else:
return None
except jose.jwt.ExpiredSignatureError as e: # type: ignore[attr-defined]
raise BackendException(f"Token time expired: {e}")
except jose.jwt.JWTError as e: # type: ignore[attr-defined]
raise BackendException(f"Invalid token: {e}")
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Type

import pytest

from fastapi_jwt.jwt_backends import AbstractJWTBackend, AuthlibJWTBackend, PythonJoseJWTBackend


@pytest.fixture(params=[PythonJoseJWTBackend, AuthlibJWTBackend])
def jwt_backend(request: pytest.FixtureRequest) -> Type[AbstractJWTBackend]:
return request.param
46 changes: 15 additions & 31 deletions tests/test_security_jwt_bearer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
import pytest
from typing import Type

from fastapi import FastAPI, Security
from fastapi.testclient import TestClient

from fastapi_jwt import (
AuthlibJWTBackend,
JwtAccessBearer,
JwtAuthorizationCredentials,
JwtRefreshBearer,
PythonJoseJWTBackend,
define_default_jwt_backend,
)
from fastapi_jwt import JwtAccessBearer, JwtAuthorizationCredentials, JwtRefreshBearer, force_jwt_backend
from fastapi_jwt.jwt_backends import AbstractJWTBackend


def create_example_client(jwt_backend: AbstractJWTBackend):
define_default_jwt_backend(jwt_backend)
def create_example_client(jwt_backend: Type[AbstractJWTBackend]):
force_jwt_backend(jwt_backend)
app = FastAPI()

access_security = JwtAccessBearer(secret_key="secret_key")
Expand Down Expand Up @@ -97,23 +91,20 @@ def read_current_user(
}


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_openapi_schema(jwt_backend: AbstractJWTBackend):
def test_openapi_schema(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_auth(jwt_backend: AbstractJWTBackend):
def test_security_jwt_auth(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.post("/auth")
assert response.status_code == 200, response.text


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_access_bearer(jwt_backend: AbstractJWTBackend):
def test_security_jwt_access_bearer(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
access_token = client.post("/auth").json()["access_token"]

Expand All @@ -122,56 +113,49 @@ def test_security_jwt_access_bearer(jwt_backend: AbstractJWTBackend):
assert response.json() == {"username": "username", "role": "user"}


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_access_bearer_wrong(jwt_backend: AbstractJWTBackend):
def test_security_jwt_access_bearer_wrong(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.get("/users/me", headers={"Authorization": "Bearer wrong_access_token"})
assert response.status_code == 401, response.text


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_access_bearer_no_credentials(jwt_backend: AbstractJWTBackend):
def test_security_jwt_access_bearer_no_credentials(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.get("/users/me")
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Credentials are not provided"}


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_access_bearer_incorrect_scheme_credentials(jwt_backend: AbstractJWTBackend):
def test_security_jwt_access_bearer_incorrect_scheme_credentials(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Credentials are not provided"}
# assert response.json() == {"detail": "Invalid authentication credentials"}


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_refresh_bearer(jwt_backend: AbstractJWTBackend):
def test_security_jwt_refresh_bearer(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
refresh_token = client.post("/auth").json()["refresh_token"]

response = client.post("/refresh", headers={"Authorization": f"Bearer {refresh_token}"})
assert response.status_code == 200, response.text


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_refresh_bearer_wrong(jwt_backend: AbstractJWTBackend):
def test_security_jwt_refresh_bearer_wrong(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.post("/refresh", headers={"Authorization": "Bearer wrong_refresh_token"})
assert response.status_code == 401, response.text


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_refresh_bearer_no_credentials(jwt_backend: AbstractJWTBackend):
def test_security_jwt_refresh_bearer_no_credentials(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.post("/refresh")
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Credentials are not provided"}


@pytest.mark.parametrize("jwt_backend", [AuthlibJWTBackend, PythonJoseJWTBackend])
def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(jwt_backend: AbstractJWTBackend):
def test_security_jwt_refresh_bearer_incorrect_scheme_credentials(jwt_backend: Type[AbstractJWTBackend]):
client = create_example_client(jwt_backend)
response = client.post("/refresh", headers={"Authorization": "Basic notreally"})
assert response.status_code == 401, response.text
Expand Down
Loading

0 comments on commit ad4221d

Please sign in to comment.