diff --git a/.gitignore b/.gitignore index 94e70a7..9a87d64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,29 +1,58 @@ -# Ignore test-results -test-results.xml +# Created by https://www.toptal.com/developers/gitignore/api/python,linux,visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,visualstudiocode + +### vscode ### +.vscode + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] +*$py.class # C extensions *.so # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ +downloads/ eggs/ .eggs/ -lib/ -lib64/ parts/ sdist/ var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec # Installer logs pip-log.txt @@ -32,35 +61,125 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage +.coverage.* .cache nosetests.xml coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +pytestdebug.log # Translations *.mo - -# Mr Developer -.mr.developer.cfg -.project -.pydevproject - -# Rope -.ropeproject +*.pot # Django stuff: *.log -*.pot +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy # Sphinx documentation docs/_build/ +doc/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +# .env +.env/ +.venv/ +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pythonenv* + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# operating system-related files +# file properties cache/storage on macOS +*.DS_Store +# thumbnail cache on Windows +Thumbs.db + +# profiling data +.prof + + +### VisualStudioCode ### +.vscode/* +!.vscode/tasks.json +!.vscode/launch.json +*.code-workspace + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide -# IDEs -.idea -.sublime -.netbeans -*.swa -*.swp -*.swo +# End of https://www.toptal.com/developers/gitignore/api/python,linux,visualstudiocode AUTHORS diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..dfa28a8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Gonzalo Verussa + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pytest.ini b/pytest.ini index 46cdef2..ecec326 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] +asyncio_mode = auto async_mongodb_fixture_dir = tests/unit/fixtures ; This tests specifying exactly which fixtures to load diff --git a/pytest_async_mongodb/plugin.py b/pytest_async_mongodb/plugin.py index 329a029..b7c82da 100644 --- a/pytest_async_mongodb/plugin.py +++ b/pytest_async_mongodb/plugin.py @@ -1,15 +1,12 @@ +from bson import json_util import asyncio import os import functools import json import codecs -import types - import mongomock -import pytest import yaml -from bson import json_util - +import pytest_asyncio _cache = {} @@ -17,93 +14,106 @@ def pytest_addoption(parser): parser.addini( - name='async_mongodb_fixtures', - help='Load these fixtures for tests', - type='linelist') + name="async_mongodb_fixtures", + help="Load these fixtures for tests", + type="linelist", + ) parser.addini( - name='async_mongodb_fixture_dir', - help='Try loading fixtures from this directory', - default=os.getcwd()) + name="async_mongodb_fixture_dir", + help="Try loading fixtures from this directory", + default=os.getcwd(), + ) parser.addoption( - '--async_mongodb-fixture-dir', - help='Try loading fixtures from this directory') + "--async_mongodb-fixture-dir", help="Try loading fixtures from this directory" + ) -def wrapper(func): - @functools.wraps(func) +def async_decorator(func): async def wrapped(*args, **kwargs): - coro_func = asyncio.coroutine(func) - return await coro_func(*args, **kwargs) + return func(*args, **kwargs) + return wrapped -class AsyncClassMethod(object): - - ASYNC_METHODS = [] - - def __getattribute__(self, name): - attr = super(AsyncClassMethod, self).__getattribute__(name) - if type(attr) == types.MethodType and name in self.ASYNC_METHODS: - attr = wrapper(attr) - return attr - - -class AsyncCollection(AsyncClassMethod, mongomock.Collection): - - ASYNC_METHODS = [ - 'find', - 'find_one', - 'find_one_and_delete', - 'find_one_and_replace', - 'find_one_and_update', - 'find_and_modify', - 'save', - 'delete_one', - 'delete_many', - 'count', - 'insert_one', - 'insert_many', - 'update_one', - 'update_many', - 'replace_one', - 'count_documents', - 'estimated_document_count', - 'drop', - 'create_index', - 'ensure_index', - 'map_reduce', - ] - - async def find_one(self, filter=None, *args, **kwargs): - import collections - # Allow calling find_one with a non-dict argument that gets used as - # the id for the query. - if filter is None: - filter = {} - if not isinstance(filter, collections.Mapping): - filter = {'_id': filter} - - cursor = await self.find(filter, *args, **kwargs) +def async_wrap(obj): + # wrap all the public interfaces except the one has been re-defined in obj + for item in dir(obj._base_sync_obj): + if not item.startswith("_"): + member = getattr(obj._base_sync_obj, item) + if callable(member) and item not in dir(obj): + setattr(obj, item, async_decorator(member)) + + +class AsyncCursor(mongomock.collection.Cursor): + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self) + except StopIteration: + raise StopAsyncIteration() + + async def to_list(self, length=None): + the_list = [] + try: + while length is None or len(the_list) < length: + the_list.append(next(self)) + finally: + return the_list + + +class AsyncCommandCursor(mongomock.command_cursor.CommandCursor): + def __aiter__(self): + return self + + async def __anext__(self): try: - return next(cursor) + return next(self) except StopIteration: - return None + raise StopAsyncIteration() + + async def to_list(self, length=None): + the_list = [] + try: + while length is None or len(the_list) < length: + the_list.append(next(self)) + finally: + return the_list + +class AsyncCollection: + def __init__(self, mongomock_collection): + self._base_sync_obj = mongomock_collection + async_wrap(self) -class AsyncDatabase(AsyncClassMethod, mongomock.Database): + def find(self, *args, **kwargs) -> AsyncCursor: + cursor = self._base_sync_obj.find(*args, **kwargs) + cursor.__class__ = AsyncCursor + return cursor - ASYNC_METHODS = [ - 'collection_names' - ] + def aggregate(self, *args, **kwargs) -> AsyncCommandCursor: + cursor = self._base_sync_obj.aggregate(*args, **kwargs) + cursor.__class__ = AsyncCommandCursor + return cursor - def get_collection(self, name, codec_options=None, read_preference=None, - write_concern=None): - collection = self._collections.get(name) - if collection is None: - collection = self._collections[name] = AsyncCollection(self, name) - return collection + +class AsyncDatabase: + def __init__(self, mongomock_db): + self._base_sync_obj = mongomock_db + async_wrap(self) + + def __getattr__(self, attr): + return self[attr] + + def __getitem__(self, db_name): + return self.get_collection(db_name) + + def get_collection(self, *args, **kwargs) -> AsyncCollection: + collection = self._base_sync_obj.get_collection(*args, **kwargs) + return AsyncCollection(collection) class Session: @@ -114,54 +124,60 @@ async def __aexit__(self, exc_type, exc, tb): await asyncio.sleep(0) -class AsyncMockMongoClient(mongomock.MongoClient): +class AsyncMockMongoClient: + def __init__(self, mongomock_client): + self._base_sync_obj = mongomock_client + async_wrap(self) + + def __getattr__(self, attr): + return self[attr] + + def __getitem__(self, db_name): + return self.get_database(db_name) - def get_database(self, name, codec_options=None, read_preference=None, - write_concern=None): - db = self._databases.get(name) - if db is None: - db = self._databases[name] = AsyncDatabase(self, name) - return db + def get_database(self, *args, **kwargs) -> AsyncDatabase: + db = self._base_sync_obj.get_database(*args, **kwargs) + return AsyncDatabase(db) async def start_session(self, **kwargs): await asyncio.sleep(0) return Session() -@pytest.fixture(scope='function') -async def async_mongodb(pytestconfig): - client = AsyncMockMongoClient() - db = client['pytest'] +@pytest_asyncio.fixture(scope="function") +async def async_mongodb(event_loop, pytestconfig): + client = AsyncMockMongoClient(mongomock.MongoClient()) + db = client["pytest"] await clean_database(db) await load_fixtures(db, pytestconfig) return db -@pytest.fixture(scope='function') -async def async_mongodb_client(pytestconfig): - client = AsyncMockMongoClient() - db = client['pytest'] +@pytest_asyncio.fixture(scope="function") +async def async_mongodb_client(event_loop, pytestconfig): + client = AsyncMockMongoClient(mongomock.MongoClient()) + db = client["pytest"] await clean_database(db) await load_fixtures(db, pytestconfig) return client async def clean_database(db): - collections = await db.collection_names(include_system_collections=False) + collections = await db.list_collection_names() for name in collections: - db.drop_collection(name) + await db.drop_collection(name) async def load_fixtures(db, config): - option_dir = config.getoption('async_mongodb_fixture_dir') - ini_dir = config.getini('async_mongodb_fixture_dir') - fixtures = config.getini('async_mongodb_fixtures') + option_dir = config.getoption("async_mongodb_fixture_dir") + ini_dir = config.getini("async_mongodb_fixture_dir") + fixtures = config.getini("async_mongodb_fixtures") basedir = option_dir or ini_dir for file_name in os.listdir(basedir): collection, ext = os.path.splitext(os.path.basename(file_name)) - file_format = ext.strip('.') - supported = file_format in ('json', 'yaml') + file_format = ext.strip(".") + supported = file_format in ("json", "yaml") selected = fixtures and collection in fixtures if selected and supported: path = os.path.join(basedir, file_name) @@ -169,16 +185,16 @@ async def load_fixtures(db, config): async def load_fixture(db, collection, path, file_format): - if file_format == 'json': + if file_format == "json": loader = functools.partial(json.load, object_hook=json_util.object_hook) - elif file_format == 'yaml': - loader = yaml.load + elif file_format == "yaml": + loader = functools.partial(yaml.load, Loader=yaml.FullLoader) else: return try: docs = _cache[path] except KeyError: - with codecs.open(path, encoding='utf-8') as fp: + with codecs.open(path, encoding="utf-8") as fp: _cache[path] = docs = loader(fp) for document in docs: diff --git a/requirements.txt b/requirements.txt index cb7c20d..3ce1cbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -mongomock -pyyaml -pytest-asyncio -pytest>=2.5.2 +mongomock>=3.22.1 +pymongo>=3.10 +pytest-asyncio>=0.11.0 +pytest>=5.4 +pyyaml>=5.1 diff --git a/setup.cfg b/setup.cfg index 757aca4..6f90a62 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,15 +19,16 @@ classifier = Topic :: Database Topic :: Software Development :: Libraries - [files] packages = pytest_async_mongodb - [entry_points] pytest11 = pytest-async-mongodb = pytest_async_mongodb.plugin - [wheel] universal = 1 + +[pbr] +skip_changelog = true +skip_authors = true \ No newline at end of file diff --git a/tests/unit/fixtures/championships.json b/tests/unit/fixtures/championships.json index 695761a..1ad17b9 100644 --- a/tests/unit/fixtures/championships.json +++ b/tests/unit/fixtures/championships.json @@ -1,4 +1,11 @@ [ + + { + "_id": {"$oid": "608b0151a20cf0c679939f59"}, + "year": 2018, + "host": "Russia", + "winner": "France" + }, { "_id": {"$oid": "55d2db06f4811f83a1f27be8"}, "year": 2014, diff --git a/tests/unit/test_plugin.py b/tests/unit/test_plugin.py index b0a5e13..bc59197 100644 --- a/tests/unit/test_plugin.py +++ b/tests/unit/test_plugin.py @@ -1,50 +1,220 @@ -import pytest +from bson import ObjectId from pytest_async_mongodb import plugin +import pytest +from pymongo import InsertOne, DESCENDING + +pytestmark = pytest.mark.asyncio -@pytest.mark.asyncio async def test_load(async_mongodb): - collection_names = await async_mongodb.collection_names() - assert 'players' in collection_names - assert 'championships' in collection_names + collection_names = await async_mongodb.list_collection_names() + assert "players" in collection_names + assert "championships" in collection_names assert len(plugin._cache.keys()) == 2 await check_players(async_mongodb.players) await check_championships(async_mongodb.championships) -@pytest.mark.asyncio async def check_players(players): count = await players.count_documents({}) assert count == 2 - await check_keys_in_docs(players, ['name', 'surname', 'position']) - manuel = await players.find_one({'name': 'Manuel'}) - assert manuel['surname'] == 'Neuer' - assert manuel['position'] == 'keeper' + await check_keys_in_docs(players, ["name", "surname", "position"]) + manuel = await players.find_one({"name": "Manuel"}) + assert manuel["surname"] == "Neuer" + assert manuel["position"] == "keeper" -@pytest.mark.asyncio async def check_championships(championships): count = await championships.count_documents({}) - assert count == 3 - await check_keys_in_docs(championships, ['year', 'host', 'winner']) + assert count == 4 + await check_keys_in_docs(championships, ["year", "host", "winner"]) -@pytest.mark.asyncio async def check_keys_in_docs(collection, keys): - docs = await collection.find() + docs = collection.find() for doc in docs: for key in keys: assert key in doc -@pytest.mark.asyncio async def test_insert(async_mongodb): - await async_mongodb.players.insert_one({ - 'name': 'Bastian', - 'surname': 'Schweinsteiger', - 'position': 'midfield' - }) - count = await async_mongodb.players.count_documents({}) - bastian = await async_mongodb.players.find_one({'name': 'Bastian'}) - assert count == 3 - assert bastian.get('name') == 'Bastian' + count_before = await async_mongodb.players.count_documents({}) + await async_mongodb.players.insert_one( + {"name": "Bastian", "surname": "Schweinsteiger", "position": "midfield"} + ) + count_after = await async_mongodb.players.count_documents({}) + bastian = await async_mongodb.players.find_one({"name": "Bastian"}) + assert count_after == count_before + 1 + assert bastian.get("name") == "Bastian" + + +async def test_find_one(async_mongodb): + doc = await async_mongodb.championships.find_one() + assert doc == { + "_id": ObjectId("608b0151a20cf0c679939f59"), + "year": 2018, + "host": "Russia", + "winner": "France", + } + + +async def test_find(async_mongodb): + docs = async_mongodb.championships.find() + docs_list = [] + async for doc in docs: + docs_list.append(doc) + assert docs_list == [ + { + "_id": ObjectId("608b0151a20cf0c679939f59"), + "year": 2018, + "host": "Russia", + "winner": "France", + }, + { + "_id": ObjectId("55d2db06f4811f83a1f27be8"), + "year": 2014, + "host": "Brazil", + "winner": "Germany", + }, + { + "_id": ObjectId("55d2db19f4811f83a1f27be9"), + "year": 2010, + "host": "South Africa", + "winner": "Spain", + }, + { + "_id": ObjectId("55d2db30f4811f83a1f27bea"), + "year": 2006, + "host": "Germany", + "winner": "France", + }, + ] + + +async def test_find_with_filter(async_mongodb): + docs = async_mongodb.championships.find({"winner": "France"}) + docs_list = [] + async for doc in docs: + docs_list.append(doc) + assert docs_list == [ + { + "_id": ObjectId("608b0151a20cf0c679939f59"), + "year": 2018, + "host": "Russia", + "winner": "France", + }, + { + "_id": ObjectId("55d2db30f4811f83a1f27bea"), + "year": 2006, + "host": "Germany", + "winner": "France", + }, + ] + + +async def test_find_sorted(async_mongodb): + docs = async_mongodb.championships.find(sort=[("year", 1)]) + docs_list = [] + async for doc in docs: + docs_list.append(doc) + assert docs_list == [ + { + "_id": ObjectId("55d2db30f4811f83a1f27bea"), + "year": 2006, + "host": "Germany", + "winner": "France", + }, + { + "_id": ObjectId("55d2db19f4811f83a1f27be9"), + "year": 2010, + "host": "South Africa", + "winner": "Spain", + }, + { + "_id": ObjectId("55d2db06f4811f83a1f27be8"), + "year": 2014, + "host": "Brazil", + "winner": "Germany", + }, + { + "_id": ObjectId("608b0151a20cf0c679939f59"), + "year": 2018, + "host": "Russia", + "winner": "France", + }, + ] + + +async def test_find_sorted_with_filter(async_mongodb): + docs = async_mongodb.championships.find( + filter={"winner": "France"}, sort=[("year", 1)] + ) + docs_list = [] + async for doc in docs: + docs_list.append(doc) + assert docs_list == [ + { + "_id": ObjectId("55d2db30f4811f83a1f27bea"), + "year": 2006, + "host": "Germany", + "winner": "France", + }, + { + "_id": ObjectId("608b0151a20cf0c679939f59"), + "year": 2018, + "host": "Russia", + "winner": "France", + }, + ] + + +async def test_bulk_write_and_to_list(async_mongodb): + await async_mongodb.championships.bulk_write( + [ + InsertOne({"_id": 1, "a": 22}), + InsertOne({"_id": 2, "a": 22}), + InsertOne({"_id": 3, "a": 33}), + ] + ) + result = async_mongodb.championships.find({"a": 22}) + docs = await result.to_list() + assert len(docs) == 2 + assert docs[0]["a"] == 22 + assert docs[1]["a"] == 22 + + +async def test_estimated_document_count(async_mongodb): + assert await async_mongodb.championships.estimated_document_count() == 4 + + +async def test_find_one_and_update(async_mongodb): + await async_mongodb.championships.find_one_and_update( + filter={"_id": ObjectId("608b0151a20cf0c679939f59")}, + update={"$set": {"year": 2022}}, + ) + doc = await async_mongodb.championships.find_one( + {"_id": ObjectId("608b0151a20cf0c679939f59")} + ) + assert doc["year"] == 2022 + + +async def test_chained_operations(async_mongodb): + docs = ( + await async_mongodb.championships.find() + .sort("year", DESCENDING) + .skip(1) + .limit(2) + .to_list() + ) + assert len(docs) == 2 + assert docs[0]["year"] == 2014 + assert docs[1]["year"] == 2010 + docs = ( + await async_mongodb.championships.find() + .sort("year", DESCENDING) + .skip(3) + .limit(2) + .to_list() + ) + assert len(docs) == 1 + assert docs[0]["year"] == 2006