Skip to content

Commit

Permalink
implemented (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-right authored Dec 23, 2021
1 parent b11ad88 commit 15ce608
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 22 deletions.
2 changes: 1 addition & 1 deletion beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from beanie.odm.utils.general import init_beanie
from beanie.odm.documents import Document

__version__ = "1.8.8"
__version__ = "1.8.9"
__all__ = [
# ODM
"Document",
Expand Down
49 changes: 41 additions & 8 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,16 +927,49 @@ def is_changed(self) -> bool:
return False
return True

def _collect_updates(
self, old_dict: Dict[str, Any], new_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Compares old_dict with new_dict and returns field paths that have been updated
Args:
old_dict: dict1
new_dict: dict2
Returns: dictionary with updates
"""
updates = {}

for field_name, field_value in new_dict.items():
if field_value != old_dict.get(field_name):
if not (
isinstance(field_value, dict)
and isinstance(old_dict.get(field_name), dict)
):
updates[field_name] = field_value
else:
if old_dict.get(field_name) is None:
updates[field_name] = field_value
elif isinstance(field_value, dict) and isinstance(
old_dict.get(field_name), dict
):

field_data = self._collect_updates(
old_dict.get(field_name), # type: ignore
field_value,
)

for k, v in field_data.items():
updates[f"{field_name}.{k}"] = v

return updates

@saved_state_needed
def get_changes(self) -> Dict[str, Any]:
# TODO search deeply
changes = {}
if self.is_changed:
current_state = get_dict(self, to_db=True)
for k, v in self._saved_state.items(): # type: ignore
if v != current_state[k]:
changes[k] = current_state[k]
return changes
return self._collect_updates(
self._saved_state, get_dict(self, to_db=True) # type: ignore
)

@saved_state_needed
def rollback(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

Beanie project changes

## [1.8.9] - 2021-12-23

### Improvement

- Deep search of updates for the `save_changes()` method

### Kudos

- Thanks, [Tigran Khazhakyan](https://github.com/tigrankh) for the deep search algo here

## [1.8.8] - 2021-12-17

### Added
Expand Down
6 changes: 6 additions & 0 deletions tests/odm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,15 @@ class InheritedDocumentWithActions(DocumentWithActions):
...


class InternalDoc(BaseModel):
num: int = 100
string: str = "test"


class DocumentWithTurnedOnStateManagement(Document):
num_1: int
num_2: int
internal: InternalDoc

class Settings:
use_state_management = True
Expand Down
50 changes: 38 additions & 12 deletions tests/odm/test_state_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from tests.odm.models import (
DocumentWithTurnedOnStateManagement,
DocumentWithTurnedOffStateManagement,
InternalDoc,
)


@pytest.fixture
def state():
return {"num_1": 1, "num_2": 2, "_id": ObjectId()}
return {
"num_1": 1,
"num_2": 2,
"_id": ObjectId(),
"internal": InternalDoc(),
}


@pytest.fixture
Expand All @@ -30,14 +36,25 @@ def test_use_state_management_property():


def test_save_state():
doc = DocumentWithTurnedOnStateManagement(num_1=1, num_2=2)
doc = DocumentWithTurnedOnStateManagement(
num_1=1, num_2=2, internal=InternalDoc(num=1, string="s")
)
assert doc.get_saved_state() is None
doc._save_state()
assert doc.get_saved_state() == {"num_1": 1, "num_2": 2}
assert doc.get_saved_state() == {
"num_1": 1,
"num_2": 2,
"internal": {"num": 1, "string": "s"},
}


def test_parse_object_with_saving_state():
obj = {"num_1": 1, "num_2": 2, "_id": ObjectId()}
obj = {
"num_1": 1,
"num_2": 2,
"_id": ObjectId(),
"internal": InternalDoc(),
}
doc = DocumentWithTurnedOnStateManagement._parse_obj_saving_state(obj)
assert doc.get_saved_state() == obj

Expand All @@ -47,7 +64,9 @@ def test_saved_state_needed():
with pytest.raises(StateManagementIsTurnedOff):
doc_1.is_changed

doc_2 = DocumentWithTurnedOnStateManagement(num_1=1, num_2=2)
doc_2 = DocumentWithTurnedOnStateManagement(
num_1=1, num_2=2, internal=InternalDoc()
)
with pytest.raises(StateNotSaved):
doc_2.is_changed

Expand All @@ -59,18 +78,21 @@ def test_if_changed(doc):


def test_get_changes(doc):
doc.num_1 = 100
assert doc.get_changes() == {"num_1": 100}
doc.internal.num = 1000
doc.internal.string = "new_value"
assert doc.get_changes() == {
"internal.num": 1000,
"internal.string": "new_value",
}


async def test_save_changes(saved_doc):
saved_doc.num_1 = 100
saved_doc.internal.num = 10000
await saved_doc.save_changes()

assert saved_doc.get_saved_state()["num_1"] == 100
assert saved_doc.get_saved_state()["internal"]["num"] == 10000

new_doc = await DocumentWithTurnedOnStateManagement.get(saved_doc.id)
assert new_doc.num_1 == 100
assert new_doc.internal.num == 10000


async def test_find_one(saved_doc, state):
Expand All @@ -86,7 +108,11 @@ async def test_find_one(saved_doc, state):
async def test_find_many():
docs = []
for i in range(10):
docs.append(DocumentWithTurnedOnStateManagement(num_1=i, num_2=i + 1))
docs.append(
DocumentWithTurnedOnStateManagement(
num_1=i, num_2=i + 1, internal=InternalDoc()
)
)
await DocumentWithTurnedOnStateManagement.insert_many(docs)

found_docs = await DocumentWithTurnedOnStateManagement.find(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_beanie.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def test_version():
assert __version__ == "1.8.8"
assert __version__ == "1.8.9"

0 comments on commit 15ce608

Please sign in to comment.