diff --git a/backend/apps/api/internal/extensions/__init__.py b/backend/apps/api/internal/extensions/__init__.py deleted file mode 100644 index 4b350753ac..0000000000 --- a/backend/apps/api/internal/extensions/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Strawberry extensions.""" diff --git a/backend/apps/api/internal/extensions/cache.py b/backend/apps/api/internal/extensions/cache.py deleted file mode 100644 index 1d2396ad12..0000000000 --- a/backend/apps/api/internal/extensions/cache.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Strawberry cache extension.""" - -import hashlib -import json -from functools import lru_cache - -from django.conf import settings -from django.core.cache import cache -from django.core.serializers.json import DjangoJSONEncoder -from strawberry.extensions import SchemaExtension -from strawberry.permission import PermissionExtension -from strawberry.schema import Schema -from strawberry.utils.str_converters import to_camel_case - - -@lru_cache(maxsize=1) -def get_protected_fields(schema: Schema) -> tuple[str, ...]: - """Get protected field names. - - Args: - schema (Schema): The GraphQL schema. - - Returns: - tuple[str, ...]: Tuple of protected field names in camelCase. - - """ - return tuple( - to_camel_case(field.name) - for field in getattr( - getattr(schema.schema_converter.type_map.get("Query"), "definition", None), - "fields", - (), - ) - if any(isinstance(ext, PermissionExtension) for ext in field.extensions) - ) - - -def generate_key(field_name: str, field_args: dict) -> str: - """Generate a unique cache key for a query. - - Args: - field_name (str): The GraphQL field name. - field_args (dict): The field's arguments. - - Returns: - str: The unique cache key. - - """ - key = f"{field_name}:{json.dumps(field_args, cls=DjangoJSONEncoder, sort_keys=True)}" - return f"{settings.GRAPHQL_RESOLVER_CACHE_PREFIX}-{hashlib.sha256(key.encode()).hexdigest()}" - - -def invalidate_cache(field_name: str, field_args: dict) -> bool: - """Invalidate a specific GraphQL query from the resolver cache. - - Args: - field_name: The GraphQL field name (e.g., 'getProgram'). - field_args: The field's arguments as a dict (e.g., {'programKey': 'my-program'}). - - Returns: - True if cache was invalidated, False if key didn't exist. - - """ - return cache.delete(generate_key(field_name, field_args)) - - -def invalidate_program_cache(program_key: str) -> None: - """Invalidate all GraphQL caches related to a program. - - Args: - program_key: The program's key identifier. - - """ - invalidate_cache("getProgram", {"programKey": program_key}) - invalidate_cache("getProgramModules", {"programKey": program_key}) - - -def invalidate_module_cache(module_key: str, program_key: str) -> None: - """Invalidate all GraphQL caches related to a module. - - Args: - module_key: The module's key identifier. - program_key: The program's key identifier. - - """ - invalidate_cache("getModule", {"moduleKey": module_key, "programKey": program_key}) - invalidate_program_cache(program_key) - - -class CacheExtension(SchemaExtension): - """Cache extension.""" - - def resolve(self, _next, root, info, *args, **kwargs): - """Wrap the resolver to provide caching.""" - if ( - info.field_name.startswith("__") - or info.parent_type.name != "Query" - or info.field_name in get_protected_fields(self.execution_context.schema) - ): - return _next(root, info, *args, **kwargs) - - return cache.get_or_set( - generate_key(info.field_name, kwargs), - lambda: _next(root, info, *args, **kwargs), - settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS, - ) diff --git a/backend/apps/mentorship/api/internal/mutations/module.py b/backend/apps/mentorship/api/internal/mutations/module.py index 5244ee69cf..00446803bf 100644 --- a/backend/apps/mentorship/api/internal/mutations/module.py +++ b/backend/apps/mentorship/api/internal/mutations/module.py @@ -8,7 +8,6 @@ from django.db import transaction from django.utils import timezone -from apps.api.internal.extensions.cache import invalidate_module_cache, invalidate_program_cache from apps.github.models import User as GithubUser from apps.mentorship.api.internal.nodes.module import ( CreateModuleInput, @@ -122,8 +121,6 @@ def create_module(self, info: strawberry.Info, input_data: CreateModuleInput) -> mentors_to_set.add(creator_as_mentor) module.mentors.set(list(mentors_to_set)) - transaction.on_commit(lambda: invalidate_program_cache(program.key)) - return module @strawberry.mutation(permission_classes=[IsAuthenticated]) @@ -336,7 +333,6 @@ def update_module(self, info: strawberry.Info, input_data: UpdateModuleInput) -> module = Module.objects.select_related("program").get( key=input_data.key, program__key=input_data.program_key ) - old_module_key = module.key except Module.DoesNotExist as e: raise ObjectDoesNotExist(MODULE_NOT_FOUND_MSG) from e @@ -407,13 +403,4 @@ def update_module(self, info: strawberry.Info, input_data: UpdateModuleInput) -> module.program.save(update_fields=["experience_levels"]) - program_key = module.program.key - - def _invalidate(): - invalidate_module_cache(old_module_key, program_key) - if module.key != old_module_key: - invalidate_module_cache(module.key, program_key) - - transaction.on_commit(_invalidate) - return module diff --git a/backend/apps/mentorship/api/internal/mutations/program.py b/backend/apps/mentorship/api/internal/mutations/program.py index 88bba96ad8..0905e28c53 100644 --- a/backend/apps/mentorship/api/internal/mutations/program.py +++ b/backend/apps/mentorship/api/internal/mutations/program.py @@ -6,7 +6,6 @@ from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError from django.db import transaction -from apps.api.internal.extensions.cache import invalidate_program_cache from apps.mentorship.api.internal.mutations.module import resolve_mentors_from_logins from apps.mentorship.api.internal.nodes.enum import ProgramStatusEnum from apps.mentorship.api.internal.nodes.program import ( @@ -77,7 +76,6 @@ def update_program(self, info: strawberry.Info, input_data: UpdateProgramInput) try: program = Program.objects.get(key=input_data.key) - old_key = program.key except Program.DoesNotExist as err: msg = f"Program with key '{input_data.key}' not found." logger.warning(msg, exc_info=True) @@ -135,13 +133,6 @@ def update_program(self, info: strawberry.Info, input_data: UpdateProgramInput) admins_to_set = resolve_mentors_from_logins(input_data.admin_logins) program.admins.set(admins_to_set) - def _invalidate(): - invalidate_program_cache(old_key) - if program.key != old_key: - invalidate_program_cache(program.key) - - transaction.on_commit(_invalidate) - return program @strawberry.mutation(permission_classes=[IsAuthenticated]) @@ -170,8 +161,6 @@ def update_program_status( program.status = input_data.status.value program.save() - transaction.on_commit(lambda: invalidate_program_cache(program.key)) - logger.info("Updated status of program '%s' to '%s'", program.key, program.status) return program diff --git a/backend/settings/base.py b/backend/settings/base.py index 748615dab3..007a94b3b6 100644 --- a/backend/settings/base.py +++ b/backend/settings/base.py @@ -133,8 +133,6 @@ class Base(Configuration): API_PAGE_SIZE = 100 API_CACHE_PREFIX = "api-response" API_CACHE_TIME_SECONDS = 86400 # 24 hours. - GRAPHQL_RESOLVER_CACHE_PREFIX = "graphql-resolver" - GRAPHQL_RESOLVER_CACHE_TIME_SECONDS = 86400 # 24 hours. NINJA_PAGINATION_CLASS = "apps.api.rest.v0.pagination.CustomPagination" NINJA_PAGINATION_PER_PAGE = API_PAGE_SIZE diff --git a/backend/settings/graphql.py b/backend/settings/graphql.py index 39eb24ab67..f24e12c34a 100644 --- a/backend/settings/graphql.py +++ b/backend/settings/graphql.py @@ -4,7 +4,6 @@ from strawberry.extensions import QueryDepthLimiter from strawberry_django.optimizer import DjangoOptimizerExtension -from apps.api.internal.extensions.cache import CacheExtension from apps.api.internal.mutations import ApiMutations from apps.api.internal.queries import ApiKeyQueries from apps.github.api.internal.queries import GithubQuery @@ -46,5 +45,5 @@ class Query( schema = strawberry.Schema( mutation=Mutation, query=Query, - extensions=[CacheExtension, QueryDepthLimiter(max_depth=5), DjangoOptimizerExtension()], + extensions=[QueryDepthLimiter(max_depth=5), DjangoOptimizerExtension()], ) diff --git a/backend/tests/apps/api/internal/extensions/__init__.py b/backend/tests/apps/api/internal/extensions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/backend/tests/apps/api/internal/extensions/cache_test.py b/backend/tests/apps/api/internal/extensions/cache_test.py deleted file mode 100644 index d1d0056129..0000000000 --- a/backend/tests/apps/api/internal/extensions/cache_test.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Tests for CacheExtension.""" - -import datetime -from unittest.mock import MagicMock, patch -from uuid import UUID - -import pytest -from django.conf import settings -from strawberry.permission import PermissionExtension - -from apps.api.internal.extensions.cache import ( - CacheExtension, - generate_key, - get_protected_fields, - invalidate_module_cache, - invalidate_program_cache, -) - - -class TestGenerateKey: - """Test cases for the generate_key function.""" - - def test_creates_deterministic_hash(self): - """Test that generate_key creates a deterministic hash key.""" - key1 = generate_key("chapter", {"key": "germany"}) - key2 = generate_key("chapter", {"key": "germany"}) - - assert key1 == key2 - assert key1.startswith("graphql-") - assert len(key1.split("-")[-1]) == 64 # SHA256 hex digest length - - def test_differs_for_different_field_names(self): - """Test that different field names produce different keys.""" - key1 = generate_key("chapter", {"key": "germany"}) - key2 = generate_key("project", {"key": "germany"}) - - assert key1 != key2 - - def test_differs_for_different_args(self): - """Test that different arguments produce different keys.""" - key1 = generate_key("chapter", {"key": "germany"}) - key2 = generate_key("chapter", {"key": "canada"}) - - assert key1 != key2 - - def test_sorts_args_for_consistency(self): - """Test that argument order doesn't affect the key.""" - key1 = generate_key("chapter", {"a": "1", "b": "2"}) - key2 = generate_key("chapter", {"b": "2", "a": "1"}) - - assert key1 == key2 - - def test_serializes_datetime_and_uuid_args(self): - """Test that generate_key handles datetime and UUID in args without error.""" - dt = datetime.datetime(2025, 1, 15, 12, 0, 0, tzinfo=datetime.UTC) - uid = UUID("550e8400-e29b-41d4-a716-446655440000") - field_args = {"at": dt, "id": uid} - - key1 = generate_key("someField", field_args) - key2 = generate_key("someField", field_args) - - assert key1 == key2 - assert key1.startswith(f"{settings.GRAPHQL_RESOLVER_CACHE_PREFIX}-") - - -class TestGetProtectedFields: - """Test cases for the get_protected_fields function.""" - - @pytest.fixture - def mock_schema(self): - """Return a mock schema with protected and public fields.""" - mock_field_protected = MagicMock() - mock_field_protected.name = "api_keys" - mock_field_protected.extensions = [MagicMock(spec=PermissionExtension)] - - mock_field_public = MagicMock() - mock_field_public.name = "chapters" - mock_field_public.extensions = [] - - mock_query_type = MagicMock() - mock_query_type.definition.fields = [mock_field_protected, mock_field_public] - - mock_schema = MagicMock() - mock_schema.schema_converter.type_map.get.return_value = mock_query_type - return mock_schema - - def test_returns_protected_fields_in_camel_case(self, mock_schema): - """Test that protected fields are returned in camelCase.""" - get_protected_fields.cache_clear() - protected = get_protected_fields(mock_schema) - - assert "apiKeys" in protected - assert "chapters" not in protected - - def test_returns_tuple(self, mock_schema): - """Test that get_protected_fields returns a tuple.""" - get_protected_fields.cache_clear() - protected = get_protected_fields(mock_schema) - - assert isinstance(protected, tuple) - - -class TestResolve: - """Test cases for the resolve method.""" - - @pytest.fixture(autouse=True) - def mock_protected_fields(self): - """Patch get_protected_fields for all tests.""" - with patch( - "apps.api.internal.extensions.cache.get_protected_fields", return_value=("apiKeys",) - ): - yield - - @pytest.fixture - def mock_info(self): - """Return a mock GraphQL resolve info.""" - mock = MagicMock() - mock.field_name = "chapter" - mock.parent_type.name = "Query" - return mock - - @pytest.fixture - def mock_next(self): - """Return a mock next resolver.""" - return MagicMock(return_value={"name": "OWASP"}) - - @pytest.fixture - def extension(self): - """Return a CacheExtension instance.""" - extension = CacheExtension() - extension.execution_context = MagicMock() - return extension - - def test_skips_introspection_queries(self, extension, mock_info, mock_next): - """Test that introspection queries skip caching.""" - mock_info.field_name = "__schema" - - result = extension.resolve(mock_next, None, mock_info) - - mock_next.assert_called_once() - assert result == mock_next.return_value - - def test_skips_non_query_fields(self, extension, mock_info, mock_next): - """Test that non-Query parent types skip caching.""" - mock_info.parent_type.name = "ChapterNode" - - result = extension.resolve(mock_next, None, mock_info) - - mock_next.assert_called_once() - assert result == mock_next.return_value - - def test_skips_protected_fields(self, extension, mock_info, mock_next): - """Test that protected fields skip caching.""" - mock_info.field_name = "apiKeys" - - result = extension.resolve(mock_next, None, mock_info) - - mock_next.assert_called_once() - assert result == mock_next.return_value - - @patch("apps.api.internal.extensions.cache.cache") - def test_returns_cached_result_on_hit(self, mock_cache, extension, mock_info, mock_next): - """Test that cached result is returned on cache hit.""" - cached_result = {"name": "Cached OWASP"} - mock_cache.get_or_set.return_value = cached_result - - result = extension.resolve(mock_next, None, mock_info, key="germany") - - assert result == cached_result - mock_cache.get_or_set.assert_called_once() - mock_next.assert_not_called() - - @patch("apps.api.internal.extensions.cache.cache") - def test_caches_result_on_miss(self, mock_cache, extension, mock_info, mock_next): - """Test that result is cached on cache miss.""" - mock_cache.get_or_set.side_effect = lambda _key, default, _timeout: default() - - extension.resolve(mock_next, None, mock_info, key="germany") - - mock_next.assert_called_once() - mock_cache.get_or_set.assert_called_once() - - -class TestInvalidationHelpers: - """Test cases for invalidation helper functions.""" - - @patch("apps.api.internal.extensions.cache.cache") - @patch("apps.api.internal.extensions.cache.generate_key") - def test_invalidate_program_cache_uses_camel_case_keys(self, mock_generate_key, mock_cache): - """Test that invalidate_program_cache uses correct camelCase keys.""" - mock_generate_key.side_effect = lambda name, _args: f"{name}-hashed" - - invalidate_program_cache("my-program") - - # Verify calls to generate_key use camelCase 'programKey' - assert mock_generate_key.call_count == 2 - mock_generate_key.assert_any_call("getProgram", {"programKey": "my-program"}) - mock_generate_key.assert_any_call("getProgramModules", {"programKey": "my-program"}) - - assert mock_cache.delete.call_count == 2 - mock_cache.delete.assert_any_call("getProgram-hashed") - mock_cache.delete.assert_any_call("getProgramModules-hashed") - - @patch("apps.api.internal.extensions.cache.cache") - @patch("apps.api.internal.extensions.cache.generate_key") - def test_invalidate_module_cache_uses_camel_case_keys(self, mock_generate_key, mock_cache): - """Test that invalidate_module_cache uses correct camelCase keys.""" - mock_generate_key.side_effect = lambda name, _args: f"{name}-hashed" - - invalidate_module_cache("module-1", "program-1") - - assert mock_generate_key.call_count == 3 - mock_generate_key.assert_any_call( - "getModule", {"moduleKey": "module-1", "programKey": "program-1"} - ) - mock_generate_key.assert_any_call("getProgram", {"programKey": "program-1"}) - mock_generate_key.assert_any_call("getProgramModules", {"programKey": "program-1"}) - - assert mock_cache.delete.call_count == 3