Skip to content

Commit 8826228

Browse files
add tests for cache extension
1 parent fa7fe1a commit 8826228

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pytest
4+
from strawberry.types.info import Info
5+
6+
from apps.common.extensions import CacheFieldExtension
7+
8+
9+
@pytest.mark.parametrize(
10+
("typename", "key", "kwargs", "prefix", "expected_key"),
11+
[
12+
("UserNode", "name", {}, "p1", "p1:UserNode:name:{}"),
13+
(
14+
"RepositoryNode",
15+
"issues",
16+
{"limit": 10},
17+
"p2",
18+
"""p2:RepositoryNode:issues:{"limit": 10}""",
19+
),
20+
(
21+
"RepositoryNode",
22+
"issues",
23+
{"limit": 10, "state": "open"},
24+
"p3",
25+
"""p3:RepositoryNode:issues:{"limit": 10, "state": "open"}""",
26+
),
27+
(
28+
"RepositoryNode",
29+
"issues",
30+
{"state": "open", "limit": 10},
31+
"p4",
32+
"""p4:RepositoryNode:issues:{"limit": 10, "state": "open"}""",
33+
),
34+
],
35+
)
36+
def test_generate_key(typename, key, kwargs, prefix, expected_key):
37+
"""Test cases for the generate_key method."""
38+
mock_info = MagicMock(spec=Info)
39+
mock_info.path.typename = typename
40+
mock_info.path.key = key
41+
42+
extension = CacheFieldExtension(prefix=prefix)
43+
assert extension.generate_key(mock_info, kwargs) == expected_key
44+
45+
46+
class TestCacheFieldExtensionResolve:
47+
"""Test cases for the resolve method of CacheFieldExtension."""
48+
49+
@pytest.fixture
50+
def mock_info(self):
51+
"""Return a mock Strawberry Info object."""
52+
mock_info = MagicMock(spec=Info)
53+
mock_info.path.typename = "TestType"
54+
mock_info.path.key = "testField"
55+
return mock_info
56+
57+
@patch("apps.common.extensions.cache")
58+
def test_resolve_caches_result_on_miss(self, mock_cache, mock_info):
59+
"""Test that the resolver caches the result on a cache miss."""
60+
mock_cache.get.return_value = None
61+
resolver_result = "some data"
62+
next_ = MagicMock(return_value=resolver_result)
63+
extension = CacheFieldExtension(cache_timeout=60)
64+
65+
result = extension.resolve(next_, source=None, info=mock_info)
66+
67+
assert result == resolver_result
68+
mock_cache.get.assert_called_once()
69+
next_.assert_called_once()
70+
mock_cache.set.assert_called_once()
71+
mock_cache.set.assert_called_with(mock_cache.get.call_args[0][0], resolver_result, 60)
72+
73+
@patch("apps.common.extensions.cache")
74+
def test_resolve_returns_cached_result_on_hit(self, mock_cache, mock_info):
75+
"""Test that the resolver returns the cached result on a cache hit."""
76+
cached_result = "cached data"
77+
mock_cache.get.return_value = cached_result
78+
next_ = MagicMock()
79+
extension = CacheFieldExtension()
80+
81+
result = extension.resolve(next_, source=None, info=mock_info)
82+
83+
assert result == cached_result
84+
mock_cache.get.assert_called_once()
85+
next_.assert_not_called()
86+
mock_cache.set.assert_not_called()
87+
88+
@pytest.mark.parametrize("falsy_result", [None, [], {}, 0, False])
89+
@patch("apps.common.extensions.cache")
90+
def test_resolve_does_not_cache_falsy_result(self, mock_cache, falsy_result, mock_info):
91+
"""Test that the resolver does not cache None or other falsy results."""
92+
mock_cache.get.return_value = None
93+
next_ = MagicMock(return_value=falsy_result)
94+
extension = CacheFieldExtension()
95+
96+
result = extension.resolve(next_, source=None, info=mock_info)
97+
98+
assert result == falsy_result
99+
mock_cache.get.assert_called_once()
100+
next_.assert_called_once()
101+
mock_cache.set.assert_not_called()

0 commit comments

Comments
 (0)