Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Created by https://www.toptal.com/developers/gitignore/api/python,linux,visualstudiocode
# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,visualstudiocode

### vscode ###
.vscode

### Linux ###
*~

Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[pytest]
asyncio_mode = auto
async_mongodb_fixture_dir =
tests/unit/fixtures
; This tests specifying exactly which fixtures to load
Expand Down
125 changes: 76 additions & 49 deletions pytest_async_mongodb/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import json
import codecs
import mongomock
import pytest
import yaml
import pytest_asyncio

_cache = {}

Expand Down Expand Up @@ -37,11 +37,13 @@ async def wrapped(*args, **kwargs):
return wrapped


def wrapp_methods(cls):
for method_name in cls.ASYNC_METHODS:
method = getattr(cls, method_name)
setattr(cls, method_name, async_decorator(method))
return cls
def async_wrap(obj):
# wrap all the public interfaces except the one has been re-defined in obj
for item in dir(obj._base_sync_obj):
if not item.startswith("_"):
member = getattr(obj._base_sync_obj, item)
if callable(member) and item not in dir(obj):
setattr(obj, item, async_decorator(member))


class AsyncCursor(mongomock.collection.Cursor):
Expand All @@ -54,48 +56,64 @@ async def __anext__(self):
except StopIteration:
raise StopAsyncIteration()

async def to_list(self, length=None):
the_list = []
try:
while length is None or len(the_list) < length:
the_list.append(next(self))
finally:
return the_list


class AsyncCommandCursor(mongomock.command_cursor.CommandCursor):
def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self)
except StopIteration:
raise StopAsyncIteration()

async def to_list(self, length=None):
the_list = []
try:
while length is None or len(the_list) < length:
the_list.append(next(self))
finally:
return the_list

@wrapp_methods
class AsyncCollection(mongomock.Collection):

ASYNC_METHODS = [
"find_one",
"find_one_and_delete",
"find_one_and_replace",
"find_one_and_update",
"find_and_modify",
"save",
"delete_one",
"delete_many",
"count",
"insert_one",
"insert_many",
"update_one",
"update_many",
"replace_one",
"count_documents",
"estimated_document_count",
"drop",
"create_index",
"ensure_index",
"map_reduce",
]

class AsyncCollection:
def __init__(self, mongomock_collection):
self._base_sync_obj = mongomock_collection
async_wrap(self)

def find(self, *args, **kwargs) -> AsyncCursor:
cursor = super().find(*args, **kwargs)
cursor = self._base_sync_obj.find(*args, **kwargs)
cursor.__class__ = AsyncCursor
return cursor

def aggregate(self, *args, **kwargs) -> AsyncCommandCursor:
cursor = self._base_sync_obj.aggregate(*args, **kwargs)
cursor.__class__ = AsyncCommandCursor
return cursor


class AsyncDatabase:
def __init__(self, mongomock_db):
self._base_sync_obj = mongomock_db
async_wrap(self)

@wrapp_methods
class AsyncDatabase(mongomock.Database):
def __getattr__(self, attr):
return self[attr]

ASYNC_METHODS = ["list_collection_names"]
def __getitem__(self, db_name):
return self.get_collection(db_name)

def get_collection(self, *args, **kwargs) -> AsyncCollection:
collection = super().get_collection(*args, **kwargs)
collection.__class__ = AsyncCollection
return collection
collection = self._base_sync_obj.get_collection(*args, **kwargs)
return AsyncCollection(collection)


class Session:
Expand All @@ -106,29 +124,38 @@ async def __aexit__(self, exc_type, exc, tb):
await asyncio.sleep(0)


class AsyncMockMongoClient(mongomock.MongoClient):
class AsyncMockMongoClient:
def __init__(self, mongomock_client):
self._base_sync_obj = mongomock_client
async_wrap(self)

def __getattr__(self, attr):
return self[attr]

def __getitem__(self, db_name):
return self.get_database(db_name)

def get_database(self, *args, **kwargs) -> AsyncDatabase:
db = super().get_database(*args, **kwargs)
db.__class__ = AsyncDatabase
return db
db = self._base_sync_obj.get_database(*args, **kwargs)
return AsyncDatabase(db)

async def start_session(self, **kwargs):
await asyncio.sleep(0)
return Session()


@pytest.fixture(scope="function")
async def async_mongodb(pytestconfig):
client = AsyncMockMongoClient()
@pytest_asyncio.fixture(scope="function")
async def async_mongodb(event_loop, pytestconfig):
client = AsyncMockMongoClient(mongomock.MongoClient())
db = client["pytest"]
await clean_database(db)
await load_fixtures(db, pytestconfig)
return db


@pytest.fixture(scope="function")
async def async_mongodb_client(pytestconfig):
client = AsyncMockMongoClient()
@pytest_asyncio.fixture(scope="function")
async def async_mongodb_client(event_loop, pytestconfig):
client = AsyncMockMongoClient(mongomock.MongoClient())
db = client["pytest"]
await clean_database(db)
await load_fixtures(db, pytestconfig)
Expand All @@ -138,7 +165,7 @@ async def async_mongodb_client(pytestconfig):
async def clean_database(db):
collections = await db.list_collection_names()
for name in collections:
db.drop_collection(name)
await db.drop_collection(name)


async def load_fixtures(db, config):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mongomock>=3.22.1
pyyaml>=5.1
pymongo>=3.10
pytest-asyncio>=0.11.0
pytest>=5.4
pymongo>=3.10
pyyaml>=5.1
53 changes: 53 additions & 0 deletions tests/unit/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from bson import ObjectId
from pytest_async_mongodb import plugin
import pytest
from pymongo import InsertOne, DESCENDING

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -165,3 +166,55 @@ async def test_find_sorted_with_filter(async_mongodb):
"winner": "France",
},
]


async def test_bulk_write_and_to_list(async_mongodb):
await async_mongodb.championships.bulk_write(
[
InsertOne({"_id": 1, "a": 22}),
InsertOne({"_id": 2, "a": 22}),
InsertOne({"_id": 3, "a": 33}),
]
)
result = async_mongodb.championships.find({"a": 22})
docs = await result.to_list()
assert len(docs) == 2
assert docs[0]["a"] == 22
assert docs[1]["a"] == 22


async def test_estimated_document_count(async_mongodb):
assert await async_mongodb.championships.estimated_document_count() == 4


async def test_find_one_and_update(async_mongodb):
await async_mongodb.championships.find_one_and_update(
filter={"_id": ObjectId("608b0151a20cf0c679939f59")},
update={"$set": {"year": 2022}},
)
doc = await async_mongodb.championships.find_one(
{"_id": ObjectId("608b0151a20cf0c679939f59")}
)
assert doc["year"] == 2022


async def test_chained_operations(async_mongodb):
docs = (
await async_mongodb.championships.find()
.sort("year", DESCENDING)
.skip(1)
.limit(2)
.to_list()
)
assert len(docs) == 2
assert docs[0]["year"] == 2014
assert docs[1]["year"] == 2010
docs = (
await async_mongodb.championships.find()
.sort("year", DESCENDING)
.skip(3)
.limit(2)
.to_list()
)
assert len(docs) == 1
assert docs[0]["year"] == 2006