Skip to content

Commit 02fc91b

Browse files
authored
Added BasePlugin.add_hook helper (#173)
Will leave undocumented because I'm not 100% this is the best approach.
1 parent 573af38 commit 02fc91b

File tree

5 files changed

+88
-54
lines changed

5 files changed

+88
-54
lines changed

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ _release:
3535
scripts/make_release
3636

3737
release: test _release
38+
39+
changelog:
40+
gitchangelog

aiocache/plugins.py

+29-24
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77

88

99
class BasePlugin:
10-
pass
1110

11+
@classmethod
12+
def add_hook(cls, fn, hooks):
13+
for hook in hooks:
14+
setattr(cls, hook, fn)
1215

13-
async def do_nothing(self, client, *args, **kwargs):
14-
pass
16+
async def do_nothing(self, *args, **kwargs):
17+
pass
1518

16-
for method in API.CMDS:
17-
setattr(BasePlugin, "pre_{}".format(method.__name__), classmethod(do_nothing))
18-
setattr(BasePlugin, "post_{}".format(method.__name__), classmethod(do_nothing))
19+
20+
BasePlugin.add_hook(
21+
BasePlugin.do_nothing, ["pre_{}".format(method.__name__) for method in API.CMDS])
22+
BasePlugin.add_hook(
23+
BasePlugin.do_nothing, ["post_{}".format(method.__name__) for method in API.CMDS])
1924

2025

2126
class TimingPlugin(BasePlugin):
@@ -25,31 +30,31 @@ class TimingPlugin(BasePlugin):
2530
access the average time of the operation get, you can do ``cache.profiling['get_avg']``
2631
"""
2732

33+
@classmethod
34+
def save_time(cls, method):
2835

29-
def save_time(method):
30-
31-
async def do_save_time(self, client, *args, took=0, **kwargs):
32-
if not hasattr(client, "profiling"):
33-
client.profiling = {}
36+
async def do_save_time(self, client, *args, took=0, **kwargs):
37+
if not hasattr(client, "profiling"):
38+
client.profiling = {}
3439

35-
previous_total = client.profiling.get("{}_total".format(method), 0)
36-
previous_avg = client.profiling.get("{}_avg".format(method), 0)
37-
previous_max = client.profiling.get("{}_max".format(method), 0)
38-
previous_min = client.profiling.get("{}_min".format(method))
40+
previous_total = client.profiling.get("{}_total".format(method), 0)
41+
previous_avg = client.profiling.get("{}_avg".format(method), 0)
42+
previous_max = client.profiling.get("{}_max".format(method), 0)
43+
previous_min = client.profiling.get("{}_min".format(method))
3944

40-
client.profiling["{}_total".format(method)] = previous_total + 1
41-
client.profiling["{}_avg".format(method)] = \
42-
previous_avg + (took - previous_avg) / (previous_total + 1)
43-
client.profiling["{}_max".format(method)] = max(took, previous_max)
44-
client.profiling["{}_min".format(method)] = \
45-
min(took, previous_min) if previous_min else took
45+
client.profiling["{}_total".format(method)] = previous_total + 1
46+
client.profiling["{}_avg".format(method)] = \
47+
previous_avg + (took - previous_avg) / (previous_total + 1)
48+
client.profiling["{}_max".format(method)] = max(took, previous_max)
49+
client.profiling["{}_min".format(method)] = \
50+
min(took, previous_min) if previous_min else took
4651

47-
return do_save_time
52+
return do_save_time
4853

4954

5055
for method in API.CMDS:
51-
setattr(
52-
TimingPlugin, "post_{}".format(method.__name__), classmethod(save_time(method.__name__)))
56+
TimingPlugin.add_hook(
57+
TimingPlugin.save_time(method.__name__), ["post_{}".format(method.__name__)])
5358

5459

5560
class HitMissRatioPlugin(BasePlugin):

docs/plugins.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Plugins can be used to change the behavior of the cache. By default all caches a
1212

1313
You can define your custom plugin by inheriting from `BasePlugin`_ and overriding the needed methods (the overrides NEED to be async). All commands have a ``pre`` and a ``post`` hooks.
1414

15-
An complete example of using the plugins:
15+
A complete example of using the plugins:
1616

1717
.. literalinclude:: ../examples/plugins.py
1818
:language: python

examples/plugins.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
11
import asyncio
22
import random
3+
import logging
34

4-
from aiocache import RedisCache
5-
from aiocache.plugins import HitMissRatioPlugin, TimingPlugin
5+
from aiocache import SimpleMemoryCache
6+
from aiocache.plugins import HitMissRatioPlugin, TimingPlugin, BasePlugin
67

78

8-
cache = RedisCache(
9-
endpoint="127.0.0.1",
10-
plugins=[HitMissRatioPlugin(), TimingPlugin()],
11-
port=6379,
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class MyCustomPlugin(BasePlugin):
13+
14+
async def pre_set(self, *args, **kwargs):
15+
logger.info("I'm the pre_set hook being called with %s %s" % (args, kwargs))
16+
17+
async def post_set(self, *args, **kwargs):
18+
logger.info("I'm the post_set hook being called with %s %s" % (args, kwargs))
19+
20+
21+
cache = SimpleMemoryCache(
22+
plugins=[HitMissRatioPlugin(), TimingPlugin(), MyCustomPlugin()],
1223
namespace="main")
1324

1425

15-
async def redis():
26+
async def run():
1627
await cache.set("a", 1)
1728
await cache.set("b", 2)
1829
await cache.set("c", 3)
@@ -35,14 +46,14 @@ async def redis():
3546
print(cache.profiling)
3647

3748

38-
def test_redis():
49+
def test_run():
3950
loop = asyncio.get_event_loop()
40-
loop.run_until_complete(redis())
51+
loop.run_until_complete(run())
4152
loop.run_until_complete(cache.delete("a"))
4253
loop.run_until_complete(cache.delete("b"))
4354
loop.run_until_complete(cache.delete("c"))
4455
loop.run_until_complete(cache.delete("d"))
4556

4657

4758
if __name__ == "__main__":
48-
test_redis()
59+
test_run()

tests/ut/test_plugins.py

+34-19
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,52 @@
11
import pytest
2-
import inspect
32

43
from unittest.mock import MagicMock
54

6-
from aiocache.plugins import BasePlugin, HitMissRatioPlugin, save_time, do_nothing
5+
from aiocache.plugins import BasePlugin, TimingPlugin, HitMissRatioPlugin
76
from aiocache.cache import API, BaseCache
87

98

109
class TestBasePlugin:
1110

12-
def test_interface_methods(self):
11+
@pytest.mark.asyncio
12+
async def test_interface_methods(self):
1313
for method in API.CMDS:
14-
assert hasattr(BasePlugin, "pre_{}".format(method.__name__)) and \
15-
inspect.iscoroutinefunction(getattr(BasePlugin, "pre_{}".format(method.__name__)))
16-
assert hasattr(BasePlugin, "post_{}".format(method.__name__)) and \
17-
inspect.iscoroutinefunction(getattr(BasePlugin, "pre_{}".format(method.__name__)))
14+
assert await getattr(BasePlugin, "pre_{}".format(method.__name__))(MagicMock()) is None
15+
assert await getattr(BasePlugin, "post_{}".format(method.__name__))(MagicMock()) is None
1816

17+
@pytest.mark.asyncio
18+
async def test_do_nothing(self):
19+
assert await BasePlugin().do_nothing() is None
1920

20-
@pytest.mark.asyncio
21-
async def test_do_nothing():
22-
assert await do_nothing(MagicMock(), MagicMock()) is None
2321

22+
class TestTimingPlugin:
2423

25-
@pytest.mark.asyncio
26-
async def test_save_time(mock_cache):
27-
do_save_time = save_time('get')
28-
await do_save_time('self', mock_cache, took=1)
29-
await do_save_time('self', mock_cache, took=2)
24+
@pytest.mark.asyncio
25+
async def test_save_time(mock_cache):
26+
do_save_time = TimingPlugin().save_time('get')
27+
await do_save_time('self', mock_cache, took=1)
28+
await do_save_time('self', mock_cache, took=2)
3029

31-
assert mock_cache.profiling["get_total"] == 2
32-
assert mock_cache.profiling["get_max"] == 2
33-
assert mock_cache.profiling["get_min"] == 1
34-
assert mock_cache.profiling["get_avg"] == 1.5
30+
assert mock_cache.profiling["get_total"] == 2
31+
assert mock_cache.profiling["get_max"] == 2
32+
assert mock_cache.profiling["get_min"] == 1
33+
assert mock_cache.profiling["get_avg"] == 1.5
34+
35+
@pytest.mark.asyncio
36+
async def test_save_time_post_set(mock_cache):
37+
await TimingPlugin().post_set(mock_cache, took=1)
38+
await TimingPlugin().post_set(mock_cache, took=2)
39+
40+
assert mock_cache.profiling["set_total"] == 2
41+
assert mock_cache.profiling["set_max"] == 2
42+
assert mock_cache.profiling["set_min"] == 1
43+
assert mock_cache.profiling["set_avg"] == 1.5
44+
45+
@pytest.mark.asyncio
46+
async def test_interface_methods(self):
47+
for method in API.CMDS:
48+
assert hasattr(TimingPlugin, "pre_{}".format(method.__name__))
49+
assert hasattr(TimingPlugin, "post_{}".format(method.__name__))
3550

3651

3752
class TestHitMissRatioPlugin:

0 commit comments

Comments
 (0)