From e18402df908d5b0b0e5d2c3a5f60b1f66c36e24e Mon Sep 17 00:00:00 2001 From: gurcuff91 Date: Sun, 7 Apr 2024 17:22:11 -0400 Subject: [PATCH] Implemented Sync support --- mongotoy/__init__.py | 3 +- mongotoy/db.py | 812 +++++++++++++++++++++++++++++-------------- mongotoy/mappers.py | 15 +- mongotoy/sync.py | 71 ++++ mongotoy/types.py | 58 +++- pyproject.toml | 2 +- 6 files changed, 699 insertions(+), 262 deletions(-) create mode 100644 mongotoy/sync.py diff --git a/mongotoy/__init__.py b/mongotoy/__init__.py index fb1eca2..6cf7d51 100644 --- a/mongotoy/__init__.py +++ b/mongotoy/__init__.py @@ -1,4 +1,5 @@ from .fields import field, reference -from .documents import EmbeddedDocument, Document +from .documents import Document, EmbeddedDocument from .db import Engine +from .sync import enable_sync_mode diff --git a/mongotoy/db.py b/mongotoy/db.py index 2b67775..ea71467 100644 --- a/mongotoy/db.py +++ b/mongotoy/db.py @@ -2,7 +2,9 @@ import datetime import functools import inspect +import math import mimetypes +import os import typing import bson @@ -13,7 +15,7 @@ from motor.motor_gridfs import AgnosticGridFSBucket from pymongo.read_concern import ReadConcern -from mongotoy import documents, expressions, references, fields, types +from mongotoy import documents, expressions, references, fields, types, sync from mongotoy.errors import EngineError, NoResultsError, ManyResultsError, SessionError __all__ = ( @@ -83,18 +85,6 @@ def __init__( self._db_client = None self._migration_lock = asyncio.Lock() - async def connect(self, *conn, ping: bool = False): - """ - Connects to the MongoDB server. - - Args: - *conn: Connection arguments for AsyncIOMotorClient. - ping (bool): Whether to ping the server after connecting. - """ - self._db_client = AsyncIOMotorClient(*conn) - if ping: - await self._db_client.admin.command({'ping': 1}) - @property def client(self) -> AgnosticClient: """ @@ -124,60 +114,6 @@ def database(self) -> AgnosticDatabase: write_concern=self._write_concern ) - def session(self) -> 'Session': - """ - Creates a new MongoDB session. - - Returns: - Session: A new MongoDB session associated with the engine. - """ - return Session(engine=self) - - def collection(self, document_cls_or_name: typing.Type[T] | str) -> AgnosticCollection: - """ - Retrieves the MongoDB collection. - - Args: - document_cls_or_name (typing.Type[T] | str): The document class or collection name. - - Returns: - AgnosticCollection: The MongoDB collection. - """ - if not isinstance(document_cls_or_name, str): - return self._get_document_collection(document_cls_or_name) - - # noinspection PyTypeChecker - return self.database[document_cls_or_name].with_options( - codec_options=self._codec_options, - read_preference=self._read_preference, - read_concern=self._read_concern, - write_concern=self._write_concern - ) - - # noinspection SpellCheckingInspection - def gridfs( - self, - bucket_name: str = 'fs', - chunk_size_bytes: int = gridfs.DEFAULT_CHUNK_SIZE - ) -> AgnosticGridFSBucket: - """ - Retrieves the GridFS bucket. - - Args: - bucket_name (str): The name of the GridFS bucket. - chunk_size_bytes (int): The chunk size in bytes. - - Returns: - AgnosticGridFSBucket: The GridFS bucket. - """ - return AsyncIOMotorGridFSBucket( - database=self.database, - bucket_name=bucket_name, - chunk_size_bytes=chunk_size_bytes, - write_concern=self.database.write_concern, - read_preference=self.database.read_preference - ) - def _get_document_indexes( self, document_cls: typing.Type[documents.BaseDocument], @@ -367,16 +303,30 @@ async def _exec_seeding( # Skip if seeding already applied if skip_exist: + # noinspection PyProtectedMember if await session.objects(Seeding).filter( Seeding.function == func_path - ).count(): + )._count(): do_seeding = False if do_seeding: await func(session) - await session.save(Seeding(function=func_path)) + # noinspection PyProtectedMember + await session._save(Seeding(function=func_path)) - async def migrate( + async def _connect(self, *conn, ping: bool = False): + """ + Connects to the MongoDB server. + + Args: + *conn: Connection arguments for AsyncIOMotorClient. + ping (bool): Whether to ping the server after connecting. + """ + self._db_client = AsyncIOMotorClient(*conn) + if ping: + await self._db_client.admin.command({'ping': 1}) + + async def _migrate( self, document_cls: typing.Type[T], session: 'Session' = None @@ -391,7 +341,7 @@ async def migrate( driver_session = session.driver_session if session else None await self._exec_migration(document_cls, driver_session=driver_session) - async def migrate_all( + async def _migrate_all( self, documents_cls: list[typing.Type[T]], session: 'Session' = None @@ -413,7 +363,7 @@ async def migrate_all( ) for doc_cls in documents_cls if doc_cls.__collection_name__ not in collections ]) - async def seeding( + async def _seeding( self, func: typing.Callable[['Session'], typing.Coroutine[typing.Any, typing.Any, None]], session: 'Session' = None @@ -427,7 +377,7 @@ async def seeding( """ await self._exec_seeding(func, session=session) - async def seeding_all( + async def _seeding_all( self, funcs: list[typing.Callable[['Session'], typing.Coroutine[typing.Any, typing.Any, None]]], session: 'Session' = None @@ -439,7 +389,8 @@ async def seeding_all( funcs (list[Callable[['Session'], Coroutine[Any, Any, None]]]): List of seeding functions. session (Session, optional): The session object. """ - seeds = await session.objects(Seeding).fetch() + # noinspection PyProtectedMember + seeds = await session.objects(Seeding)._fetch() seeds = [s.function for s in seeds] # noinspection PyUnresolvedReferences await asyncio.gather(*[ @@ -450,8 +401,102 @@ async def seeding_all( ) for func in funcs if f'{func.__module__}.{func.__name__}' not in seeds ]) + def session(self) -> 'Session': + """ + Creates a new MongoDB session. + + Returns: + Session: A new MongoDB session associated with the engine. + """ + return Session(engine=self) + + def collection(self, document_cls_or_name: typing.Type[T] | str) -> AgnosticCollection: + """ + Retrieves the MongoDB collection. + + Args: + document_cls_or_name (typing.Type[T] | str): The document class or collection name. + + Returns: + AgnosticCollection: The MongoDB collection. + """ + if not isinstance(document_cls_or_name, str): + return self._get_document_collection(document_cls_or_name) + + # noinspection PyTypeChecker + return self.database[document_cls_or_name].with_options( + codec_options=self._codec_options, + read_preference=self._read_preference, + read_concern=self._read_concern, + write_concern=self._write_concern + ) + + # noinspection SpellCheckingInspection + def gridfs( + self, + bucket_name: str = 'fs', + chunk_size_bytes: int = gridfs.DEFAULT_CHUNK_SIZE + ) -> AgnosticGridFSBucket: + """ + Retrieves the GridFS bucket. + + Args: + bucket_name (str): The name of the GridFS bucket. + chunk_size_bytes (int): The chunk size in bytes. -class Session(typing.AsyncContextManager): + Returns: + AgnosticGridFSBucket: The GridFS bucket. + """ + return AsyncIOMotorGridFSBucket( + database=self.database, + bucket_name=bucket_name, + chunk_size_bytes=chunk_size_bytes, + write_concern=self.database.write_concern, + read_preference=self.database.read_preference + ) + + @sync.proxy + def connect( + self, + *conn, + ping: bool = False + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._connect(*conn, ping) + + @sync.proxy + async def migrate( + self, + document_cls: typing.Type[T], + session: 'Session' = None + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._migrate(document_cls, session) + + @sync.proxy + def migrate_all( + self, + documents_cls: list[typing.Type[T]], + session: 'Session' = None + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._migrate_all(documents_cls, session) + + @sync.proxy + def seeding( + self, + func: typing.Callable[['Session'], typing.Coroutine[typing.Any, typing.Any, None]], + session: 'Session' = None + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._seeding(func, session) + + @sync.proxy + def seeding_all( + self, + funcs: list[typing.Callable[['Session'], typing.Coroutine[typing.Any, typing.Any, None]]], + session: 'Session' = None + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._seeding_all(funcs, session) + + +class Session(typing.AsyncContextManager, typing.ContextManager): """ Represents a MongoDB session for performing database operations within a transaction-like context. @@ -491,73 +536,45 @@ def driver_session(self) -> AgnosticClientSession: raise SessionError('Session not started') return self._driver_session - async def start(self): - """ - Starts the MongoDB session. - - Raises: - EngineError: If the session is already started. - """ - if not self._driver_session: - self._driver_session = await self.engine.client.start_session() - - async def end(self): - """ - Ends the MongoDB session. - - Raises: - EngineError: If the session is not started. - """ - if self.driver_session: - await self.driver_session.end_session() - self._driver_session = None - async def __aenter__(self) -> 'Session': """ Enables the use of the 'async with' statement. """ - await self.start() + await self._start() return self - async def __aexit__(self, __exc_type, __exc_value, __traceback) -> None: + async def __aexit__(self, exc_type, exc_value, traceback) -> None: """ Enables the use of the 'async with' statement. Ends the session upon exiting the context. """ - await self.end() + await self._end() - def transaction(self) -> 'Transaction': - """ - Creates a new MongoDB transaction. + def __enter__(self) -> 'Session': + return sync.run_sync(self.__aenter__)() - Returns: - Transaction: A new MongoDB transaction associated with the engine - """ - return Transaction(session=self) + def __exit__(self, exc_type, exc_value, traceback): + sync.run_sync(self.__aexit__)(exc_type, exc_value, traceback) - def objects(self, document_cls: typing.Type[T], dereference_deep: int = 0) -> 'Objects[T]': + async def _start(self): """ - Returns an object manager for the specified document class. - - Args: - document_cls (typing.Type[T]): The document class. - dereference_deep (int): Depth of dereferencing. + Starts the MongoDB session. - Returns: - Objects[T]: An object manager. + Raises: + EngineError: If the session is already started. """ - return Objects(document_cls, session=self, dereference_deep=dereference_deep) + if not self._driver_session: + self._driver_session = await self.engine.client.start_session() - def fs(self, chunk_size_bytes: int = gridfs.DEFAULT_CHUNK_SIZE) -> 'FsBucket': + async def _end(self): """ - Returns a GridFS bucket manager. - - Args: - chunk_size_bytes (int): The chunk size in bytes. + Ends the MongoDB session. - Returns: - FsBucket: A GridFS bucket manager. + Raises: + EngineError: If the session is not started. """ - return FsBucket(self, chunk_size_bytes=chunk_size_bytes) + if self.driver_session: + await self.driver_session.end_session() + self._driver_session = None async def _save_references(self, doc: T): """ @@ -574,12 +591,12 @@ async def _save_references(self, doc: T): if not reference.is_many: obj = [obj] operations.append( - self.save_all(obj, save_references=True) + self._save_all(obj, save_references=True) ) await asyncio.gather(*operations) - async def save(self, doc: T, save_references: bool = False): + async def _save(self, doc: T, save_references: bool = False): """ Saves a document to the database. @@ -592,17 +609,15 @@ async def save(self, doc: T, save_references: bool = False): operations.append(self._save_references(doc)) son = doc.dump_bson() - operations.append( - self.engine.collection(doc.__collection_name__).update_one( - filter=Query.Eq('_id', son.pop('_id')), - update={'$set': son}, - upsert=True, - session=self.driver_session - ) - ) await asyncio.gather(*operations) + await self.engine.collection(doc.__collection_name__).update_one( + filter=Query.Eq('_id', son.pop('_id')), + update={'$set': son}, + upsert=True, + session=self.driver_session + ) - async def save_all(self, docs: list[T], save_references: bool = False): + async def _save_all(self, docs: list[T], save_references: bool = False): """ Saves a list of documents to the database. @@ -610,7 +625,7 @@ async def save_all(self, docs: list[T], save_references: bool = False): docs (list[T]): The list of document objects to save. save_references (bool): Whether to save referenced documents. """ - await asyncio.gather(*[self.save(i, save_references) for i in docs if i is not None]) + await asyncio.gather(*[self._save(i, save_references) for i in docs if i is not None]) async def _delete_cascade(self, doc: T): """ @@ -653,15 +668,15 @@ async def _delete_cascade(self, doc: T): break setattr(ref_doc, field_name, value) # Apply update - operations.append(self.save(ref_doc)) + operations.append(self._save(ref_doc)) # Apply delete if do_delete: - operations.append(self.delete(ref_doc, delete_cascade=True)) + operations.append(self._delete(ref_doc, delete_cascade=True)) await asyncio.gather(*operations) - async def delete(self, doc: T, delete_cascade: bool = False): + async def _delete(self, doc: T, delete_cascade: bool = False): """ Deletes a document from the database. @@ -679,7 +694,7 @@ async def delete(self, doc: T, delete_cascade: bool = False): session=self.driver_session ) - async def delete_all(self, docs: list[T], delete_cascade: bool = False): + async def _delete_all(self, docs: list[T], delete_cascade: bool = False): """ Deletes a list of documents from the database. @@ -687,10 +702,84 @@ async def delete_all(self, docs: list[T], delete_cascade: bool = False): docs (list[T]): The list of document objects to delete. delete_cascade (bool): Whether to delete referenced documents. """ - await asyncio.gather(*[self.delete(i, delete_cascade) for i in docs if i is not None]) + await asyncio.gather(*[self._delete(i, delete_cascade) for i in docs if i is not None]) + + def transaction(self) -> 'Transaction': + """ + Creates a new MongoDB transaction. + + Returns: + Transaction: A new MongoDB transaction associated with the engine + """ + return Transaction(session=self) + def objects(self, document_cls: typing.Type[T], dereference_deep: int = 0) -> 'Objects[T]': + """ + Returns an object manager for the specified document class. -class Transaction(typing.AsyncContextManager): + Args: + document_cls (typing.Type[T]): The document class. + dereference_deep (int): Depth of dereferencing. + + Returns: + Objects[T]: An object manager. + """ + return Objects(document_cls, session=self, dereference_deep=dereference_deep) + + def fs(self, chunk_size_bytes: int = gridfs.DEFAULT_CHUNK_SIZE) -> 'FsBucket': + """ + Returns a GridFS bucket manager. + + Args: + chunk_size_bytes (int): The chunk size in bytes. + + Returns: + FsBucket: A GridFS bucket manager. + """ + return FsBucket(self, chunk_size_bytes=chunk_size_bytes) + + @sync.proxy + def start(self) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._start() + + @sync.proxy + def end(self) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._end() + + @sync.proxy + def save( + self, + doc: T, + save_references: bool = False + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._save(doc, save_references) + + @sync.proxy + def save_all( + self, + docs: list[T], + save_references: bool = False + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._save_all(docs, save_references) + + @sync.proxy + def delete( + self, + doc: T, + delete_cascade: bool = False + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._delete(doc, delete_cascade) + + @sync.proxy + def delete_all( + self, + docs: list[T], + delete_cascade: bool = False + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._delete_all(docs, delete_cascade) + + +class Transaction(typing.AsyncContextManager, typing.ContextManager): """ Represents a MongoDB transaction for performing atomic operations within a session or engine context. @@ -719,24 +808,6 @@ def session(self) -> 'Session': """ return self._session - async def commit(self): - """ - Commits changes and closes the MongoDB transaction. - - Raises: - EngineError: If the transaction is not started. - """ - await self._session.driver_session.commit_transaction() - - async def abort(self): - """ - Discards changes and closes the MongoDB transaction. - - Raises: - EngineError: If the transaction is not started. - """ - await self._session.driver_session.abort_transaction() - async def __aenter__(self) -> 'Transaction': """ Enables the use of the 'async with' statement. @@ -756,9 +827,41 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> None: traceback: The exception traceback. """ if exc_value: - await self.abort() + await self._abort() else: - await self.commit() + await self._commit() + + def __enter__(self) -> 'Transaction': + return self + + def __exit__(self, exc_type, exc_value, traceback): + sync.run_sync(self.__aexit__)(exc_type, exc_value, traceback) + + async def _commit(self): + """ + Commits changes and closes the MongoDB transaction. + + Raises: + EngineError: If the transaction is not started. + """ + await self._session.driver_session.commit_transaction() + + async def _abort(self): + """ + Discards changes and closes the MongoDB transaction. + + Raises: + EngineError: If the transaction is not started. + """ + await self._session.driver_session.abort_transaction() + + @sync.proxy + def commit(self) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._commit() + + @sync.proxy + def abort(self) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._abort() class Objects(typing.Generic[T]): @@ -783,7 +886,7 @@ def __init__(self, document_cls: typing.Type[T], session: Session, dereference_d self._skip = 0 self._limit = 0 - def __copy_with__(self, **options) -> 'Objects[T]': + def __copy__(self, **options) -> 'Objects[T]': """ Creates a shallow copy of the query set with specified options. @@ -806,77 +909,6 @@ def __copy_with__(self, **options) -> 'Objects[T]': return objs - async def create(self, **data) -> T: - """ - Creates a new document in the database. - - Args: - **data: Keyword arguments representing the document data. - - Returns: - T: The newly created document instance. - """ - doc = self._document_cls(**data) - await self._session.save(doc, save_references=True) - return doc - - def filter(self, *queries: expressions.Query | bool, **filters) -> 'Objects[T]': - """ - Adds filter conditions to the query set. - - Args: - *queries (expressions.Query | bool): Variable number of filter expressions. - **filters: Keyword arguments representing additional filter conditions. - - Returns: - Objects[T]: The updated query set with added filter conditions. - """ - _filter = self._filter - for q in queries: - _filter = _filter & q - if filters: - _filter = _filter & expressions.Q(**filters) - return self.__copy_with__(_filter=_filter) - - def sort(self, *sorts: expressions.Sort) -> 'Objects[T]': - """ - Adds sort conditions to the query set. - - Args: - *sorts (expressions.Sort): Variable number of sort expressions. - - Returns: - Objects[T]: The updated query set with added sort conditions. - """ - _sort = self._sort - for sort in sorts: - _sort = _sort | expressions.Sort(sort) - return self.__copy_with__(_sort=_sort) - - def skip(self, skip: int) -> 'Objects[T]': - """ - Sets the number of documents to skip in the result set. - - Args: - skip (int): The number of documents to skip. - - Returns: - Objects[T]: The updated query set with the skip value set. - """ - return self.__copy_with__(_skip=skip) - - def limit(self, limit: int) -> 'Objects[T]': - """ - Sets the maximum number of documents to return. - - Args: - limit (int): The maximum number of documents to return. - - Returns: - Objects[T]: The updated query set with the limit value set. - """ - return self.__copy_with__(_limit=limit) - async def __aiter__(self) -> typing.AsyncGenerator[T, None]: """ Asynchronously iterates over the result set, executing the query. @@ -907,7 +939,26 @@ async def __aiter__(self) -> typing.AsyncGenerator[T, None]: async for data in cursor: yield self._document_cls(**data) - async def fetch(self) -> list[T]: + def __iter__(self) -> typing.Generator[T, None, None]: + for doc in sync.as_sync_gen(self.__aiter__()): + yield doc + + async def _create(self, **data) -> T: + """ + Creates a new document in the database. + + Args: + **data: Keyword arguments representing the document data. + + Returns: + T: The newly created document instance. + """ + doc = self._document_cls(**data) + # noinspection PyProtectedMember + await self._session._save(doc, save_references=True) + return doc + + async def _fetch(self) -> list[T]: """ Retrieves all documents in the result set. @@ -916,7 +967,7 @@ async def fetch(self) -> list[T]: """ return [doc async for doc in self] - async def fetch_one(self) -> T: + async def _fetch_one(self) -> T: """ Retrieves a specific document in the result set. @@ -927,7 +978,7 @@ async def fetch_one(self) -> T: NoResultsError: If no results are found. ManyResultsError: If more than one result is found. """ - docs = await self.limit(2).fetch() + docs = await self.limit(2)._fetch() if not docs: raise NoResultsError() if len(docs) > 1: @@ -935,7 +986,7 @@ async def fetch_one(self) -> T: return docs[0] # noinspection PyShadowingBuiltins - async def fetch_by_id(self, value: typing.Any) -> T: + async def _fetch_by_id(self, value: typing.Any) -> T: """ Retrieves a document by its identifier. @@ -948,9 +999,9 @@ async def fetch_by_id(self, value: typing.Any) -> T: # noinspection PyProtectedMember return await self.filter( self._document_cls.id == self._document_cls.id._field.mapper.validate(value) - ).fetch_one() + )._fetch_one() - async def count(self) -> int: + async def _count(self) -> int: """ Counts the number of documents in the result set. @@ -962,6 +1013,83 @@ async def count(self) -> int: session=self._session.driver_session ) + def filter(self, *queries: expressions.Query | bool, **filters) -> 'Objects[T]': + """ + Adds filter conditions to the query set. + + Args: + *queries (expressions.Query | bool): Variable number of filter expressions. + **filters: Keyword arguments representing additional filter conditions. + + Returns: + Objects[T]: The updated query set with added filter conditions. + """ + _filter = self._filter + for q in queries: + _filter = _filter & q + if filters: + _filter = _filter & expressions.Q(**filters) + return self.__copy__(_filter=_filter) + + def sort(self, *sorts: expressions.Sort) -> 'Objects[T]': + """ + Adds sort conditions to the query set. + + Args: + *sorts (expressions.Sort): Variable number of sort expressions. + + Returns: + Objects[T]: The updated query set with added sort conditions. + """ + _sort = self._sort + for sort in sorts: + _sort = _sort | expressions.Sort(sort) + return self.__copy__(_sort=_sort) + + def skip(self, skip: int) -> 'Objects[T]': + """ + Sets the number of documents to skip in the result set. + + Args: + skip (int): The number of documents to skip. + + Returns: + Objects[T]: The updated query set with the skip value set. + """ + return self.__copy__(_skip=skip) + + def limit(self, limit: int) -> 'Objects[T]': + """ + Sets the maximum number of documents to return. + + Args: + limit (int): The maximum number of documents to return. + + Returns: + Objects[T]: The updated query set with the limit value set. + """ + return self.__copy__(_limit=limit) + + @sync.proxy + def create(self, **data) -> typing.Coroutine[typing.Any, typing.Any, T] | T: + return self._create(**data) + + @sync.proxy + def fetch(self) -> typing.Coroutine[typing.Any, typing.Any, list[T]] | list[T]: + return self._fetch() + + @sync.proxy + def fetch_one(self) -> typing.Coroutine[typing.Any, typing.Any, T] | T: + return self._fetch_one() + + @sync.proxy + def fetch_by_id(self, value: typing.Any) -> typing.Coroutine[typing.Any, typing.Any, T] | T: + return self._fetch_by_id(value) + + @sync.proxy + def count(self) -> typing.Coroutine[typing.Any, typing.Any, int] | int: + return self._count() + # noinspection PyProtectedMember class FsBucket(Objects['FsObject']): @@ -989,7 +1117,7 @@ def __init__( self._chunk_size_bytes = chunk_size_bytes # noinspection PyMethodMayBeStatic - async def create( + async def _create( self, filename: str, src: typing.IO | bytes, @@ -1029,11 +1157,11 @@ async def create( session=self._session.driver_session ) # Update obj info - obj = await self.fetch_by_id(obj.id) + obj = await self._fetch_by_id(obj.id) return obj - async def exist(self, filename: str) -> bool: + async def _exist(self, filename: str) -> bool: """ Checks if a file exists in the file system bucket. @@ -1043,10 +1171,10 @@ async def exist(self, filename: str) -> bool: Returns: bool: True if the file exists, False otherwise. """ - count = await self.filter(Query.Eq('filename', filename)).count() + count = await self.filter(Query.Eq('filename', filename))._count() return count > 0 - async def revisions(self, filename: str) -> list['FsObject']: + async def _revisions(self, filename: str) -> list['FsObject']: """ Retrieves all revisions of a file from the file system bucket. @@ -1056,7 +1184,31 @@ async def revisions(self, filename: str) -> list['FsObject']: Returns: list[FsObject]: A list of file objects representing revisions. """ - return await self.filter(Query.Eq('filename', filename)).fetch() + return await self.filter(Query.Eq('filename', filename))._fetch() + + @sync.proxy + def create( + self, + filename: str, + src: typing.IO | bytes, + metadata: dict = None, + chunk_size_bytes: int = None + ) -> typing.Coroutine[typing.Any, typing.Any, 'FsObject'] | 'FsObject': + return self._create(filename, src, metadata, chunk_size_bytes) + + @sync.proxy + def exist( + self, + filename: str + ) -> typing.Coroutine[typing.Any, typing.Any, bool] | bool: + return self._exist(filename) + + @sync.proxy + def revisions( + self, + filename: str + ) -> typing.Coroutine[typing.Any, typing.Any, list['FsObject']] | list['FsObject']: + return self._revisions(filename) # noinspection PyProtectedMember @@ -1082,7 +1234,21 @@ class FsObject(documents.Document): __collection_name__ = 'fs.files' - async def create_revision(self, fs: FsBucket, src: typing.IO | bytes, metadata: dict = None): + @property + def content_type(self) -> str: + return self.metadata.get('contentType') + + @property + def chunks(self) -> int: + import math + return math.ceil(self.length / self.chunk_size) + + async def _create_revision( + self, + fs: FsBucket, + src: typing.IO | bytes, + metadata: dict = None + ) -> 'FsObject': """ Creates a revision of the file. @@ -1092,7 +1258,7 @@ async def create_revision(self, fs: FsBucket, src: typing.IO | bytes, metadata: metadata (dict, optional): Additional metadata for the new revision. """ - await fs.create( + return await fs._create( self.filename, src=src, metadata=metadata, @@ -1100,7 +1266,7 @@ async def create_revision(self, fs: FsBucket, src: typing.IO | bytes, metadata: ) # noinspection SpellCheckingInspection - async def download(self, fs: FsBucket, dest: typing.IO, revision: int = None): + async def _download(self, fs: FsBucket, dest: typing.IO, revision: int = None): """ Downloads the file from the file system. @@ -1124,7 +1290,7 @@ async def download(self, fs: FsBucket, dest: typing.IO, revision: int = None): session=fs._session.driver_session ) - async def stream(self, fs: FsBucket, revision: int = None) -> AsyncIOMotorGridOut: + async def _stream(self, fs: FsBucket, revision: int = None) -> 'FsOutput': """ Streams the file from the file system. @@ -1137,18 +1303,156 @@ async def stream(self, fs: FsBucket, revision: int = None) -> AsyncIOMotorGridOu """ if revision is None: - return await fs._bucket.open_download_stream( + grid_out = await fs._bucket.open_download_stream( file_id=self.id, session=fs._session.driver_session ) - return await fs._bucket.open_download_stream_by_name( - filename=self.filename, - revision=revision, + else: + grid_out = await fs._bucket.open_download_stream_by_name( + filename=self.filename, + revision=revision, + session=fs._session.driver_session + ) + return FsOutput(grid_out) + + async def _delete(self, fs: FsBucket): + await fs._bucket.delete( + file_id=self.id, session=fs._session.driver_session ) + @sync.proxy + def create_revision( + self, + fs: FsBucket, + src: typing.IO | bytes, + metadata: dict = None + ) -> typing.Union[typing.Coroutine[typing.Any, typing.Any, 'FsObject'], 'FsObject']: + return self._create_revision(fs, src, metadata) + + @sync.proxy + def download( + self, + fs: FsBucket, + dest: typing.IO, + revision: int = None + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._download(fs, dest, revision) + + @sync.proxy + def stream( + self, + fs: FsBucket, + revision: int = None + ) -> typing.Union[typing.Coroutine[typing.Any, typing.Any, 'FsOutput'], 'FsOutput']: + return self._stream(fs, revision) + + @sync.proxy + def delete(self, fs: FsBucket) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + return self._delete(fs) + + +class FsOutput: + """ + Represents a file stored in MongoDB GridFS. + + Args: + grid_out (AsyncIOMotorGridOut): The underlying GridFS file object. + + """ + + def __init__(self, grid_out: AsyncIOMotorGridOut): + self._driver_grid_out = grid_out + + async def _read(self, size: int = -1) -> bytes: + """ + Reads data from the file asynchronously. + + Args: + size (int, optional): The number of bytes to read. If not specified or negative, reads until EOF. + + Returns: + bytes: The data read from the file. + """ + return await self._driver_grid_out.read(size) + + # noinspection SpellCheckingInspection + async def _readchunk(self) -> bytes: + """ + Reads a chunk of data from the file asynchronously. + + Returns: + bytes: The data chunk read from the file. + """ + return await self._driver_grid_out.readchunk() + + async def _readline(self, size: int = -1) -> bytes: + """ + Reads a line from the file asynchronously. + + Args: + size (int, optional): The maximum number of bytes to read. If not specified or negative, reads until EOF. + + Returns: + bytes: The line read from the file. + """ + return await self._driver_grid_out.readline(size) + + def seek(self, pos: int, whence: int = os.SEEK_SET) -> int: + """ + Moves the file pointer to a specified position. + + Args: + pos (int): The position to seek to. + whence (int, optional): The reference point for the seek operation. + + Returns: + int: The new position of the file pointer. + """ + return self._driver_grid_out.seek(pos, whence) + + def seekable(self) -> bool: + """ + Checks if the file is seekable. + + Returns: + bool: True if the file is seekable, False otherwise. + """ + return self._driver_grid_out.seekable() + + def tell(self) -> int: + """ + Returns the current position of the file pointer. + + Returns: + int: The current position of the file pointer. + """ + return self._driver_grid_out.tell() + + def close(self): + """Closes the file.""" + self._driver_grid_out.close() + + @sync.proxy + def read( + self, + size: int = -1 + ) -> typing.Coroutine[typing.Any, typing.Any, bytes] | bytes: + return self._read(size) + + # noinspection SpellCheckingInspection + @sync.proxy + def readchunk(self) -> typing.Coroutine[typing.Any, typing.Any, bytes] | bytes: + return self._readchunk() + + @sync.proxy + def readline( + self, + size: int = -1 + ) -> typing.Coroutine[typing.Any, typing.Any, bytes] | bytes: + return self._readline(size) + -# noinspection SpellCheckingInspection class Seeding(documents.Document): """ Represents a seeding operation in the database. diff --git a/mongotoy/mappers.py b/mongotoy/mappers.py index bd2aa21..1b1b176 100644 --- a/mongotoy/mappers.py +++ b/mongotoy/mappers.py @@ -9,7 +9,7 @@ import bson -from mongotoy import cache, expressions, references, types, geodata +from mongotoy import cache, expressions, types, geodata from mongotoy.errors import ValidationError, ErrorWrapper if typing.TYPE_CHECKING: @@ -363,6 +363,19 @@ def validate(self, value) -> typing.Any: if not isinstance(value, self.__bind__): raise TypeError(f'Invalid data type {type(value)}, required is {self.__bind__}') + # Validate extra options + if self._options.extra: + if 'min_items' in self._options.extra: + if len(value) < self._options.extra['min_items']: + raise ValueError( + f'Invalid value len {len(value)}, required min_items={self._options.extra["min_items"]}' + ) + if 'max_items' in self._options.extra: + if len(value) > self._options.extra['max_items']: + raise ValueError( + f'Invalid value len {len(value)}, required max_items={self._options.extra["max_items"]}' + ) + new_value = [] errors = [] for i, val in enumerate(value): diff --git a/mongotoy/sync.py b/mongotoy/sync.py new file mode 100644 index 0000000..0b3142b --- /dev/null +++ b/mongotoy/sync.py @@ -0,0 +1,71 @@ +import asyncio +import functools +from typing import Callable, AsyncGenerator, Generator + +# Flag to indicate whether sync mode is enabled +_SYNC_MODE_ENABLED = False + +# Global event loop +_loop = asyncio.get_event_loop() + + +def enable_sync_mode(): + """ + Enable synchronous mode for running asynchronous functions synchronously. + """ + global _SYNC_MODE_ENABLED + _SYNC_MODE_ENABLED = True + + +def run_sync(func: Callable): + """ + Wrapper function to run an asynchronous function synchronously. + + Args: + func (Callable): The asynchronous function to run. + + Returns: + Callable: A wrapped function that runs the asynchronous function synchronously. + """ + + def wrap(*args, **kwargs): + return _loop.run_until_complete(func(*args, **kwargs)) + + return wrap + + +def as_sync_gen(gen: AsyncGenerator) -> Generator: + """ + Convert an asynchronous generator into a synchronous generator. + + Args: + gen (AsyncGenerator): The asynchronous generator. + + Yields: + Any: Items yielded by the asynchronous generator. + """ + while True: + try: + yield _loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + break + + +def proxy(func: Callable): + """ + Decorator to run an asynchronous function synchronously if sync mode is enabled. + + Args: + func (Callable): The asynchronous function to decorate. + + Returns: + Callable: A wrapped function that runs the asynchronous function synchronously if sync mode is enabled. + """ + + @functools.wraps(func) + def wrap(*args, **kwargs): + if _SYNC_MODE_ENABLED: + return run_sync(func)(*args, **kwargs) + return func(*args, **kwargs) + + return wrap diff --git a/mongotoy/types.py b/mongotoy/types.py index 7a9331e..5108b3f 100644 --- a/mongotoy/types.py +++ b/mongotoy/types.py @@ -1,11 +1,10 @@ - import collections import datetime +import os import re import typing import bson -from motor.motor_asyncio import AsyncIOMotorGridOut from mongotoy import geodata @@ -356,13 +355,62 @@ class File(typing.Protocol): chunk_size: int length: int upload_date: datetime.datetime + content_type: str + chunks: int + + def create_revision( + self, + fs: 'db.FsBucket', + src: typing.IO | bytes, + metadata: dict = None + ) -> typing.Union[typing.Coroutine[typing.Any, typing.Any, 'File'], 'File']: + pass + + # noinspection SpellCheckingInspection + def download( + self, + fs: 'db.FsBucket', + dest: typing.IO, + revision: int = None + ) -> typing.Coroutine[typing.Any, typing.Any, None] | None: + pass + + def stream( + self, + fs: 'db.FsBucket', + revision: int = None + ) -> typing.Union[typing.Coroutine[typing.Any, typing.Any, '_FileOut'], '_FileOut']: + pass - async def create_revision(self, fs: 'db.FsBucket', src: typing.IO | bytes, metadata: dict = None): + def delete(self, fs: 'db.FsBucket') -> typing.Coroutine[typing.Any, typing.Any, None] | None: pass + +class _FileOut(typing.Protocol): # noinspection SpellCheckingInspection - async def download(self, fs: 'db.FsBucket', dest: typing.IO, revision: int = None): + """ + This is a facade for type mogotoy.db.FsOutput + """ + + def seek(self, pos: int, whence: int = os.SEEK_SET) -> int: pass - async def stream(self, fs: 'db.FsBucket', revision: int = None) -> AsyncIOMotorGridOut: + def seekable(self) -> bool: pass + + def tell(self) -> int: + pass + + def close(self): + pass + + def read(self, size: int = -1) -> typing.Coroutine[typing.Any, typing.Any, bytes] | bytes: + pass + + # noinspection SpellCheckingInspection + def readchunk(self) -> typing.Coroutine[typing.Any, typing.Any, bytes] | bytes: + pass + + def readline(self, size: int = -1) -> typing.Coroutine[typing.Any, typing.Any, bytes] | bytes: + pass + diff --git a/pyproject.toml b/pyproject.toml index 3ad42a2..8edb085 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mongotoy" -version = "0.1.2" +version = "0.1.3" description = "Async ODM for MongoDB" license = "Apache-2.0" authors = ["gurcuff91 "]