Skip to content

Commit 548b5e7

Browse files
nightblureivan
andauthored
Overriding with context manager (one and multiple providers) (#53)
* Implement overriding with context manager for one provider * Implement batch overriding with context manager for container Co-authored-by: ivan <[email protected]>
1 parent 18a1857 commit 548b5e7

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

tests/providers/test_providers_overriding.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,87 @@
11
import datetime
22

3+
import pytest
4+
35
from tests import container
46

57

8+
async def test_batch_providers_overriding() -> None:
9+
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
10+
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
11+
async_factory_mock = datetime.datetime.fromisoformat("2025-01-01")
12+
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
13+
singleton_mock = container.SingletonFactory(dep1=False)
14+
15+
providers_for_overriding = {
16+
"async_resource": async_resource_mock,
17+
"sync_resource": sync_resource_mock,
18+
"simple_factory": simple_factory_mock,
19+
"singleton": singleton_mock,
20+
"async_factory": async_factory_mock,
21+
}
22+
23+
with container.DIContainer.override_providers(providers_for_overriding):
24+
await container.DIContainer.simple_factory()
25+
dependent_factory = await container.DIContainer.dependent_factory()
26+
singleton = await container.DIContainer.singleton()
27+
async_factory = await container.DIContainer.async_factory()
28+
29+
assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
30+
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
31+
assert dependent_factory.sync_resource == sync_resource_mock
32+
assert dependent_factory.async_resource == async_resource_mock
33+
assert singleton is singleton_mock
34+
assert async_factory is async_factory_mock
35+
36+
assert (await container.DIContainer.async_resource()) != async_resource_mock
37+
38+
39+
async def test_batch_providers_overriding_sync_resolve() -> None:
40+
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
41+
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
42+
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
43+
singleton_mock = container.SingletonFactory(dep1=False)
44+
45+
providers_for_overriding = {
46+
"async_resource": async_resource_mock,
47+
"sync_resource": sync_resource_mock,
48+
"simple_factory": simple_factory_mock,
49+
"singleton": singleton_mock,
50+
}
51+
52+
with container.DIContainer.override_providers(providers_for_overriding):
53+
container.DIContainer.simple_factory.sync_resolve()
54+
await container.DIContainer.async_resource.async_resolve()
55+
dependent_factory = container.DIContainer.dependent_factory.sync_resolve()
56+
singleton = container.DIContainer.singleton.sync_resolve()
57+
58+
assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
59+
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
60+
assert dependent_factory.sync_resource == sync_resource_mock
61+
assert dependent_factory.async_resource == async_resource_mock
62+
assert singleton is singleton_mock
63+
64+
assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock
65+
66+
67+
def test_providers_overriding_with_context_manager() -> None:
68+
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
69+
70+
with container.DIContainer.simple_factory.override_context(simple_factory_mock):
71+
assert container.DIContainer.simple_factory.sync_resolve() is simple_factory_mock
72+
73+
assert container.DIContainer.simple_factory.sync_resolve() is not simple_factory_mock
74+
75+
76+
def test_providers_overriding_fail_with_unknown_provider() -> None:
77+
unknown_provider_name = "unknown_provider_name"
78+
match = f"Provider with name {unknown_provider_name!r} not found"
79+
providers_for_overriding = {unknown_provider_name: None}
80+
81+
with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding):
82+
... # pragma: no cover
83+
84+
685
async def test_providers_overriding() -> None:
786
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
887
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")

that_depends/container.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import typing
3+
from contextlib import contextmanager
34

45
from that_depends.providers import AbstractProvider, AbstractResource, Singleton
56

@@ -92,3 +93,26 @@ async def resolve(cls, object_to_resolve: type[T] | typing.Callable[..., T]) ->
9293
kwargs[field_name] = await providers[field_name].async_resolve()
9394

9495
return object_to_resolve(**kwargs)
96+
97+
@classmethod
98+
@contextmanager
99+
def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]:
100+
current_providers = cls.get_providers()
101+
current_provider_names = set(current_providers.keys())
102+
given_provider_names = set(providers_for_overriding.keys())
103+
104+
for given_name in given_provider_names:
105+
if given_name not in current_provider_names:
106+
msg = f"Provider with name {given_name!r} not found"
107+
raise RuntimeError(msg)
108+
109+
for provider_name, mock in providers_for_overriding.items():
110+
provider = current_providers[provider_name]
111+
provider.override(mock)
112+
113+
try:
114+
yield
115+
finally:
116+
for provider_name in providers_for_overriding:
117+
provider = current_providers[provider_name]
118+
provider.reset_override()

that_depends/providers/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import typing
3+
from contextlib import contextmanager
34

45

56
T = typing.TypeVar("T")
@@ -24,6 +25,14 @@ async def __call__(self) -> T_co:
2425
def override(self, mock: object) -> None:
2526
self._override = mock
2627

28+
@contextmanager
29+
def override_context(self, mock: object) -> typing.Iterator[None]:
30+
self.override(mock)
31+
try:
32+
yield
33+
finally:
34+
self.reset_override()
35+
2736
def reset_override(self) -> None:
2837
self._override = None
2938

0 commit comments

Comments
 (0)