Skip to content

Commit c23b60c

Browse files
fix: nested node caching by implementing path based key generation
1 parent 33af60c commit c23b60c

File tree

2 files changed

+92
-25
lines changed

2 files changed

+92
-25
lines changed

backend/apps/common/extensions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Strawberry extensions."""
22

33
import json
4+
from typing import Any
45

56
from django.conf import settings
67
from django.core.cache import cache
78
from django.core.serializers.json import DjangoJSONEncoder
89
from strawberry.extensions.field_extension import FieldExtension
10+
from strawberry.types.info import Info
911

1012

1113
class CacheFieldExtension(FieldExtension):
@@ -22,24 +24,38 @@ def __init__(self, cache_timeout: int | None = None, prefix: str | None = None):
2224
self.cache_timeout = cache_timeout or settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS
2325
self.prefix = prefix or settings.GRAPHQL_RESOLVER_CACHE_PREFIX
2426

25-
def generate_key(self, info, kwargs: dict) -> str:
27+
def _convert_path_to_str(self, path: Any) -> str:
28+
"""Convert the Strawberry path linked list to a string."""
29+
parts = []
30+
current = path
31+
while current:
32+
parts.append(str(current.key))
33+
current = getattr(current, "prev", None)
34+
return ".".join(reversed(parts))
35+
36+
def generate_key(self, source: Any | None, info: Info, kwargs: dict) -> str:
2637
"""Generate a unique cache key for a field.
2738
2839
Args:
40+
source (Any | None): The source/parent object.
2941
info (Info): The Strawberry execution info.
3042
kwargs (dict): The resolver's arguments.
3143
3244
Returns:
3345
str: The unique cache key.
3446
3547
"""
36-
args_str = json.dumps(kwargs, sort_keys=True, cls=DjangoJSONEncoder)
48+
key_kwargs = kwargs.copy()
49+
if source and (source_id := getattr(source, "id", None)) is not None:
50+
key_kwargs["__source_id__"] = str(source_id)
51+
52+
args_str = json.dumps(key_kwargs, sort_keys=True, cls=DjangoJSONEncoder)
3753

38-
return f"{self.prefix}:{info.path.typename}:{info.path.key}:{args_str}"
54+
return f"{self.prefix}:{self._convert_path_to_str(info.path)}:{args_str}"
3955

40-
def resolve(self, next_, source, info, **kwargs):
56+
def resolve(self, next_: Any, source: Any, info: Info, **kwargs: Any) -> Any:
4157
"""Wrap the resolver to provide caching."""
42-
cache_key = self.generate_key(info, kwargs)
58+
cache_key = self.generate_key(source, info, kwargs)
4359

4460
return cache.get_or_set(
4561
cache_key,

backend/tests/apps/common/extensions_test.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,93 @@
66
from apps.common.extensions import CacheFieldExtension
77

88

9+
class MockPath:
10+
def __init__(self, key, typename, prev=None):
11+
self.key = key
12+
self.prev = prev
13+
self.typename = typename
14+
15+
16+
class MockSource:
17+
def __init__(self, id=None): # noqa: A002
18+
self.id = id
19+
20+
921
@pytest.mark.parametrize(
10-
("typename", "key", "kwargs", "prefix", "expected_key"),
22+
("source", "path", "kwargs", "prefix", "expected_key"),
1123
[
12-
("UserNode", "name", {}, "p1", "p1:UserNode:name:{}"),
1324
(
14-
"RepositoryNode",
15-
"issues",
16-
{"limit": 10},
25+
None,
26+
MockPath(key="repository", typename="Query"),
27+
{"organization_key": "OWASP", "repository_key": "nest"},
28+
"p1",
29+
'p1:repository:{"organization_key": "OWASP", "repository_key": "nest"}',
30+
),
31+
(
32+
MockSource(id=123),
33+
MockPath(
34+
key="organization",
35+
typename="RepositoryNode",
36+
prev=MockPath(key="repository", typename="Query"),
37+
),
38+
{},
1739
"p2",
18-
"""p2:RepositoryNode:issues:{"limit": 10}""",
40+
'p2:repository.organization:{"__source_id__": "123"}',
1941
),
2042
(
21-
"RepositoryNode",
22-
"issues",
23-
{"limit": 10, "state": "open"},
43+
MockSource(id=0),
44+
MockPath(
45+
key="organization",
46+
typename="RepositoryNode",
47+
prev=MockPath(key="repository", typename="Query"),
48+
),
49+
{},
2450
"p3",
25-
"""p3:RepositoryNode:issues:{"limit": 10, "state": "open"}""",
51+
'p3:repository.organization:{"__source_id__": "0"}',
2652
),
2753
(
28-
"RepositoryNode",
29-
"issues",
30-
{"state": "open", "limit": 10},
54+
MockSource(),
55+
MockPath(
56+
key="organization",
57+
typename="RepositoryNode",
58+
prev=MockPath(key="repository", typename="Query"),
59+
),
60+
{},
3161
"p4",
32-
"""p4:RepositoryNode:issues:{"limit": 10, "state": "open"}""",
62+
"p4:repository.organization:{}",
63+
),
64+
(
65+
None,
66+
MockPath(
67+
key="badgeCount",
68+
typename="UserNode",
69+
prev=MockPath(
70+
key="author",
71+
typename="IssueNode",
72+
prev=MockPath(
73+
key=0,
74+
typename=None,
75+
prev=MockPath(
76+
key="issues",
77+
typename="RepositoryNode",
78+
prev=MockPath(key="repository", typename="Query"),
79+
),
80+
),
81+
),
82+
),
83+
{},
84+
"graphql-resolver",
85+
"graphql-resolver:repository.issues.0.author.badgeCount:{}",
3386
),
3487
],
3588
)
36-
def test_generate_key(typename, key, kwargs, prefix, expected_key):
89+
def test_generate_key(source, path, kwargs, prefix, expected_key):
3790
"""Test cases for the generate_key method."""
3891
mock_info = MagicMock(spec=Info)
39-
mock_info.path.typename = typename
40-
mock_info.path.key = key
92+
mock_info.path = path
4193

4294
extension = CacheFieldExtension(prefix=prefix)
43-
assert extension.generate_key(mock_info, kwargs) == expected_key
95+
assert extension.generate_key(source, mock_info, kwargs) == expected_key
4496

4597

4698
class TestCacheFieldExtensionResolve:
@@ -50,8 +102,7 @@ class TestCacheFieldExtensionResolve:
50102
def mock_info(self):
51103
"""Return a mock Strawberry Info object."""
52104
mock_info = MagicMock(spec=Info)
53-
mock_info.path.typename = "TestType"
54-
mock_info.path.key = "testField"
105+
mock_info.path = MockPath(key="testField", typename="TestType", prev=None)
55106
return mock_info
56107

57108
@patch("apps.common.extensions.cache")

0 commit comments

Comments
 (0)