|
1 | 1 | import datetime |
2 | 2 |
|
| 3 | +import pytest |
| 4 | + |
3 | 5 | from tests import container |
4 | 6 |
|
5 | 7 |
|
| 8 | +async def test_batch_providers_overriding() -> None: |
| 9 | + async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") |
| 10 | + sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") |
| 11 | + async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") |
| 12 | + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) |
| 13 | + singleton_mock = container.SingletonFactory(dep1=False) |
| 14 | + |
| 15 | + providers_for_overriding = { |
| 16 | + "async_resource": async_resource_mock, |
| 17 | + "sync_resource": sync_resource_mock, |
| 18 | + "simple_factory": simple_factory_mock, |
| 19 | + "singleton": singleton_mock, |
| 20 | + "async_factory": async_factory_mock, |
| 21 | + } |
| 22 | + |
| 23 | + with container.DIContainer.override_providers(providers_for_overriding): |
| 24 | + await container.DIContainer.simple_factory() |
| 25 | + dependent_factory = await container.DIContainer.dependent_factory() |
| 26 | + singleton = await container.DIContainer.singleton() |
| 27 | + async_factory = await container.DIContainer.async_factory() |
| 28 | + |
| 29 | + assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 |
| 30 | + assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 |
| 31 | + assert dependent_factory.sync_resource == sync_resource_mock |
| 32 | + assert dependent_factory.async_resource == async_resource_mock |
| 33 | + assert singleton is singleton_mock |
| 34 | + assert async_factory is async_factory_mock |
| 35 | + |
| 36 | + assert (await container.DIContainer.async_resource()) != async_resource_mock |
| 37 | + |
| 38 | + |
| 39 | +async def test_batch_providers_overriding_sync_resolve() -> None: |
| 40 | + async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") |
| 41 | + sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") |
| 42 | + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) |
| 43 | + singleton_mock = container.SingletonFactory(dep1=False) |
| 44 | + |
| 45 | + providers_for_overriding = { |
| 46 | + "async_resource": async_resource_mock, |
| 47 | + "sync_resource": sync_resource_mock, |
| 48 | + "simple_factory": simple_factory_mock, |
| 49 | + "singleton": singleton_mock, |
| 50 | + } |
| 51 | + |
| 52 | + with container.DIContainer.override_providers(providers_for_overriding): |
| 53 | + container.DIContainer.simple_factory.sync_resolve() |
| 54 | + await container.DIContainer.async_resource.async_resolve() |
| 55 | + dependent_factory = container.DIContainer.dependent_factory.sync_resolve() |
| 56 | + singleton = container.DIContainer.singleton.sync_resolve() |
| 57 | + |
| 58 | + assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 |
| 59 | + assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 |
| 60 | + assert dependent_factory.sync_resource == sync_resource_mock |
| 61 | + assert dependent_factory.async_resource == async_resource_mock |
| 62 | + assert singleton is singleton_mock |
| 63 | + |
| 64 | + assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock |
| 65 | + |
| 66 | + |
| 67 | +def test_providers_overriding_with_context_manager() -> None: |
| 68 | + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) |
| 69 | + |
| 70 | + with container.DIContainer.simple_factory.override_context(simple_factory_mock): |
| 71 | + assert container.DIContainer.simple_factory.sync_resolve() is simple_factory_mock |
| 72 | + |
| 73 | + assert container.DIContainer.simple_factory.sync_resolve() is not simple_factory_mock |
| 74 | + |
| 75 | + |
| 76 | +def test_providers_overriding_fail_with_unknown_provider() -> None: |
| 77 | + unknown_provider_name = "unknown_provider_name" |
| 78 | + match = f"Provider with name {unknown_provider_name!r} not found" |
| 79 | + providers_for_overriding = {unknown_provider_name: None} |
| 80 | + |
| 81 | + with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding): |
| 82 | + ... # pragma: no cover |
| 83 | + |
| 84 | + |
6 | 85 | async def test_providers_overriding() -> None: |
7 | 86 | async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") |
8 | 87 | sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") |
|
0 commit comments