Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added BasePlugin.add_hook helper #173

Merged
merged 1 commit into from
Apr 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ _release:
scripts/make_release

release: test _release

changelog:
gitchangelog
53 changes: 29 additions & 24 deletions aiocache/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@


class BasePlugin:
pass

@classmethod
def add_hook(cls, fn, hooks):
for hook in hooks:
setattr(cls, hook, fn)

async def do_nothing(self, client, *args, **kwargs):
pass
async def do_nothing(self, *args, **kwargs):
pass

for method in API.CMDS:
setattr(BasePlugin, "pre_{}".format(method.__name__), classmethod(do_nothing))
setattr(BasePlugin, "post_{}".format(method.__name__), classmethod(do_nothing))

BasePlugin.add_hook(
BasePlugin.do_nothing, ["pre_{}".format(method.__name__) for method in API.CMDS])
BasePlugin.add_hook(
BasePlugin.do_nothing, ["post_{}".format(method.__name__) for method in API.CMDS])


class TimingPlugin(BasePlugin):
Expand All @@ -25,31 +30,31 @@ class TimingPlugin(BasePlugin):
access the average time of the operation get, you can do ``cache.profiling['get_avg']``
"""

@classmethod
def save_time(cls, method):

def save_time(method):

async def do_save_time(self, client, *args, took=0, **kwargs):
if not hasattr(client, "profiling"):
client.profiling = {}
async def do_save_time(self, client, *args, took=0, **kwargs):
if not hasattr(client, "profiling"):
client.profiling = {}

previous_total = client.profiling.get("{}_total".format(method), 0)
previous_avg = client.profiling.get("{}_avg".format(method), 0)
previous_max = client.profiling.get("{}_max".format(method), 0)
previous_min = client.profiling.get("{}_min".format(method))
previous_total = client.profiling.get("{}_total".format(method), 0)
previous_avg = client.profiling.get("{}_avg".format(method), 0)
previous_max = client.profiling.get("{}_max".format(method), 0)
previous_min = client.profiling.get("{}_min".format(method))

client.profiling["{}_total".format(method)] = previous_total + 1
client.profiling["{}_avg".format(method)] = \
previous_avg + (took - previous_avg) / (previous_total + 1)
client.profiling["{}_max".format(method)] = max(took, previous_max)
client.profiling["{}_min".format(method)] = \
min(took, previous_min) if previous_min else took
client.profiling["{}_total".format(method)] = previous_total + 1
client.profiling["{}_avg".format(method)] = \
previous_avg + (took - previous_avg) / (previous_total + 1)
client.profiling["{}_max".format(method)] = max(took, previous_max)
client.profiling["{}_min".format(method)] = \
min(took, previous_min) if previous_min else took

return do_save_time
return do_save_time


for method in API.CMDS:
setattr(
TimingPlugin, "post_{}".format(method.__name__), classmethod(save_time(method.__name__)))
TimingPlugin.add_hook(
TimingPlugin.save_time(method.__name__), ["post_{}".format(method.__name__)])


class HitMissRatioPlugin(BasePlugin):
Expand Down
2 changes: 1 addition & 1 deletion docs/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Plugins can be used to change the behavior of the cache. By default all caches a

You can define your custom plugin by inheriting from `BasePlugin`_ and overriding the needed methods (the overrides NEED to be async). All commands have a ``pre`` and a ``post`` hooks.

An complete example of using the plugins:
A complete example of using the plugins:

.. literalinclude:: ../examples/plugins.py
:language: python
Expand Down
31 changes: 21 additions & 10 deletions examples/plugins.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import asyncio
import random
import logging

from aiocache import RedisCache
from aiocache.plugins import HitMissRatioPlugin, TimingPlugin
from aiocache import SimpleMemoryCache
from aiocache.plugins import HitMissRatioPlugin, TimingPlugin, BasePlugin


cache = RedisCache(
endpoint="127.0.0.1",
plugins=[HitMissRatioPlugin(), TimingPlugin()],
port=6379,
logger = logging.getLogger(__name__)


class MyCustomPlugin(BasePlugin):

async def pre_set(self, *args, **kwargs):
logger.info("I'm the pre_set hook being called with %s %s" % (args, kwargs))

async def post_set(self, *args, **kwargs):
logger.info("I'm the post_set hook being called with %s %s" % (args, kwargs))


cache = SimpleMemoryCache(
plugins=[HitMissRatioPlugin(), TimingPlugin(), MyCustomPlugin()],
namespace="main")


async def redis():
async def run():
await cache.set("a", 1)
await cache.set("b", 2)
await cache.set("c", 3)
Expand All @@ -35,14 +46,14 @@ async def redis():
print(cache.profiling)


def test_redis():
def test_run():
loop = asyncio.get_event_loop()
loop.run_until_complete(redis())
loop.run_until_complete(run())
loop.run_until_complete(cache.delete("a"))
loop.run_until_complete(cache.delete("b"))
loop.run_until_complete(cache.delete("c"))
loop.run_until_complete(cache.delete("d"))


if __name__ == "__main__":
test_redis()
test_run()
53 changes: 34 additions & 19 deletions tests/ut/test_plugins.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,52 @@
import pytest
import inspect

from unittest.mock import MagicMock

from aiocache.plugins import BasePlugin, HitMissRatioPlugin, save_time, do_nothing
from aiocache.plugins import BasePlugin, TimingPlugin, HitMissRatioPlugin
from aiocache.cache import API, BaseCache


class TestBasePlugin:

def test_interface_methods(self):
@pytest.mark.asyncio
async def test_interface_methods(self):
for method in API.CMDS:
assert hasattr(BasePlugin, "pre_{}".format(method.__name__)) and \
inspect.iscoroutinefunction(getattr(BasePlugin, "pre_{}".format(method.__name__)))
assert hasattr(BasePlugin, "post_{}".format(method.__name__)) and \
inspect.iscoroutinefunction(getattr(BasePlugin, "pre_{}".format(method.__name__)))
assert await getattr(BasePlugin, "pre_{}".format(method.__name__))(MagicMock()) is None
assert await getattr(BasePlugin, "post_{}".format(method.__name__))(MagicMock()) is None

@pytest.mark.asyncio
async def test_do_nothing(self):
assert await BasePlugin().do_nothing() is None

@pytest.mark.asyncio
async def test_do_nothing():
assert await do_nothing(MagicMock(), MagicMock()) is None

class TestTimingPlugin:

@pytest.mark.asyncio
async def test_save_time(mock_cache):
do_save_time = save_time('get')
await do_save_time('self', mock_cache, took=1)
await do_save_time('self', mock_cache, took=2)
@pytest.mark.asyncio
async def test_save_time(mock_cache):
do_save_time = TimingPlugin().save_time('get')
await do_save_time('self', mock_cache, took=1)
await do_save_time('self', mock_cache, took=2)

assert mock_cache.profiling["get_total"] == 2
assert mock_cache.profiling["get_max"] == 2
assert mock_cache.profiling["get_min"] == 1
assert mock_cache.profiling["get_avg"] == 1.5
assert mock_cache.profiling["get_total"] == 2
assert mock_cache.profiling["get_max"] == 2
assert mock_cache.profiling["get_min"] == 1
assert mock_cache.profiling["get_avg"] == 1.5

@pytest.mark.asyncio
async def test_save_time_post_set(mock_cache):
await TimingPlugin().post_set(mock_cache, took=1)
await TimingPlugin().post_set(mock_cache, took=2)

assert mock_cache.profiling["set_total"] == 2
assert mock_cache.profiling["set_max"] == 2
assert mock_cache.profiling["set_min"] == 1
assert mock_cache.profiling["set_avg"] == 1.5

@pytest.mark.asyncio
async def test_interface_methods(self):
for method in API.CMDS:
assert hasattr(TimingPlugin, "pre_{}".format(method.__name__))
assert hasattr(TimingPlugin, "post_{}".format(method.__name__))


class TestHitMissRatioPlugin:
Expand Down