Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/apps/api/internal/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Strawberry extensions."""
106 changes: 106 additions & 0 deletions backend/apps/api/internal/extensions/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Strawberry cache extension."""

import hashlib
import json
from functools import lru_cache

from django.conf import settings
from django.core.cache import cache
from django.core.serializers.json import DjangoJSONEncoder
from strawberry.extensions import SchemaExtension
from strawberry.permission import PermissionExtension
from strawberry.schema import Schema
from strawberry.utils.str_converters import to_camel_case


@lru_cache(maxsize=1)
def get_protected_fields(schema: Schema) -> tuple[str, ...]:
"""Get protected field names.

Args:
schema (Schema): The GraphQL schema.

Returns:
tuple[str, ...]: Tuple of protected field names in camelCase.

"""
return tuple(
to_camel_case(field.name)
for field in getattr(
getattr(schema.schema_converter.type_map.get("Query"), "definition", None),
"fields",
(),
)
if any(isinstance(ext, PermissionExtension) for ext in field.extensions)
)


def generate_key(field_name: str, field_args: dict) -> str:
"""Generate a unique cache key for a query.

Args:
field_name (str): The GraphQL field name.
field_args (dict): The field's arguments.

Returns:
str: The unique cache key.

"""
key = f"{field_name}:{json.dumps(field_args, cls=DjangoJSONEncoder, sort_keys=True)}"
return f"{settings.GRAPHQL_RESOLVER_CACHE_PREFIX}-{hashlib.sha256(key.encode()).hexdigest()}"


def invalidate_cache(field_name: str, field_args: dict) -> bool:
"""Invalidate a specific GraphQL query from the resolver cache.

Args:
field_name: The GraphQL field name (e.g., 'getProgram').
field_args: The field's arguments as a dict (e.g., {'programKey': 'my-program'}).

Returns:
True if cache was invalidated, False if key didn't exist.

"""
return cache.delete(generate_key(field_name, field_args))


def invalidate_program_cache(program_key: str) -> None:
"""Invalidate all GraphQL caches related to a program.

Args:
program_key: The program's key identifier.

"""
invalidate_cache("getProgram", {"programKey": program_key})
invalidate_cache("getProgramModules", {"programKey": program_key})


def invalidate_module_cache(module_key: str, program_key: str) -> None:
"""Invalidate all GraphQL caches related to a module.

Args:
module_key: The module's key identifier.
program_key: The program's key identifier.

"""
invalidate_cache("getModule", {"moduleKey": module_key, "programKey": program_key})
invalidate_program_cache(program_key)


class CacheExtension(SchemaExtension):
"""Cache extension."""

def resolve(self, _next, root, info, *args, **kwargs):
"""Wrap the resolver to provide caching."""
if (
info.field_name.startswith("__")
or info.parent_type.name != "Query"
or info.field_name in get_protected_fields(self.execution_context.schema)
):
return _next(root, info, *args, **kwargs)

return cache.get_or_set(
generate_key(info.field_name, kwargs),
lambda: _next(root, info, *args, **kwargs),
settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS,
)
67 changes: 0 additions & 67 deletions backend/apps/common/extensions.py

This file was deleted.

14 changes: 14 additions & 0 deletions backend/apps/mentorship/api/internal/mutations/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.db import transaction
from django.utils import timezone

from apps.api.internal.extensions.cache import invalidate_module_cache, invalidate_program_cache
from apps.github.models import User as GithubUser
from apps.mentorship.api.internal.nodes.module import (
CreateModuleInput,
Expand Down Expand Up @@ -38,6 +39,7 @@ def resolve_mentors_from_logins(logins: list[str]) -> set[Mentor]:
msg = f"GitHub user '{login}' not found."
logger.warning(msg, exc_info=True)
raise ValueError(msg) from e

return mentors


Expand Down Expand Up @@ -119,6 +121,8 @@ def create_module(self, info: strawberry.Info, input_data: CreateModuleInput) ->
mentors_to_set.add(creator_as_mentor)
module.mentors.set(list(mentors_to_set))

transaction.on_commit(lambda: invalidate_program_cache(program.key))

return module

@strawberry.mutation(permission_classes=[IsAuthenticated])
Expand Down Expand Up @@ -330,6 +334,7 @@ def update_module(self, info: strawberry.Info, input_data: UpdateModuleInput) ->
module = Module.objects.select_related("program").get(
key=input_data.key, program__key=input_data.program_key
)
old_module_key = module.key
except Module.DoesNotExist as e:
raise ObjectDoesNotExist(msg=MODULE_NOT_FOUND_MSG) from e

Expand Down Expand Up @@ -400,4 +405,13 @@ def update_module(self, info: strawberry.Info, input_data: UpdateModuleInput) ->

module.program.save(update_fields=["experience_levels"])

program_key = module.program.key

def _invalidate():
invalidate_module_cache(old_module_key, program_key)
if module.key != old_module_key:
invalidate_module_cache(module.key, program_key)

transaction.on_commit(_invalidate)

return module
11 changes: 11 additions & 0 deletions backend/apps/mentorship/api/internal/mutations/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError
from django.db import transaction

from apps.api.internal.extensions.cache import invalidate_program_cache
from apps.mentorship.api.internal.mutations.module import resolve_mentors_from_logins
from apps.mentorship.api.internal.nodes.enum import ProgramStatusEnum
from apps.mentorship.api.internal.nodes.program import (
Expand Down Expand Up @@ -76,6 +77,7 @@ def update_program(self, info: strawberry.Info, input_data: UpdateProgramInput)

try:
program = Program.objects.get(key=input_data.key)
old_key = program.key
except Program.DoesNotExist as err:
msg = f"Program with key '{input_data.key}' not found."
logger.warning(msg, exc_info=True)
Expand Down Expand Up @@ -133,6 +135,13 @@ def update_program(self, info: strawberry.Info, input_data: UpdateProgramInput)
admins_to_set = resolve_mentors_from_logins(input_data.admin_logins)
program.admins.set(admins_to_set)

def _invalidate():
invalidate_program_cache(old_key)
if program.key != old_key:
invalidate_program_cache(program.key)

transaction.on_commit(_invalidate)

return program

@strawberry.mutation(permission_classes=[IsAuthenticated])
Expand Down Expand Up @@ -161,6 +170,8 @@ def update_program_status(
program.status = input_data.status.value
program.save()

transaction.on_commit(lambda: invalidate_program_cache(program.key))

logger.info("Updated status of program '%s' to '%s'", program.key, program.status)

return program
2 changes: 1 addition & 1 deletion backend/settings/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from strawberry.extensions import QueryDepthLimiter
from strawberry_django.optimizer import DjangoOptimizerExtension

from apps.api.internal.extensions.cache import CacheExtension
from apps.api.internal.mutations import ApiMutations
from apps.api.internal.queries import ApiKeyQueries
from apps.common.extensions import CacheExtension
from apps.github.api.internal.queries import GithubQuery
from apps.mentorship.api.internal.mutations import (
ModuleMutation,
Expand Down
Empty file.
Loading
Loading