diff --git a/project_name/app.py b/project_name/app.py index c3bdfe0..691f176 100644 --- a/project_name/app.py +++ b/project_name/app.py @@ -56,4 +56,4 @@ def read(*paths, **kwargs): @app.on_event("startup") def on_startup(): - create_db_and_tables(engine) + create_db_and_tables(engine) # pragma: no cover diff --git a/project_name/models/content.py b/project_name/models/content.py index 3497122..d6b18b9 100644 --- a/project_name/models/content.py +++ b/project_name/models/content.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Extra from sqlmodel import Field, Relationship, SQLModel -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no coverage from project_name.security import User diff --git a/project_name/routes/content.py b/project_name/routes/content.py index 7ee33ec..55e3fad 100644 --- a/project_name/routes/content.py +++ b/project_name/routes/content.py @@ -21,15 +21,16 @@ async def list_contents(*, session: Session = ActiveSession): async def query_content( *, id_or_slug: Union[str, int], session: Session = ActiveSession ): - content = session.query(Content).where( + query = select(Content).where( or_( Content.id == id_or_slug, Content.slug == id_or_slug, ) ) + content = session.exec(query).one_or_none() if not content: raise HTTPException(status_code=404, detail="Content not found") - return content.first() + return content @router.post( diff --git a/project_name/routes/user.py b/project_name/routes/user.py index 6f98600..cd414d2 100644 --- a/project_name/routes/user.py +++ b/project_name/routes/user.py @@ -70,6 +70,14 @@ async def update_user_password( return user +# Order of these functions matters here +# The /me/ path needs to be higher than the wildcard path below or else +# this function will never be called. +@router.get("/me/", response_model=UserResponse) +async def my_profile(current_user: User = AuthenticatedUser): + return current_user + + @router.get( "/{user_id_or_username}/", response_model=UserResponse, @@ -78,21 +86,16 @@ async def update_user_password( async def query_user( *, session: Session = ActiveSession, user_id_or_username: Union[str, int] ): - user = session.query(User).where( + query = select(User).where( or_( User.id == user_id_or_username, User.username == user_id_or_username, ) ) - + user = session.exec(query).one_or_none() if not user: raise HTTPException(status_code=404, detail="User not found") - return user.first() - - -@router.get("/me/", response_model=UserResponse) -async def my_profile(current_user: User = AuthenticatedUser): - return current_user + return user @router.delete("/{user_id}/", dependencies=[AdminUser]) diff --git a/project_name/security.py b/project_name/security.py index c2aaefb..b418df8 100644 --- a/project_name/security.py +++ b/project_name/security.py @@ -50,7 +50,7 @@ def __get_validators__(cls): @classmethod def validate(cls, v): """Accepts a plain text password and returns a hashed password.""" - if not isinstance(v, str): + if not isinstance(v, str): # pragma: no coverage raise TypeError("string required") hashed_password = get_password_hash(v) @@ -114,7 +114,7 @@ def create_access_token( to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta - else: + else: # pragma: no cover expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) @@ -150,7 +150,7 @@ def get_current_user( if authorization := request.headers.get("authorization"): try: token = authorization.split(" ")[1] - except IndexError: + except IndexError: # pragma: no cover raise credentials_exception try: @@ -170,7 +170,7 @@ def get_current_user( async def get_current_active_user( current_user: User = Depends(get_current_user), ) -> User: - if current_user.disabled: + if current_user.disabled: # pragma: no cover raise HTTPException(status_code=400, detail="Inactive user") return current_user @@ -181,7 +181,7 @@ async def get_current_active_user( async def get_current_admin_user( current_user: User = Depends(get_current_user), ) -> User: - if not current_user.superuser: + if not current_user.superuser: # pragma: no cover raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not an admin user" ) diff --git a/tests/conftest.py b/tests/conftest.py index 05abd61..f30d74e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,6 +62,24 @@ def api_client_authenticated(): return client +@pytest.fixture(scope="function") +def api_client_not_superuser(): + + try: + create_user("regular", "regular", superuser=False) + except IntegrityError: + pass + + client = TestClient(app) + token = client.post( + "/token", + data={"username": "regular", "password": "regular"}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ).json()["access_token"] + client.headers["Authorization"] = f"Bearer {token}" + return client + + @pytest.fixture(scope="function") def cli_client(): return CliRunner() diff --git a/tests/test_content_api.py b/tests/test_content_api.py index 5ff8533..e7354e9 100644 --- a/tests/test_content_api.py +++ b/tests/test_content_api.py @@ -18,3 +18,77 @@ def test_content_list(api_client_authenticated): assert response.status_code == 200 result = response.json() assert result[0]["slug"] == "hello-test" + + +def test_content_get_individual(api_client_authenticated): + response = api_client_authenticated.get("/content/hello-test") + assert response.status_code == 200 + result = response.json() + assert result["slug"] == "hello-test" + + +def test_content_get_individual_404(api_client_authenticated): + response = api_client_authenticated.get("/content/does-not-exist/") + assert response.status_code == 404 + + +def test_content_update(api_client_authenticated): + response = api_client_authenticated.post( + "/content/", + json={ + "title": "Test Post 2 for Patch", + "text": "this is just a test", + "published": True, + "tags": ["test", "hello"], + }, + ) + assert response.status_code == 200 + result = response.json() + assert result["slug"] == "test-post-2-for-patch" + response2 = api_client_authenticated.patch( + f"/content/{result['id']}/", + json={ + "published": "false", + }, + ) + assert response2.status_code == 200 + result2 = response2.json() + assert result["slug"] == result2["slug"] + assert result["text"] == result2["text"] + assert result["tags"] == result2["tags"] + assert result["published"] != result2["published"] + + +def test_content_update_404(api_client_authenticated): + response = api_client_authenticated.patch( + "/content/999/", + json={ + "published": "false", + }, + ) + assert response.status_code == 404 + + +def test_content_update_unauthorized(api_client_not_superuser): + response = api_client_not_superuser.patch( + "/content/1/", + json={ + "published": "false", + }, + ) + assert response.status_code == 403 + + +def test_content_delete_404(api_client_authenticated): + response = api_client_authenticated.delete("/content/999/") + assert response.status_code == 404 + + +def test_content_delete_unauthorized(api_client_not_superuser): + response = api_client_not_superuser.delete("/content/2/") + assert response.status_code == 403 + + +def test_content_delete(api_client_authenticated): + response = api_client_authenticated.delete("/content/2/") + assert response.status_code == 200 diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..6cb8da8 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,17 @@ +from project_name import security +from fastapi.exceptions import HTTPException +import pytest + + +def test_get_current_user(): + malformed_token_no_username = security.create_access_token({}) + with pytest.raises(HTTPException): + user = security.get_current_user(token=malformed_token_no_username) + + malformed_token_invalid_username = security.create_access_token( + {"sub": "InvalidUserName"} + ) + with pytest.raises(HTTPException): + user = security.get_current_user( + token=malformed_token_invalid_username + ) diff --git a/tests/test_security_api.py b/tests/test_security_api.py new file mode 100644 index 0000000..1cf96d4 --- /dev/null +++ b/tests/test_security_api.py @@ -0,0 +1,35 @@ +from fastapi.testclient import TestClient +from project_name import app + + +def test_user_login_no_username_match(): + client = TestClient(app) + response = client.post( + "/token", + data={"username": "doesNotExist", "password": "doesNotExist"}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + assert response.status_code == 401 + + +def test_user_login_no_password_match(): + client = TestClient(app) + response = client.post( + "/token", + data={"username": "admin", "password": "incorrect password"}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + assert response.status_code == 401 + + +def test_secure_api_malformed_headers(): + client = TestClient(app) + token = client.post( + "/token", + data={"username": "admin", "password": "admin"}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ).json()["access_token"] + malformed_token = token[:5] + "a" + token[5:] + client.headers["Authorization"] = f"Bearer {malformed_token}" + response = client.get("/user/me") + assert response.status_code == 401 diff --git a/tests/test_user_api.py b/tests/test_user_api.py index c7d547e..7b064d7 100644 --- a/tests/test_user_api.py +++ b/tests/test_user_api.py @@ -18,3 +18,109 @@ def test_user_create(api_client_authenticated): assert response.status_code == 200 result = response.json() assert result["username"] == "foo" + + +def test_get_my_profile(api_client_not_superuser): + response = api_client_not_superuser.get("/user/me") + assert response.status_code == 200 + result = response.json() + assert result["username"] == "regular" + assert result["id"] == 3 + + +def test_get_other_profile_by_id(api_client_not_superuser): + response = api_client_not_superuser.get("/user/1") + assert response.status_code == 200 + result = response.json() + assert result["username"] == "admin2" + assert result["id"] == 1 + + +def test_get_other_profile_by_name(api_client_not_superuser): + response = api_client_not_superuser.get("/user/admin") + assert response.status_code == 200 + result = response.json() + assert result["username"] == "admin" + assert result["id"] == 2 + + +def test_get_other_profile_404(api_client_not_superuser): + response = api_client_not_superuser.get("/user/99999") + result = response.json() + assert response.status_code == 404 + + +def test_change_password_404(api_client_not_superuser): + response = api_client_not_superuser.patch( + "/user/99999/password/", + json={"password": "string", "password_confirm": "string"}, + ) + result = response.text + assert response.status_code == 404 + + +def test_change_password_unauthorised(api_client_not_superuser): + response = api_client_not_superuser.patch( + "/user/1/password/", + json={"password": "string", "password_confirm": "string"}, + ) + result = response.text + assert response.status_code == 403 + + +def test_change_password_no_match(api_client_not_superuser): + my_user = api_client_not_superuser.get("/user/me/").json() + response = api_client_not_superuser.patch( + f"/user/{my_user['id']}/password/", + json={"password": "string", "password_confirm": "string1"}, + ) + assert response.status_code == 400 + result = response.json() + assert result["detail"] == "Passwords don't match" + + +def test_change_password(api_client_not_superuser): + my_user = api_client_not_superuser.get("/user/me/").json() + response = api_client_not_superuser.patch( + f"/user/{my_user['id']}/password/", + json={"password": "string", "password_confirm": "string"}, + ) + assert response.status_code == 200 + result = response.json() + assert result == my_user + + +def test_change_password_by_admin(api_client_authenticated): + regular_user = api_client_authenticated.get("/user/regular/").json() + response = api_client_authenticated.patch( + f"/user/{regular_user['id']}/password/", + json={"password": "string", "password_confirm": "string"}, + ) + assert response.status_code == 200 + result = response.json() + assert result == regular_user + + +def test_delete_user_404(api_client_authenticated): + response = api_client_authenticated.delete( + "/user/99999/", + ) + assert response.status_code == 404 + + +def test_delete_user_self_not_allowed(api_client_authenticated): + my_user = api_client_authenticated.get("/user/me/").json() + response = api_client_authenticated.delete( + f"/user/{my_user['id']}/", + ) + assert response.status_code == 403 + + +def test_delete_user(api_client_authenticated): + user = api_client_authenticated.get("/user/foo/").json() + response = api_client_authenticated.delete( + f"/user/{user['id']}/", + ) + assert response.status_code == 200 + result = response.json() + assert result["ok"] == True