Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions backend/apps/common/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Strawberry extensions."""

import json
from typing import Any

from django.conf import settings
from django.core.cache import cache
from django.core.serializers.json import DjangoJSONEncoder
from strawberry.extensions.field_extension import FieldExtension
from strawberry.types.info import Info


class CacheFieldExtension(FieldExtension):
"""Cache FieldExtension class."""

def __init__(self, cache_timeout: int | None = None, prefix: str | None = None):
"""Initialize the cache extension.

Args:
cache_timeout (int | None): The TTL for cache entries in seconds.
prefix (str | None): A prefix for the cache key.

"""
self.cache_timeout = cache_timeout or settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS
self.prefix = prefix or settings.GRAPHQL_RESOLVER_CACHE_PREFIX

def _convert_path_to_str(self, path: Any) -> str:
"""Convert the Strawberry path linked list to a string."""
parts = []
current = path
while current:
parts.append(str(current.key))
current = getattr(current, "prev", None)
return ".".join(reversed(parts))

def generate_key(self, source: Any | None, info: Info, kwargs: dict) -> str:
"""Generate a unique cache key for a field.

Args:
source (Any | None): The source/parent object.
info (Info): The Strawberry execution info.
kwargs (dict): The resolver's arguments.

Returns:
str: The unique cache key.

"""
key_kwargs = kwargs.copy()
if source and (source_id := getattr(source, "id", None)) is not None:
key_kwargs["__source_id__"] = str(source_id)

args_str = json.dumps(key_kwargs, sort_keys=True, cls=DjangoJSONEncoder)

return f"{self.prefix}:{self._convert_path_to_str(info.path)}:{args_str}"

def resolve(self, next_: Any, source: Any, info: Info, **kwargs: Any) -> Any:
"""Wrap the resolver to provide caching."""
cache_key = self.generate_key(source, info, kwargs)

return cache.get_or_set(
cache_key,
lambda: next_(source, info, **kwargs),
timeout=self.cache_timeout,
)
2 changes: 2 additions & 0 deletions backend/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class Base(Configuration):
API_PAGE_SIZE = 100
API_CACHE_PREFIX = "api-response"
API_CACHE_TIME_SECONDS = 86400 # 24 hours.
GRAPHQL_RESOLVER_CACHE_PREFIX = "graphql-resolver"
GRAPHQL_RESOLVER_CACHE_TIME_SECONDS = 86400 # 24 hours.
NINJA_PAGINATION_CLASS = "apps.api.rest.v0.pagination.CustomPagination"
NINJA_PAGINATION_PER_PAGE = API_PAGE_SIZE

Expand Down
155 changes: 155 additions & 0 deletions backend/tests/apps/common/extensions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from unittest.mock import MagicMock, patch

import pytest
from strawberry.types.info import Info

from apps.common.extensions import CacheFieldExtension


class MockPath:
def __init__(self, key, typename, prev=None):
self.key = key
self.prev = prev
self.typename = typename


class MockSource:
def __init__(self, id=None): # noqa: A002
self.id = id


@pytest.mark.parametrize(
("source", "path", "kwargs", "prefix", "expected_key"),
[
(
None,
MockPath(key="repository", typename="Query"),
{"organization_key": "OWASP", "repository_key": "nest"},
"p1",
'p1:repository:{"organization_key": "OWASP", "repository_key": "nest"}',
),
(
MockSource(id=123),
MockPath(
key="organization",
typename="RepositoryNode",
prev=MockPath(key="repository", typename="Query"),
),
{},
"p2",
'p2:repository.organization:{"__source_id__": "123"}',
),
(
MockSource(id=0),
MockPath(
key="organization",
typename="RepositoryNode",
prev=MockPath(key="repository", typename="Query"),
),
{},
"p3",
'p3:repository.organization:{"__source_id__": "0"}',
),
(
MockSource(),
MockPath(
key="organization",
typename="RepositoryNode",
prev=MockPath(key="repository", typename="Query"),
),
{},
"p4",
"p4:repository.organization:{}",
),
(
None,
MockPath(
key="badgeCount",
typename="UserNode",
prev=MockPath(
key="author",
typename="IssueNode",
prev=MockPath(
key=0,
typename=None,
prev=MockPath(
key="issues",
typename="RepositoryNode",
prev=MockPath(key="repository", typename="Query"),
),
),
),
),
{},
"graphql-resolver",
"graphql-resolver:repository.issues.0.author.badgeCount:{}",
),
],
)
def test_generate_key(source, path, kwargs, prefix, expected_key):
"""Test cases for the generate_key method."""
mock_info = MagicMock(spec=Info)
mock_info.path = path

extension = CacheFieldExtension(prefix=prefix)
assert extension.generate_key(source, mock_info, kwargs) == expected_key


class TestCacheFieldExtensionResolve:
"""Test cases for the resolve method of CacheFieldExtension."""

@pytest.fixture
def mock_info(self):
"""Return a mock Strawberry Info object."""
mock_info = MagicMock(spec=Info)
mock_info.path = MockPath(key="testField", typename="TestType", prev=None)
return mock_info

@patch("apps.common.extensions.cache")
def test_resolve_caches_result_on_miss(self, mock_cache, mock_info):
"""Test that get_or_set calls the resolver on a cache miss."""
resolver_result = "some data"
next_ = MagicMock(return_value=resolver_result)
extension = CacheFieldExtension(cache_timeout=60)

def cache_miss_side_effect(key, default_callable, timeout=None): # noqa: ARG001
return default_callable()

mock_cache.get_or_set.side_effect = cache_miss_side_effect

result = extension.resolve(next_, source=None, info=mock_info)

assert result == resolver_result
mock_cache.get_or_set.assert_called_once()
next_.assert_called_once()

@patch("apps.common.extensions.cache")
def test_resolve_returns_cached_result_on_hit(self, mock_cache, mock_info):
"""Test that the resolver returns the cached result on a cache hit."""
cached_result = "cached data"
mock_cache.get_or_set.return_value = cached_result
next_ = MagicMock()
extension = CacheFieldExtension()

result = extension.resolve(next_, source=None, info=mock_info)

assert result == cached_result
mock_cache.get_or_set.assert_called_once()
next_.assert_not_called()

@pytest.mark.parametrize("falsy_result", [None, [], {}, 0, False])
@patch("apps.common.extensions.cache")
def test_resolve_caches_falsy_result(self, mock_cache, falsy_result, mock_info):
"""Test that the resolver caches None and other falsy results."""
next_ = MagicMock(return_value=falsy_result)
extension = CacheFieldExtension()

def cache_miss_side_effect(key, default_callable, timeout=None): # noqa: ARG001
return default_callable()

mock_cache.get_or_set.side_effect = cache_miss_side_effect
result = extension.resolve(next_, source=None, info=mock_info)

assert result == falsy_result
mock_cache.get_or_set.assert_called_once()
next_.assert_called_once()