diff --git a/tests/unit/app/test_database.py b/tests/unit/app/test_database.py index d8ce1103c..89b33a938 100644 --- a/tests/unit/app/test_database.py +++ b/tests/unit/app/test_database.py @@ -2,10 +2,11 @@ # pylint: disable=protected-access +from typing import Generator from pathlib import Path import tempfile import pytest -from pytest_mock import MockerFixture +from pytest_mock import MockerFixture, MockType from sqlalchemy.engine.base import Engine from sqlalchemy.orm import Session @@ -14,7 +15,7 @@ @pytest.fixture(name="reset_database_state") -def reset_database_state_fixture(): +def reset_database_state_fixture() -> Generator: """Reset global database state before and after tests.""" original_engine = database.engine original_session_local = database.session_local @@ -31,7 +32,7 @@ def reset_database_state_fixture(): @pytest.fixture(name="base_postgres_config") -def base_postgres_config_fixture(): +def base_postgres_config_fixture() -> PostgreSQLDatabaseConfiguration: """Provide base PostgreSQL configuration for tests.""" return PostgreSQLDatabaseConfiguration( host="localhost", @@ -47,7 +48,7 @@ def base_postgres_config_fixture(): class TestGetEngine: """Test cases for get_engine function.""" - def test_get_engine_when_initialized(self, mocker: MockerFixture): + def test_get_engine_when_initialized(self, mocker: MockerFixture) -> None: """Test get_engine returns engine when initialized.""" mock_engine = mocker.MagicMock(spec=Engine) database.engine = mock_engine @@ -56,7 +57,7 @@ def test_get_engine_when_initialized(self, mocker: MockerFixture): assert result is mock_engine - def test_get_engine_when_not_initialized(self): + def test_get_engine_when_not_initialized(self) -> None: """Test get_engine raises RuntimeError when not initialized.""" database.engine = None @@ -68,7 +69,7 @@ def test_get_engine_when_not_initialized(self): class TestGetSession: """Test cases for get_session function.""" - def test_get_session_when_initialized(self, mocker: MockerFixture): + def test_get_session_when_initialized(self, mocker: MockerFixture) -> None: """Test get_session returns session when initialized.""" mock_session_local = mocker.MagicMock() mock_session = mocker.MagicMock(spec=Session) @@ -80,7 +81,7 @@ def test_get_session_when_initialized(self, mocker: MockerFixture): assert result is mock_session mock_session_local.assert_called_once() - def test_get_session_when_not_initialized(self): + def test_get_session_when_not_initialized(self) -> None: """Test get_session raises RuntimeError when not initialized.""" database.session_local = None @@ -91,7 +92,7 @@ def test_get_session_when_not_initialized(self): class TestCreateTables: """Test cases for create_tables function.""" - def test_create_tables_success(self, mocker: MockerFixture): + def test_create_tables_success(self, mocker: MockerFixture) -> None: """Test create_tables calls Base.metadata.create_all with engine.""" mock_base = mocker.patch("app.database.Base") mock_get_engine = mocker.patch("app.database.get_engine") @@ -103,7 +104,9 @@ def test_create_tables_success(self, mocker: MockerFixture): mock_get_engine.assert_called_once() mock_base.metadata.create_all.assert_called_once_with(mock_engine) - def test_create_tables_when_engine_not_initialized(self, mocker: MockerFixture): + def test_create_tables_when_engine_not_initialized( + self, mocker: MockerFixture + ) -> None: """Test create_tables raises error when engine not initialized.""" mock_get_engine = mocker.patch("app.database.get_engine") mock_get_engine.side_effect = RuntimeError("Database engine not initialized") @@ -115,7 +118,7 @@ def test_create_tables_when_engine_not_initialized(self, mocker: MockerFixture): class TestCreateSqliteEngine: """Test cases for _create_sqlite_engine function.""" - def test_create_sqlite_engine_success(self): + def test_create_sqlite_engine_success(self) -> None: """Test _create_sqlite_engine creates engine successfully.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test.db" @@ -126,7 +129,7 @@ def test_create_sqlite_engine_success(self): assert isinstance(engine, Engine) assert f"sqlite:///{db_path}" in str(engine.url) - def test_create_sqlite_engine_directory_not_exists(self): + def test_create_sqlite_engine_directory_not_exists(self) -> None: """Test _create_sqlite_engine raises error when directory doesn't exist.""" config = SQLiteDatabaseConfiguration(db_path="/nonexistent/path/test.db") @@ -135,7 +138,7 @@ def test_create_sqlite_engine_directory_not_exists(self): ): database._create_sqlite_engine(config) - def test_create_sqlite_engine_creation_failure(self, mocker: MockerFixture): + def test_create_sqlite_engine_creation_failure(self, mocker: MockerFixture) -> None: """Test _create_sqlite_engine handles engine creation failure.""" mock_create_engine = mocker.patch("app.database.create_engine") with tempfile.TemporaryDirectory() as temp_dir: @@ -151,8 +154,10 @@ class TestCreatePostgresEngine: """Test cases for _create_postgres_engine function.""" def test_create_postgres_engine_success_default_schema( - self, mocker: MockerFixture, base_postgres_config - ): + self, + mocker: MockerFixture, + base_postgres_config: PostgreSQLDatabaseConfiguration, + ) -> None: """Test _create_postgres_engine creates engine successfully with default schema.""" mock_create_engine = mocker.patch("app.database.create_engine") mock_engine = mocker.MagicMock(spec=Engine) @@ -171,8 +176,10 @@ def test_create_postgres_engine_success_default_schema( assert expected_url == call_args[0][0] def test_create_postgres_engine_success_custom_schema( - self, mocker: MockerFixture, base_postgres_config - ): + self, + mocker: MockerFixture, + base_postgres_config: PostgreSQLDatabaseConfiguration, + ) -> None: """Test _create_postgres_engine creates engine successfully with custom schema.""" mock_create_engine = mocker.patch("app.database.create_engine") mock_engine = mocker.MagicMock(spec=Engine) @@ -193,8 +200,10 @@ def test_create_postgres_engine_success_custom_schema( mock_connection.commit.assert_called_once() def test_create_postgres_engine_with_ca_cert( - self, mocker: MockerFixture, base_postgres_config - ): + self, + mocker: MockerFixture, + base_postgres_config: PostgreSQLDatabaseConfiguration, + ) -> None: """Test _create_postgres_engine with CA certificate path.""" mock_create_engine = mocker.patch("app.database.create_engine") mock_engine = mocker.MagicMock(spec=Engine) @@ -212,8 +221,10 @@ def test_create_postgres_engine_with_ca_cert( assert call_args[1]["connect_args"]["sslrootcert"] == cert_file.name def test_create_postgres_engine_creation_failure( - self, mocker: MockerFixture, base_postgres_config - ): + self, + mocker: MockerFixture, + base_postgres_config: PostgreSQLDatabaseConfiguration, + ) -> None: """Test _create_postgres_engine handles engine creation failure.""" mock_create_engine = mocker.patch("app.database.create_engine") mock_create_engine.side_effect = Exception("Connection failed") @@ -222,8 +233,10 @@ def test_create_postgres_engine_creation_failure( database._create_postgres_engine(base_postgres_config) def test_create_postgres_engine_schema_creation_failure( - self, mocker: MockerFixture, base_postgres_config - ): + self, + mocker: MockerFixture, + base_postgres_config: PostgreSQLDatabaseConfiguration, + ) -> None: """Test _create_postgres_engine handles schema creation failure.""" mock_create_engine = mocker.patch("app.database.create_engine") mock_engine = mocker.MagicMock(spec=Engine) @@ -246,10 +259,10 @@ def _setup_common_mocks( self, *, mocker: MockerFixture, - mock_sessionmaker, - mock_logger, - enable_debug=False, - ): + mock_sessionmaker: MockType, + mock_logger: MockType, + enable_debug: bool = False, + ) -> tuple[MockType, MockType]: """Setup common mocks for initialize_database tests.""" mock_engine = mocker.MagicMock(spec=Engine) mock_session_local = mocker.MagicMock() @@ -258,8 +271,12 @@ def _setup_common_mocks( return mock_engine, mock_session_local def _verify_common_assertions( - self, *, mock_sessionmaker, mock_engine, mock_session_local - ): + self, + *, + mock_sessionmaker: MockType, + mock_engine: MockType, + mock_session_local: MockType, + ) -> None: """Verify common assertions for initialize_database tests.""" mock_sessionmaker.assert_called_once_with( autocommit=False, autoflush=False, bind=mock_engine @@ -270,7 +287,7 @@ def _verify_common_assertions( def test_initialize_database_sqlite( self, mocker: MockerFixture, - ): + ) -> None: """Test initialize_database with SQLite configuration.""" # Setup mocks mock_configuration = mocker.patch("app.database.configuration") @@ -304,8 +321,8 @@ def test_initialize_database_sqlite( def test_initialize_database_postgres( self, mocker: MockerFixture, - base_postgres_config, - ): + base_postgres_config: PostgreSQLDatabaseConfiguration, + ) -> None: """Test initialize_database with PostgreSQL configuration.""" # Setup mocks mock_configuration = mocker.patch("app.database.configuration") diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index e466fca44..7a608f6f9 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -1,6 +1,6 @@ """Unit tests for routers.py.""" -from typing import Any, Optional +from typing import Any, Optional, Sequence, Callable from fastapi import FastAPI @@ -37,14 +37,14 @@ def include_router( # pylint: disable=too-many-arguments router: Any, *, prefix: str = "", - tags=None, - dependencies=None, - responses=None, - deprecated=None, - include_in_schema=None, - default_response_class=None, - callbacks=None, - generate_unique_id_function=None, + tags: Optional[list] = None, + dependencies: Optional[Sequence] = None, + responses: Optional[dict] = None, + deprecated: Optional[bool] = None, + include_in_schema: Optional[bool] = None, + default_response_class: Optional[Any] = None, + callbacks: Optional[list] = None, + generate_unique_id_function: Optional[Callable] = None, ) -> None: """Register new router.""" self.routers.append((router, prefix)) diff --git a/tests/unit/authentication/test_jwk_token.py b/tests/unit/authentication/test_jwk_token.py index 0dc16f1ad..fb06ba51a 100644 --- a/tests/unit/authentication/test_jwk_token.py +++ b/tests/unit/authentication/test_jwk_token.py @@ -4,6 +4,9 @@ import time +from typing import Any, Generator +from pytest_mock import MockerFixture + import pytest from fastapi import HTTPException, Request from pydantic import AnyHttpUrl @@ -18,13 +21,13 @@ @pytest.fixture -def token_header(single_key_set): +def token_header(single_key_set: list[dict[str, Any]]) -> dict[str, Any]: """A sample token header.""" return {"alg": "RS256", "typ": "JWT", "kid": single_key_set[0]["kid"]} @pytest.fixture -def token_payload(): +def token_payload() -> dict[str, Any]: """A sample token payload with the default user_id and username claims.""" return { "user_id": TEST_USER_ID, @@ -34,7 +37,7 @@ def token_payload(): } -def make_key(): +def make_key() -> dict[str, Any]: """Generate a key pair for testing purposes.""" key = JsonWebKey.generate_key("RSA", 2048, is_private=True) return { @@ -45,19 +48,21 @@ def make_key(): @pytest.fixture -def single_key_set(): +def single_key_set() -> list[dict[str, Any]]: """Default single-key set for signing tokens.""" return [make_key()] @pytest.fixture -def another_single_key_set(): +def another_single_key_set() -> list[dict[str, Any]]: """Same as single_key_set, but generates a different key pair by being its own fixture.""" return [make_key()] @pytest.fixture -def valid_token(single_key_set, token_header, token_payload): +def valid_token( + single_key_set: list[dict[str, Any]], token_header: str, token_payload: str +) -> str: """A token that is valid and signed with the signing keys.""" jwt_instance = JsonWebToken(algorithms=["RS256"]) return jwt_instance.encode( @@ -66,14 +71,16 @@ def valid_token(single_key_set, token_header, token_payload): @pytest.fixture(autouse=True) -def clear_jwk_cache(): +def clear_jwk_cache() -> Generator: """Clear the global JWK cache before each test.""" _jwk_cache.clear() yield _jwk_cache.clear() -def make_signing_server(mocker, key_set, algorithms): +def make_signing_server( + mocker: MockerFixture, key_set: list[dict[str, Any]], algorithms: list[str] +) -> Any: """A fake server to serve our signing keys as JWKs.""" mock_session_class = mocker.patch("aiohttp.ClientSession") mock_response = mocker.AsyncMock() @@ -111,13 +118,15 @@ def make_signing_server(mocker, key_set, algorithms): @pytest.fixture -def mocked_signing_keys_server(mocker, single_key_set): +def mocked_signing_keys_server( + mocker: MockerFixture, single_key_set: list[dict[str, Any]] +) -> None: """Single-key signing server.""" return make_signing_server(mocker, single_key_set, ["RS256"]) @pytest.fixture -def default_jwk_configuration(): +def default_jwk_configuration() -> JwkConfiguration: """Default JwkConfiguration for testing.""" return JwkConfiguration( url=AnyHttpUrl("https://this#isgonnabemocked.com/jwks.json"), @@ -128,7 +137,7 @@ def default_jwk_configuration(): ) -def dummy_request(token): +def dummy_request(token: str) -> Request: """Generate a dummy request with a given token.""" return Request( scope={ @@ -140,7 +149,7 @@ def dummy_request(token): @pytest.fixture -def no_token_request(): +def no_token_request() -> Request: """Dummy request with no token.""" return Request( scope={ @@ -152,7 +161,7 @@ def no_token_request(): @pytest.fixture -def not_bearer_token_request(): +def not_bearer_token_request() -> Request: """Dummy request with no token.""" return Request( scope={ @@ -163,7 +172,7 @@ def not_bearer_token_request(): ) -def set_auth_header(request: Request, token: str): +def set_auth_header(request: Request, token: str) -> None: """Helper function to set the Authorization header in a request.""" new_headers = [ (k, v) for k, v in request.scope["headers"] if k.lower() != b"authorization" @@ -172,7 +181,7 @@ def set_auth_header(request: Request, token: str): request.scope["headers"] = new_headers -def ensure_test_user_id_and_name(auth_tuple, expected_token): +def ensure_test_user_id_and_name(auth_tuple: tuple, expected_token: str) -> None: """Utility to ensure that the values in the auth tuple match the test values.""" user_id, username, skip_userid_check, token = auth_tuple assert user_id == TEST_USER_ID @@ -182,10 +191,10 @@ def ensure_test_user_id_and_name(auth_tuple, expected_token): async def test_valid( - default_jwk_configuration, - mocked_signing_keys_server, - valid_token, -): + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + valid_token: str, +) -> None: """Test with a valid token.""" _ = mocked_signing_keys_server @@ -197,7 +206,9 @@ async def test_valid( @pytest.fixture -def expired_token(single_key_set, token_header, token_payload): +def expired_token( + single_key_set: list[dict[str, Any]], token_header: dict, token_payload: dict +) -> str: """An well-signed yet expired token.""" jwt_instance = JsonWebToken(algorithms=["RS256"]) token_payload["exp"] = int(time.time()) - 3600 # Set expiration in the past @@ -207,10 +218,10 @@ def expired_token(single_key_set, token_header, token_payload): async def test_expired( - default_jwk_configuration, - mocked_signing_keys_server, - expired_token, -): + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + expired_token: str, +) -> None: """Test with an expired token.""" _ = mocked_signing_keys_server @@ -225,7 +236,11 @@ async def test_expired( @pytest.fixture -def invalid_token(another_single_key_set, token_header, token_payload): +def invalid_token( + another_single_key_set: list[dict[str, Any]], + token_header: dict, + token_payload: dict, +) -> str: """A token that is signed with different keys than the signing keys.""" jwt_instance = JsonWebToken(algorithms=["RS256"]) return jwt_instance.encode( @@ -234,10 +249,10 @@ def invalid_token(another_single_key_set, token_header, token_payload): async def test_invalid( - default_jwk_configuration, - mocked_signing_keys_server, - invalid_token, -): + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + invalid_token: str, +) -> None: """Test with an invalid token.""" _ = mocked_signing_keys_server @@ -251,10 +266,10 @@ async def test_invalid( async def test_no_auth_header( - default_jwk_configuration, - mocked_signing_keys_server, - no_token_request, -): + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + no_token_request: str, +) -> None: """Test with no Authorization header.""" _ = mocked_signing_keys_server @@ -271,10 +286,10 @@ async def test_no_auth_header( async def test_no_bearer( - default_jwk_configuration, - mocked_signing_keys_server, - not_bearer_token_request, -): + default_jwk_configuration: JwtConfiguration, + mocked_signing_keys_server: Any, + not_bearer_token_request: str, +) -> None: """Test with Authorization header that does not start with Bearer.""" _ = mocked_signing_keys_server @@ -288,7 +303,11 @@ async def test_no_bearer( @pytest.fixture -def no_user_id_token(single_key_set, token_payload, token_header): +def no_user_id_token( + single_key_set: list[dict[str, Any]], + token_payload: dict[str, Any], + token_header: dict[str, Any], +) -> str: """Token without a user_id claim.""" jwt_instance = JsonWebToken(algorithms=["RS256"]) # Modify the token payload to include different claims @@ -300,10 +319,10 @@ def no_user_id_token(single_key_set, token_payload, token_header): async def test_no_user_id( - default_jwk_configuration, - mocked_signing_keys_server, - no_user_id_token, -): + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + no_user_id_token: str, +) -> None: """Test with a token that has no user_id claim.""" _ = mocked_signing_keys_server @@ -319,7 +338,11 @@ async def test_no_user_id( @pytest.fixture -def no_username_token(single_key_set, token_payload, token_header): +def no_username_token( + single_key_set: list[dict[str, Any]], + token_payload: dict[str, Any], + token_header: dict[str, Any], +) -> str: """Token without a username claim.""" jwt_instance = JsonWebToken(algorithms=["RS256"]) # Modify the token payload to include different claims @@ -331,10 +354,10 @@ def no_username_token(single_key_set, token_payload, token_header): async def test_no_username( - default_jwk_configuration, - mocked_signing_keys_server, - no_username_token, -): + default_jwk_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + no_username_token: str, +) -> None: """Test with a token that has no username claim.""" _ = mocked_signing_keys_server @@ -350,7 +373,9 @@ async def test_no_username( @pytest.fixture -def custom_claims_token(single_key_set, token_payload, token_header): +def custom_claims_token( + single_key_set: list[dict[str, Any]], token_payload: dict, token_header: dict +) -> str: """Token with custom claims.""" jwt_instance = JsonWebToken(algorithms=["RS256"]) @@ -367,7 +392,9 @@ def custom_claims_token(single_key_set, token_payload, token_header): @pytest.fixture -def custom_claims_configuration(default_jwk_configuration): +def custom_claims_configuration( + default_jwk_configuration: JwkConfiguration, +) -> JwkConfiguration: """Configuration for custom claims.""" # Create a copy of the default configuration custom_config = default_jwk_configuration.model_copy() @@ -380,10 +407,10 @@ def custom_claims_configuration(default_jwk_configuration): async def test_custom_claims( - custom_claims_configuration, - mocked_signing_keys_server, - custom_claims_token, -): + custom_claims_configuration: JwkConfiguration, + mocked_signing_keys_server: Any, + custom_claims_token: str, +) -> None: """Test with a token that has custom claims.""" _ = mocked_signing_keys_server @@ -396,49 +423,49 @@ async def test_custom_claims( @pytest.fixture -def token_header_256_1(multi_key_set): +def token_header_256_1(multi_key_set: list[dict[str, Any]]) -> dict[str, Any]: """A sample token header for RS256 using multi_key_set.""" return {"alg": "RS256", "typ": "JWT", "kid": multi_key_set[0]["kid"]} @pytest.fixture -def token_header_256_2(multi_key_set): +def token_header_256_2(multi_key_set: list[dict[str, Any]]) -> dict[str, Any]: """A sample token header for RS256 using multi_key_set.""" return {"alg": "RS256", "typ": "JWT", "kid": multi_key_set[1]["kid"]} @pytest.fixture -def token_header_384(multi_key_set): +def token_header_384(multi_key_set: list[dict[str, Any]]) -> dict[str, Any]: """A sample token header.""" return {"alg": "RS384", "typ": "JWT", "kid": multi_key_set[2]["kid"]} @pytest.fixture -def token_header_256_no_kid(): +def token_header_256_no_kid() -> dict[str, Any]: """RS256 no kid.""" return {"alg": "RS256", "typ": "JWT"} @pytest.fixture -def token_header_384_no_kid(): +def token_header_384_no_kid() -> dict[str, Any]: """RS384 no kid.""" return {"alg": "RS384", "typ": "JWT"} @pytest.fixture -def multi_key_set(): +def multi_key_set() -> list[dict[str, Any]]: """Default multi-key set for signing tokens.""" return [make_key(), make_key(), make_key()] @pytest.fixture def valid_tokens( - multi_key_set, - token_header_256_1, - token_header_256_2, - token_payload, - token_header_384, -): + multi_key_set: list[dict[str, Any]], + token_header_256_1: dict[str, Any], + token_header_256_2: dict[str, Any], + token_payload: dict[str, Any], + token_header_384: dict[str, Any], +) -> tuple[str, str, str]: """Generate valid tokens for each key in the multi-key set.""" key_for_256_1 = multi_key_set[0] key_for_256_2 = multi_key_set[1] @@ -464,8 +491,11 @@ def valid_tokens( @pytest.fixture def valid_tokens_no_kid( - multi_key_set, token_header_256_no_kid, token_payload, token_header_384_no_kid -): + multi_key_set: list[dict[str, Any]], + token_header_256_no_kid: dict[str, Any], + token_payload: dict[str, Any], + token_header_384_no_kid: dict[str, Any], +) -> tuple[str, str, str]: """Generate valid tokens for each key in the multi-key set without a kid.""" key_for_256_1 = multi_key_set[0] key_for_256_2 = multi_key_set[1] @@ -490,16 +520,18 @@ def valid_tokens_no_kid( @pytest.fixture -def multi_key_signing_server(mocker, multi_key_set): +def multi_key_signing_server( + mocker: MockerFixture, multi_key_set: list[dict[str, Any]] +) -> Any: """Multi-key signing server.""" return make_signing_server(mocker, multi_key_set, ["RS256", "RS256", "RS384"]) async def test_multi_key_valid( - default_jwk_configuration, - multi_key_signing_server, - valid_tokens, -): + default_jwk_configuration: JwkConfiguration, + multi_key_signing_server: Any, + valid_tokens: tuple[str, str, str], +) -> None: """Test with valid tokens from a multi-key set.""" _ = multi_key_signing_server @@ -517,10 +549,10 @@ async def test_multi_key_valid( async def test_multi_key_no_kid( - default_jwk_configuration, - multi_key_signing_server, - valid_tokens_no_kid, -): + default_jwk_configuration: JwkConfiguration, + multi_key_signing_server: Any, + valid_tokens_no_kid: tuple[str, str, str], +) -> None: """Test with valid tokens from a multi-key set without a kid.""" _ = multi_key_signing_server diff --git a/tests/unit/authentication/test_k8s.py b/tests/unit/authentication/test_k8s.py index fd0efcb7d..47cae381d 100644 --- a/tests/unit/authentication/test_k8s.py +++ b/tests/unit/authentication/test_k8s.py @@ -4,6 +4,7 @@ import os +from typing import Optional import pytest from pytest_mock import MockerFixture @@ -27,10 +28,18 @@ class MockK8sResponseStatus: and user information if authenticated. """ - def __init__(self, authenticated, allowed, username=None, uid=None, groups=None): + def __init__( + self, + authenticated: Optional[bool], + allowed: Optional[bool], + username: Optional[str] = None, + uid: Optional[str] = None, + groups: Optional[list[str]] = None, + ) -> None: """Init function.""" self.authenticated = authenticated self.allowed = allowed + self.user: Optional[MockK8sUser] if authenticated: self.user = MockK8sUser(username, uid, groups) else: @@ -43,7 +52,12 @@ class MockK8sUser: Represents a user in the mocked Kubernetes environment. """ - def __init__(self, username=None, uid=None, groups=None): + def __init__( + self, + username: Optional[str] = None, + uid: Optional[str] = None, + groups: Optional[list[str]] = None, + ) -> None: """Init function.""" self.username = username self.uid = uid @@ -57,7 +71,12 @@ class MockK8sResponse: """ def __init__( - self, authenticated=None, allowed=None, username=None, uid=None, groups=None + self, + authenticated: Optional[bool] = None, + allowed: Optional[bool] = None, + username: Optional[str] = None, + uid: Optional[str] = None, + groups: Optional[list[str]] = None, ): """Init function.""" self.status = MockK8sResponseStatus( @@ -65,14 +84,14 @@ def __init__( ) -def test_singleton_pattern(): +def test_singleton_pattern() -> None: """Test if K8sClientSingleton is really a singleton.""" k1 = K8sClientSingleton() k2 = K8sClientSingleton() assert k1 is k2 -async def test_auth_dependency_valid_token(mocker: MockerFixture): +async def test_auth_dependency_valid_token(mocker: MockerFixture) -> None: """Tests the auth dependency with a mocked valid-token.""" dependency = K8SAuthDependency() @@ -105,7 +124,7 @@ async def test_auth_dependency_valid_token(mocker: MockerFixture): assert token == "valid-token" -async def test_auth_dependency_invalid_token(mocker: MockerFixture): +async def test_auth_dependency_invalid_token(mocker: MockerFixture) -> None: """Test the auth dependency with a mocked invalid-token.""" dependency = K8SAuthDependency() @@ -137,7 +156,7 @@ async def test_auth_dependency_invalid_token(mocker: MockerFixture): assert exc_info.value.status_code == 403 -async def test_cluster_id_is_used_for_kube_admin(mocker: MockerFixture): +async def test_cluster_id_is_used_for_kube_admin(mocker: MockerFixture) -> None: """Test the cluster id is used as user_id when user is kube:admin.""" dependency = K8SAuthDependency() mock_authz_api = mocker.patch("authentication.k8s.K8sClientSingleton.get_authz_api") @@ -177,7 +196,7 @@ async def test_cluster_id_is_used_for_kube_admin(mocker: MockerFixture): assert token == "valid-token" -def test_auth_dependency_config(mocker: MockerFixture): +def test_auth_dependency_config(mocker: MockerFixture) -> None: """Test the auth dependency can load kubeconfig file.""" mocker.patch.dict(os.environ, {"MY_ENV_VAR": "mocked"}) @@ -191,7 +210,7 @@ def test_auth_dependency_config(mocker: MockerFixture): ), "authz_client is not an instance of AuthorizationV1Api" -def test_get_cluster_id(mocker: MockerFixture): +def test_get_cluster_id(mocker: MockerFixture) -> None: """Test get_cluster_id function.""" mock_get_custom_objects_api = mocker.patch( "authentication.k8s.K8sClientSingleton.get_custom_objects_api" @@ -212,7 +231,7 @@ def test_get_cluster_id(mocker: MockerFixture): K8sClientSingleton._get_cluster_id() # typeerror - cluster_id = None + cluster_id = None # type: ignore mocked_call = mocker.MagicMock() mocked_call.get_cluster_custom_object.return_value = cluster_id mock_get_custom_objects_api.return_value = mocked_call @@ -230,7 +249,7 @@ def test_get_cluster_id(mocker: MockerFixture): K8sClientSingleton._get_cluster_id() -def test_get_cluster_id_in_cluster(mocker: MockerFixture): +def test_get_cluster_id_in_cluster(mocker: MockerFixture) -> None: """Test get_cluster_id function when running inside of cluster.""" mocker.patch("authentication.k8s.RUNNING_IN_CLUSTER", True) mocker.patch("authentication.k8s.K8sClientSingleton.__new__") @@ -242,7 +261,7 @@ def test_get_cluster_id_in_cluster(mocker: MockerFixture): assert K8sClientSingleton.get_cluster_id() == "some-cluster-id" -def test_get_cluster_id_outside_of_cluster(mocker: MockerFixture): +def test_get_cluster_id_outside_of_cluster(mocker: MockerFixture) -> None: """Test get_cluster_id function when running outside of cluster.""" mocker.patch("authentication.k8s.RUNNING_IN_CLUSTER", False) mocker.patch("authentication.k8s.K8sClientSingleton.__new__") diff --git a/tests/unit/authentication/test_noop.py b/tests/unit/authentication/test_noop.py index 1e1460d05..e651ff673 100644 --- a/tests/unit/authentication/test_noop.py +++ b/tests/unit/authentication/test_noop.py @@ -5,7 +5,7 @@ from constants import DEFAULT_USER_NAME, DEFAULT_USER_UID, NO_USER_TOKEN -async def test_noop_auth_dependency(): +async def test_noop_auth_dependency() -> None: """Test the NoopAuthDependency class with default user ID.""" dependency = NoopAuthDependency() @@ -22,7 +22,7 @@ async def test_noop_auth_dependency(): assert user_token == NO_USER_TOKEN -async def test_noop_auth_dependency_custom_user_id(): +async def test_noop_auth_dependency_custom_user_id() -> None: """Test the NoopAuthDependency class.""" dependency = NoopAuthDependency() diff --git a/tests/unit/authentication/test_noop_with_token.py b/tests/unit/authentication/test_noop_with_token.py index 61048a41c..b18ae7a16 100644 --- a/tests/unit/authentication/test_noop_with_token.py +++ b/tests/unit/authentication/test_noop_with_token.py @@ -7,7 +7,7 @@ from constants import DEFAULT_USER_NAME, DEFAULT_USER_UID -async def test_noop_with_token_auth_dependency(): +async def test_noop_with_token_auth_dependency() -> None: """Test the NoopWithTokenAuthDependency class with default user ID.""" dependency = NoopWithTokenAuthDependency() @@ -31,7 +31,7 @@ async def test_noop_with_token_auth_dependency(): assert user_token == "spongebob-token" -async def test_noop_with_token_auth_dependency_custom_user_id(): +async def test_noop_with_token_auth_dependency_custom_user_id() -> None: """Test the NoopWithTokenAuthDependency class with custom user ID.""" dependency = NoopWithTokenAuthDependency() @@ -56,7 +56,7 @@ async def test_noop_with_token_auth_dependency_custom_user_id(): assert user_token == "spongebob-token" -async def test_noop_with_token_auth_dependency_no_token(): +async def test_noop_with_token_auth_dependency_no_token() -> None: """ Test if checks for Authorization header is in place. @@ -85,7 +85,7 @@ async def test_noop_with_token_auth_dependency_no_token(): assert exc_info.value.detail == "No Authorization header found" -async def test_noop_with_token_auth_dependency_no_bearer(): +async def test_noop_with_token_auth_dependency_no_bearer() -> None: """Test the NoopWithTokenAuthDependency class with no token.""" dependency = NoopWithTokenAuthDependency() diff --git a/tests/unit/authentication/test_utils.py b/tests/unit/authentication/test_utils.py index a0d64ac97..ee1d34dc8 100644 --- a/tests/unit/authentication/test_utils.py +++ b/tests/unit/authentication/test_utils.py @@ -6,14 +6,14 @@ from authentication.utils import extract_user_token -def test_extract_user_token(): +def test_extract_user_token() -> None: """Test extracting user token from headers.""" headers = Headers({"Authorization": "Bearer abcdef123"}) token = extract_user_token(headers) assert token == "abcdef123" -def test_extract_user_token_no_header(): +def test_extract_user_token_no_header() -> None: """Test extracting user token when no Authorization header is present.""" headers = Headers({}) try: @@ -23,7 +23,7 @@ def test_extract_user_token_no_header(): assert exc.detail == "No Authorization header found" -def test_extract_user_token_invalid_format(): +def test_extract_user_token_invalid_format() -> None: """Test extracting user token with invalid Authorization header format.""" headers = Headers({"Authorization": "InvalidFormat"}) try: diff --git a/tests/unit/authorization/test_resolvers.py b/tests/unit/authorization/test_resolvers.py index f05b22e05..2df1cc27d 100644 --- a/tests/unit/authorization/test_resolvers.py +++ b/tests/unit/authorization/test_resolvers.py @@ -5,6 +5,7 @@ import re from contextlib import nullcontext as does_not_raise +from typing import Any import pytest from authentication.interface import AuthTuple @@ -33,7 +34,7 @@ class TestJwtRolesResolver: """Test cases for JwtRolesResolver.""" @pytest.fixture - async def employee_role_rule(self) -> None: + async def employee_role_rule(self) -> JwtRoleRule: """Role rule for RedHat employees.""" return JwtRoleRule( jsonpath="$.realm_access.roles[*]", @@ -43,12 +44,14 @@ async def employee_role_rule(self) -> None: ) @pytest.fixture - async def employee_resolver(self, employee_role_rule): + async def employee_resolver( + self, employee_role_rule: JwtRoleRule + ) -> JwtRolesResolver: """JwtRolesResolver with a rule for RedHat employees.""" return JwtRolesResolver([employee_role_rule]) @pytest.fixture - async def employee_claims(self): + async def employee_claims(self) -> dict[str, Any]: """JWT claims for a RedHat employee.""" return { "foo": "bar", @@ -66,7 +69,7 @@ async def employee_claims(self): } @pytest.fixture - async def non_employee_claims(self): + async def non_employee_claims(self) -> dict[str, Any]: """JWT claims for a non-RedHat employee.""" return { "exp": 1754489339, @@ -76,14 +79,16 @@ async def non_employee_claims(self): } async def test_resolve_roles_redhat_employee( - self, employee_resolver, employee_claims - ): + self, employee_resolver: JwtRolesResolver, employee_claims: dict[str, Any] + ) -> None: """Test role extraction for RedHat employee JWT.""" assert "employee" in await employee_resolver.resolve_roles( claims_to_auth_tuple(employee_claims) ) - async def test_resolve_roles_no_match(self, employee_resolver, non_employee_claims): + async def test_resolve_roles_no_match( + self, employee_resolver: JwtRolesResolver, non_employee_claims: dict[str, Any] + ) -> None: """Test no roles extracted for non-RedHat employee JWT.""" assert ( len( @@ -94,7 +99,9 @@ async def test_resolve_roles_no_match(self, employee_resolver, non_employee_clai == 0 ) - async def test_negate_operator(self, employee_role_rule, non_employee_claims): + async def test_negate_operator( + self, employee_role_rule: JwtRoleRule, non_employee_claims: dict[str, Any] + ) -> None: """Test role extraction with negated operator.""" negated_rule = employee_role_rule negated_rule.negate = True @@ -106,7 +113,7 @@ async def test_negate_operator(self, employee_role_rule, non_employee_claims): ) @pytest.fixture - async def email_rule_resolver(self): + async def email_rule_resolver(self) -> JwtRolesResolver: """JwtRolesResolver with a rule for email domain.""" return JwtRolesResolver( [ @@ -120,7 +127,7 @@ async def email_rule_resolver(self): ) @pytest.fixture - async def equals_rule_resolver(self): + async def equals_rule_resolver(self) -> JwtRolesResolver: """JwtRolesResolver with a rule for exact email match.""" return JwtRolesResolver( [ @@ -134,15 +141,15 @@ async def equals_rule_resolver(self): ) async def test_resolve_roles_equals_operator( - self, equals_rule_resolver, employee_claims - ): + self, equals_rule_resolver: JwtRolesResolver, employee_claims: dict[str, Any] + ) -> None: """Test role extraction using EQUALS operator.""" assert "foobar" in await equals_rule_resolver.resolve_roles( claims_to_auth_tuple(employee_claims) ) @pytest.fixture - async def in_rule_resolver(self): + async def in_rule_resolver(self) -> JwtRolesResolver: """JwtRolesResolver with a rule for IN operator.""" return JwtRolesResolver( [ @@ -155,23 +162,25 @@ async def in_rule_resolver(self): ] ) - async def test_resolve_roles_in_operator(self, in_rule_resolver, employee_claims): + async def test_resolve_roles_in_operator( + self, in_rule_resolver: JwtRolesResolver, employee_claims: dict[str, Any] + ) -> None: """Test role extraction using IN operator.""" assert "in_role" in await in_rule_resolver.resolve_roles( claims_to_auth_tuple(employee_claims) ) async def test_resolve_roles_match_operator_email_domain( - self, email_rule_resolver, employee_claims - ): + self, email_rule_resolver: JwtRolesResolver, employee_claims: dict[str, Any] + ) -> None: """Test role extraction using MATCH operator with email domain regex.""" assert "redhat_employee" in await email_rule_resolver.resolve_roles( claims_to_auth_tuple(employee_claims) ) async def test_resolve_roles_match_operator_no_match( - self, email_rule_resolver, non_employee_claims - ): + self, email_rule_resolver: JwtRolesResolver, non_employee_claims: dict[str, Any] + ) -> None: """Test role extraction using MATCH operator with no match.""" assert ( len( @@ -182,7 +191,7 @@ async def test_resolve_roles_match_operator_no_match( == 0 ) - async def test_resolve_roles_match_operator_invalid_regex(self): + async def test_resolve_roles_match_operator_invalid_regex(self) -> None: """Test that invalid regex patterns are rejected at rule creation time.""" with pytest.raises( ValueError, match="Invalid regex pattern for MATCH operator" @@ -194,7 +203,7 @@ async def test_resolve_roles_match_operator_invalid_regex(self): roles=["test_role"], ) - async def test_resolve_roles_match_operator_non_string_pattern(self): + async def test_resolve_roles_match_operator_non_string_pattern(self) -> None: """Test that non-string regex patterns are rejected at rule creation time.""" with pytest.raises( ValueError, match="MATCH operator requires a string pattern" @@ -206,7 +215,7 @@ async def test_resolve_roles_match_operator_non_string_pattern(self): roles=["test_role"], ) - async def test_resolve_roles_match_operator_non_string_value(self): + async def test_resolve_roles_match_operator_non_string_value(self) -> None: """Test role extraction using MATCH operator with non-string match value.""" role_rules = [ JwtRoleRule( @@ -228,7 +237,7 @@ async def test_resolve_roles_match_operator_non_string_value(self): roles = await jwt_resolver.resolve_roles(auth) assert len(roles) == 0 # Non-string values don't match regex - async def test_compiled_regex_property(self): + async def test_compiled_regex_property(self) -> None: """Test that compiled regex pattern is properly created for MATCH operator.""" # Test MATCH operator creates compiled regex match_rule = JwtRoleRule( @@ -250,7 +259,9 @@ async def test_compiled_regex_property(self): ) assert equals_rule.compiled_regex is None - async def test_resolve_roles_with_no_user_token(self, employee_resolver): + async def test_resolve_roles_with_no_user_token( + self, employee_resolver: JwtRolesResolver + ) -> None: """Test NO_USER_TOKEN returns empty claims.""" guest_tuple = ( "user", @@ -269,19 +280,19 @@ class TestGenericAccessResolver: """Test cases for GenericAccessResolver.""" @pytest.fixture - def admin_access_rules(self): + def admin_access_rules(self) -> list[AccessRule]: """Access rules with admin role for testing.""" return [AccessRule(role="superuser", actions=[Action.ADMIN])] @pytest.fixture - def multi_role_access_rules(self): + def multi_role_access_rules(self) -> list[AccessRule]: """Access rules with multiple roles for testing.""" return [ AccessRule(role="user", actions=[Action.QUERY, Action.GET_MODELS]), AccessRule(role="moderator", actions=[Action.FEEDBACK]), ] - async def test_check_access_with_valid_role(self): + async def test_check_access_with_valid_role(self) -> None: """Test access check with valid role.""" access_rules = [ AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) @@ -296,7 +307,7 @@ async def test_check_access_with_valid_role(self): has_access = resolver.check_access(Action.FEEDBACK, frozenset(["employee"])) assert has_access is False - async def test_check_access_with_invalid_role(self): + async def test_check_access_with_invalid_role(self) -> None: """Test access check with invalid role.""" access_rules = [ AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) @@ -306,7 +317,7 @@ async def test_check_access_with_invalid_role(self): has_access = resolver.check_access(Action.QUERY, {"visitor"}) assert has_access is False - async def test_check_access_with_no_roles(self): + async def test_check_access_with_no_roles(self) -> None: """Test access check with no roles.""" access_rules = [ AccessRule(role="employee", actions=[Action.QUERY, Action.GET_MODELS]) @@ -316,19 +327,23 @@ async def test_check_access_with_no_roles(self): has_access = resolver.check_access(Action.QUERY, set()) assert has_access is False - def test_admin_action_with_other_actions_raises_error(self): + def test_admin_action_with_other_actions_raises_error(self) -> None: """Test admin action with others raises ValueError.""" with pytest.raises(ValueError): GenericAccessResolver( [AccessRule(role="superuser", actions=[Action.ADMIN, Action.QUERY])] ) - def test_admin_role_allows_all_actions(self, admin_access_rules): + def test_admin_role_allows_all_actions( + self, admin_access_rules: list[AccessRule] + ) -> None: """Test admin action allows all actions via recursive check.""" resolver = GenericAccessResolver(admin_access_rules) assert resolver.check_access(Action.QUERY, {"superuser"}) is True - def test_admin_get_actions_excludes_admin_action(self, admin_access_rules): + def test_admin_get_actions_excludes_admin_action( + self, admin_access_rules: list[AccessRule] + ) -> None: """Test get actions on a role with admin returns everything except ADMIN.""" resolver = GenericAccessResolver(admin_access_rules) actions = resolver.get_actions({"superuser"}) @@ -336,7 +351,9 @@ def test_admin_get_actions_excludes_admin_action(self, admin_access_rules): assert Action.QUERY in actions assert len(actions) == len(set(Action)) - 1 - def test_get_actions_for_regular_users(self, multi_role_access_rules): + def test_get_actions_for_regular_users( + self, multi_role_access_rules: list[AccessRule] + ) -> None: """Test non-admin user gets only their specific actions.""" resolver = GenericAccessResolver(multi_role_access_rules) actions = resolver.get_actions({"user", "moderator"}) diff --git a/tests/unit/test_configuration_unknown_fields.py b/tests/unit/test_configuration_unknown_fields.py index b97c3f0d3..1fadca6fe 100644 --- a/tests/unit/test_configuration_unknown_fields.py +++ b/tests/unit/test_configuration_unknown_fields.py @@ -6,7 +6,7 @@ from models.config import ServiceConfiguration -def test_configuration_rejects_unknown_fields(): +def test_configuration_rejects_unknown_fields() -> None: """Test that configuration models reject unknown fields.""" with pytest.raises(ValidationError, match="Extra inputs are not permitted"): ServiceConfiguration(host="localhost", port=8080, unknown_field="should_fail") diff --git a/tests/unit/utils/test_llama_stack_version.py b/tests/unit/utils/test_llama_stack_version.py index a6fa9a7d1..43790b801 100644 --- a/tests/unit/utils/test_llama_stack_version.py +++ b/tests/unit/utils/test_llama_stack_version.py @@ -1,8 +1,12 @@ """Unit tests for utility function to check Llama Stack version.""" +from typing import Any + import pytest + from semver import Version from pytest_mock import MockerFixture +from pytest_subtests import SubTests from llama_stack_client.types import VersionInfo @@ -20,7 +24,7 @@ @pytest.mark.asyncio async def test_check_llama_stack_version_minimal_supported_version( mocker: MockerFixture, -): +) -> None: """Test the check_llama_stack_version function.""" # mock the Llama Stack client @@ -36,7 +40,7 @@ async def test_check_llama_stack_version_minimal_supported_version( @pytest.mark.asyncio async def test_check_llama_stack_version_maximal_supported_version( mocker: MockerFixture, -): +) -> None: """Test the check_llama_stack_version function.""" # mock the Llama Stack client @@ -50,7 +54,9 @@ async def test_check_llama_stack_version_maximal_supported_version( @pytest.mark.asyncio -async def test_check_llama_stack_version_too_small_version(mocker: MockerFixture): +async def test_check_llama_stack_version_too_small_version( + mocker: MockerFixture, +) -> None: """Test the check_llama_stack_version function.""" # mock the Llama Stack client @@ -68,7 +74,7 @@ async def test_check_llama_stack_version_too_small_version(mocker: MockerFixture await check_llama_stack_version(mock_client) -async def _check_version_must_fail(mock_client, bigger_version): +async def _check_version_must_fail(mock_client: Any, bigger_version: Version) -> None: mock_client.inspect.version.return_value = VersionInfo(version=str(bigger_version)) expected_exception_msg = ( @@ -81,7 +87,9 @@ async def _check_version_must_fail(mock_client, bigger_version): @pytest.mark.asyncio -async def test_check_llama_stack_version_too_big_version(mocker, subtests): +async def test_check_llama_stack_version_too_big_version( + mocker: MockerFixture, subtests: SubTests +) -> None: """Test the check_llama_stack_version function.""" # mock the Llama Stack client diff --git a/tests/unit/utils/test_transcripts.py b/tests/unit/utils/test_transcripts.py index 918bfe00c..83fc2ecf9 100644 --- a/tests/unit/utils/test_transcripts.py +++ b/tests/unit/utils/test_transcripts.py @@ -4,7 +4,7 @@ from pytest_mock import MockerFixture from configuration import AppConfig -from models.requests import QueryRequest +from models.requests import Attachment, QueryRequest from utils.transcripts import ( construct_transcripts_path, @@ -13,7 +13,7 @@ from utils.types import ToolCallSummary, TurnSummary -def test_construct_transcripts_path(mocker: MockerFixture): +def test_construct_transcripts_path(mocker: MockerFixture) -> None: """Test the construct_transcripts_path function.""" config_dict = { @@ -52,7 +52,7 @@ def test_construct_transcripts_path(mocker: MockerFixture): ), "Path should be constructed correctly" -def test_store_transcript(mocker: MockerFixture): +def test_store_transcript(mocker: MockerFixture) -> None: """Test the store_transcript function.""" mocker.patch("builtins.open", mocker.mock_open()) @@ -83,9 +83,9 @@ def test_store_transcript(mocker: MockerFixture): ], ) query_is_valid = True - rag_chunks = [] + rag_chunks: list[dict] = [] truncated = False - attachments = [] + attachments: list[Attachment] = [] store_transcript( user_id,