Skip to content

Commit

Permalink
Added collation to documet config and queries. Make some general fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gurcuff91 committed May 24, 2024
1 parent 3cdf4b8 commit 2c54baa
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 72 deletions.
3 changes: 2 additions & 1 deletion docs/docs/documents.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ the database. To specify settings, you use the `document_config` class attribute
from mongotoy import Document
from mongotoy.documents import DocumentConfig


class Person(Document):
document_config = DocumentConfig(capped=True)
document_config = DocumentConfig(capped_collection=True)
````


Expand Down
15 changes: 5 additions & 10 deletions docs/docs/objects.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ See [querying expressions](#querying-expressions)
async with engine.session() as session:

# Get persons older than 21 years
legal_persons = session.objects(Person).filter(Person.age > 21)
older_persons = session.objects(Person).filter(Person.age > 21)

# Get persons from USA
usa_persons = session.objects(Person).filter(address__country__eq='USA')
Expand Down Expand Up @@ -138,9 +138,6 @@ instance. Raise `mongotoy.errors.NoResultError()` if no result found.
- **one_or_none()**: Retrieves a specific document from the result set. It returns a single-parsed document
instance or `None` if no result found.

- **get_by_id(value)**: Retrieves a document by its identifier from the result set. It returns a parsed
document instance corresponding to the provided identifier.

These functions contribute to efficient data retrieval and manipulation by leveraging asynchronous or synchronous
operations, ensuring responsiveness and scalability in handling database queries.

Expand All @@ -152,9 +149,6 @@ async with engine.session() as session:

# Fetching one person
person = await session.objects(Person).one()

# Fetching one person by id
person = await session.objects(Person).get_by_id(1)
````

### Counting documents
Expand Down Expand Up @@ -281,6 +275,7 @@ Query.Gt('age', 21)
To ensure accurate querying expressions, use the `alias` rather than the field `name` for fields with defined aliases.
Otherwise, querying operations might target nonexistent database fields, resulting in inaccuracies.
///
///

### The Q function

Expand All @@ -293,9 +288,9 @@ The function parses each keyword, separating the field name from the operator, w
dynamically constructs a query by combining these conditions using logical `AND` operations. This allows users to build
queries in a more readable and intuitive way, compared to manually constructing query strings.

The `Q` function is particularly useful in scenarios where the query parameters are not known in advance and need to be constructed at
runtime based on user input or other dynamic data sources. It encapsulates the complexity of query construction,
providing a clean and maintainable interface for building queries.
The `Q` function is particularly useful in scenarios where the query parameters are not known in advance and need to be
constructed at runtime based on user input or other dynamic data sources. It encapsulates the complexity of query
construction, providing a clean and maintainable interface for building queries.

````python
from mongotoy.expressions import Q
Expand Down
107 changes: 68 additions & 39 deletions mongotoy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from motor.core import AgnosticClient, AgnosticDatabase, AgnosticCollection, AgnosticClientSession
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket, AsyncIOMotorGridOut
from motor.motor_gridfs import AgnosticGridFSBucket
from pymongo.collation import Collation
from pymongo.read_concern import ReadConcern

from mongotoy import documents, expressions, references, fields, types, sync
Expand Down Expand Up @@ -203,8 +204,14 @@ async def _create_document_indexes(
"""
indexes = self._get_document_indexes(document_cls)
collection = self._get_document_collection(document_cls)

if indexes:
await collection.create_indexes(indexes, session=driver_session)
options = {}
# Add collation to the index
if document_cls.document_config.collation:
options['collation'] = document_cls.document_config.collation

await collection.create_indexes(indexes, session=driver_session, **options)

async def _create_document_collection(
self,
Expand All @@ -221,26 +228,40 @@ async def _create_document_collection(
config = document_cls.document_config
options = {'check_exists': False}

# Configure options for capped collections
if config.capped:
# Configure options for capped collection
if config.capped_collection:
options['capped'] = True
options['size'] = config.capped_size
if config.capped_max:
options['max'] = config.capped_max
options['size'] = config.capped_collection_size
if config.capped_collection_max:
options['max'] = config.capped_collection_max

# Configure options for timeseries collections
# Configure options for timeseries collection
if config.timeseries_field:
timeseries = {
'timeField': config.timeseries_field,
'granularity': config.timeseries_granularity
'timeField': documents.get_document_field(
document_cls,
field_name=config.timeseries_field
).alias,
'granularity': config.timeseries_granularity or 'seconds'
}
if config.timeseries_meta_field:
timeseries['metaField'] = config.timeseries_meta_field
timeseries['metaField'] = documents.get_document_field(
document_cls,
field_name=config.timeseries_meta_field
).alias

options['timeseries'] = timeseries
if config.timeseries_expire_after_seconds:
options['expireAfterSeconds'] = config.timeseries_expire_after_seconds

# Add collation to options
if config.collation:
options['collation'] = config.collation

# Add extra options to a collection
if config.extra_collection_options:
options.update(config.extra_collection_options)

# Create the collection with configured options
await self.database.create_collection(
name=document_cls.__collection_name__,
Expand Down Expand Up @@ -721,18 +742,29 @@ def transaction(self) -> 'Transaction':
"""
return Transaction(session=self)

def objects(self, document_cls: typing.Type[T], dereference_deep: int = 0) -> 'Objects[T]':
def objects(
self,
document_cls: typing.Type[T],
dereference_deep: int = 0,
collation: typing.Optional[Collation] = None
) -> 'Objects[T]':
"""
Returns an object manager for the specified document class.
Args:
document_cls (typing.Type[T]): The document class.
dereference_deep (int): Depth of dereferencing.
collation (Collation, optional): The collation to use when query documents.
Returns:
Objects[T]: An object manager.
"""
return Objects(document_cls, session=self, dereference_deep=dereference_deep)
return Objects(
document_cls,
session=self,
dereference_deep=dereference_deep,
collation=collation
)

def fs(self, chunk_size_bytes: int = gridfs.DEFAULT_CHUNK_SIZE) -> 'FsBucket':
"""
Expand Down Expand Up @@ -896,12 +928,21 @@ class Objects(typing.Generic[T]):
document_cls (typing.Type[T]): The document class associated with the query set.
session (Session): The session object used for database operations.
dereference_deep (int, optional): The depth of dereferencing for referenced documents.
collation (Collation, optional): The collation to use when query documents.
"""

def __init__(self, document_cls: typing.Type[T], session: Session, dereference_deep: int = 0):
def __init__(
self,
document_cls: typing.Type[T],
session: Session,
dereference_deep: int = 0,
collation: typing.Optional[Collation] = None
):
self._document_cls = document_cls
self._session = session
self._dereference_deep = dereference_deep
self._collation = collation
self._collection = session.engine.collection(document_cls)
self._filter = expressions.Query()
self._sort = expressions.Sort()
Expand All @@ -921,7 +962,8 @@ def __copy__(self, **options) -> 'Objects[T]':
objs = Objects(
document_cls=self._document_cls,
session=self._session,
dereference_deep=self._dereference_deep
dereference_deep=self._dereference_deep,
collation=self._collation
)
setattr(objs, '_filter', options.get('_filter', self._filter))
setattr(objs, '_sort', options.get('_sort', self._sort))
Expand All @@ -937,12 +979,14 @@ async def __aiter__(self) -> typing.AsyncGenerator[T, None]:
Yields:
T: The parsed document instances.
"""
# Create pipeline
# noinspection PyTypeChecker
pipeline = references.build_dereference_pipeline(
references=self._document_cls.__references__.values(),
deep=self._dereference_deep
)

# Apply filters, sorting and limits
if self._filter:
pipeline.append({'$match': self._filter})
if self._sort:
Expand All @@ -952,7 +996,15 @@ async def __aiter__(self) -> typing.AsyncGenerator[T, None]:
if self._limit > 0:
pipeline.append({'$limit': self._limit})

cursor = self._collection.aggregate(pipeline, session=self._session.driver_session)
# Aggregation query options
options = {}

# Add collation to aggregation query
collation = self._collation or self._document_cls.document_config.collation
if collation:
options['collation'] = collation

cursor = self._collection.aggregate(pipeline, session=self._session.driver_session, **options)
async for data in cursor:
yield self._document_cls(**data)

Expand Down Expand Up @@ -1002,25 +1054,6 @@ async def _one_or_none(self) -> typing.Optional[T]:
except NoResultError:
pass

# noinspection PyShadowingBuiltins
async def _get_by_id(self, id_value: typing.Any) -> T:
"""
Retrieves a document by its identifier.
Args:
id_value (typing.Any): The identifier value.
Returns:
T: The parsed document instance.
Raises:
NoResultsError: If no results are found.
"""
# noinspection PyProtectedMember
return await self.filter(
self._document_cls.id == self._document_cls.id._field.mapper.validate_value(id_value)
)._one()

async def _count(self) -> int:
"""
Counts the number of documents in the result set.
Expand Down Expand Up @@ -1102,10 +1135,6 @@ def one(self) -> typing.Coroutine[typing.Any, typing.Any, T] | T:
def one_or_none(self) -> typing.Coroutine[typing.Any, typing.Any, typing.Optional[T]] | typing.Optional[T]:
return self._one_or_none()

@sync.proxy
def get_by_id(self, value: typing.Any) -> typing.Coroutine[typing.Any, typing.Any, T] | T:
return self._get_by_id(value)

@sync.proxy
def count(self) -> typing.Coroutine[typing.Any, typing.Any, int] | int:
return self._count()
Expand Down Expand Up @@ -1177,7 +1206,7 @@ async def _create(
session=self._session.driver_session
)
# Update obj info
obj = await self._get_by_id(obj.id)
obj = await self.filter(FsObject.id == obj.id)._one()

return obj

Expand Down
Loading

0 comments on commit 2c54baa

Please sign in to comment.