diff --git a/.gitignore b/.gitignore index b89c115..9a87d64 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ # Created by https://www.toptal.com/developers/gitignore/api/python,linux,visualstudiocode # Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,visualstudiocode +### vscode ### +.vscode + ### Linux ### *~ diff --git a/pytest.ini b/pytest.ini index 46cdef2..ecec326 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] +asyncio_mode = auto async_mongodb_fixture_dir = tests/unit/fixtures ; This tests specifying exactly which fixtures to load diff --git a/pytest_async_mongodb/plugin.py b/pytest_async_mongodb/plugin.py index 585b536..b7c82da 100644 --- a/pytest_async_mongodb/plugin.py +++ b/pytest_async_mongodb/plugin.py @@ -5,8 +5,8 @@ import json import codecs import mongomock -import pytest import yaml +import pytest_asyncio _cache = {} @@ -37,11 +37,13 @@ async def wrapped(*args, **kwargs): return wrapped -def wrapp_methods(cls): - for method_name in cls.ASYNC_METHODS: - method = getattr(cls, method_name) - setattr(cls, method_name, async_decorator(method)) - return cls +def async_wrap(obj): + # wrap all the public interfaces except the one has been re-defined in obj + for item in dir(obj._base_sync_obj): + if not item.startswith("_"): + member = getattr(obj._base_sync_obj, item) + if callable(member) and item not in dir(obj): + setattr(obj, item, async_decorator(member)) class AsyncCursor(mongomock.collection.Cursor): @@ -54,48 +56,64 @@ async def __anext__(self): except StopIteration: raise StopAsyncIteration() + async def to_list(self, length=None): + the_list = [] + try: + while length is None or len(the_list) < length: + the_list.append(next(self)) + finally: + return the_list + + +class AsyncCommandCursor(mongomock.command_cursor.CommandCursor): + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self) + except StopIteration: + raise StopAsyncIteration() + + async def to_list(self, length=None): + the_list = [] + try: + while length is None or len(the_list) < length: + the_list.append(next(self)) + finally: + return the_list -@wrapp_methods -class AsyncCollection(mongomock.Collection): - - ASYNC_METHODS = [ - "find_one", - "find_one_and_delete", - "find_one_and_replace", - "find_one_and_update", - "find_and_modify", - "save", - "delete_one", - "delete_many", - "count", - "insert_one", - "insert_many", - "update_one", - "update_many", - "replace_one", - "count_documents", - "estimated_document_count", - "drop", - "create_index", - "ensure_index", - "map_reduce", - ] + +class AsyncCollection: + def __init__(self, mongomock_collection): + self._base_sync_obj = mongomock_collection + async_wrap(self) def find(self, *args, **kwargs) -> AsyncCursor: - cursor = super().find(*args, **kwargs) + cursor = self._base_sync_obj.find(*args, **kwargs) cursor.__class__ = AsyncCursor return cursor + def aggregate(self, *args, **kwargs) -> AsyncCommandCursor: + cursor = self._base_sync_obj.aggregate(*args, **kwargs) + cursor.__class__ = AsyncCommandCursor + return cursor + + +class AsyncDatabase: + def __init__(self, mongomock_db): + self._base_sync_obj = mongomock_db + async_wrap(self) -@wrapp_methods -class AsyncDatabase(mongomock.Database): + def __getattr__(self, attr): + return self[attr] - ASYNC_METHODS = ["list_collection_names"] + def __getitem__(self, db_name): + return self.get_collection(db_name) def get_collection(self, *args, **kwargs) -> AsyncCollection: - collection = super().get_collection(*args, **kwargs) - collection.__class__ = AsyncCollection - return collection + collection = self._base_sync_obj.get_collection(*args, **kwargs) + return AsyncCollection(collection) class Session: @@ -106,29 +124,38 @@ async def __aexit__(self, exc_type, exc, tb): await asyncio.sleep(0) -class AsyncMockMongoClient(mongomock.MongoClient): +class AsyncMockMongoClient: + def __init__(self, mongomock_client): + self._base_sync_obj = mongomock_client + async_wrap(self) + + def __getattr__(self, attr): + return self[attr] + + def __getitem__(self, db_name): + return self.get_database(db_name) + def get_database(self, *args, **kwargs) -> AsyncDatabase: - db = super().get_database(*args, **kwargs) - db.__class__ = AsyncDatabase - return db + db = self._base_sync_obj.get_database(*args, **kwargs) + return AsyncDatabase(db) async def start_session(self, **kwargs): await asyncio.sleep(0) return Session() -@pytest.fixture(scope="function") -async def async_mongodb(pytestconfig): - client = AsyncMockMongoClient() +@pytest_asyncio.fixture(scope="function") +async def async_mongodb(event_loop, pytestconfig): + client = AsyncMockMongoClient(mongomock.MongoClient()) db = client["pytest"] await clean_database(db) await load_fixtures(db, pytestconfig) return db -@pytest.fixture(scope="function") -async def async_mongodb_client(pytestconfig): - client = AsyncMockMongoClient() +@pytest_asyncio.fixture(scope="function") +async def async_mongodb_client(event_loop, pytestconfig): + client = AsyncMockMongoClient(mongomock.MongoClient()) db = client["pytest"] await clean_database(db) await load_fixtures(db, pytestconfig) @@ -138,7 +165,7 @@ async def async_mongodb_client(pytestconfig): async def clean_database(db): collections = await db.list_collection_names() for name in collections: - db.drop_collection(name) + await db.drop_collection(name) async def load_fixtures(db, config): diff --git a/requirements.txt b/requirements.txt index 978fc84..3ce1cbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ mongomock>=3.22.1 -pyyaml>=5.1 +pymongo>=3.10 pytest-asyncio>=0.11.0 pytest>=5.4 -pymongo>=3.10 \ No newline at end of file +pyyaml>=5.1 diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py index 7e4613d..bc59197 100644 --- a/tests/unit/test_plugin.py +++ b/tests/unit/test_plugin.py @@ -1,6 +1,7 @@ from bson import ObjectId from pytest_async_mongodb import plugin import pytest +from pymongo import InsertOne, DESCENDING pytestmark = pytest.mark.asyncio @@ -165,3 +166,55 @@ async def test_find_sorted_with_filter(async_mongodb): "winner": "France", }, ] + + +async def test_bulk_write_and_to_list(async_mongodb): + await async_mongodb.championships.bulk_write( + [ + InsertOne({"_id": 1, "a": 22}), + InsertOne({"_id": 2, "a": 22}), + InsertOne({"_id": 3, "a": 33}), + ] + ) + result = async_mongodb.championships.find({"a": 22}) + docs = await result.to_list() + assert len(docs) == 2 + assert docs[0]["a"] == 22 + assert docs[1]["a"] == 22 + + +async def test_estimated_document_count(async_mongodb): + assert await async_mongodb.championships.estimated_document_count() == 4 + + +async def test_find_one_and_update(async_mongodb): + await async_mongodb.championships.find_one_and_update( + filter={"_id": ObjectId("608b0151a20cf0c679939f59")}, + update={"$set": {"year": 2022}}, + ) + doc = await async_mongodb.championships.find_one( + {"_id": ObjectId("608b0151a20cf0c679939f59")} + ) + assert doc["year"] == 2022 + + +async def test_chained_operations(async_mongodb): + docs = ( + await async_mongodb.championships.find() + .sort("year", DESCENDING) + .skip(1) + .limit(2) + .to_list() + ) + assert len(docs) == 2 + assert docs[0]["year"] == 2014 + assert docs[1]["year"] == 2010 + docs = ( + await async_mongodb.championships.find() + .sort("year", DESCENDING) + .skip(3) + .limit(2) + .to_list() + ) + assert len(docs) == 1 + assert docs[0]["year"] == 2006