From 3f60ccfeffa4e0cd4d24569e23e2cf7d343e6dfa Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Tue, 30 Jan 2024 00:13:53 +0100 Subject: [PATCH 1/9] 0.1.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4a0696628..07483ea5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.1" +version = "0.1.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From 86cfa07d061bfc58bbd94b59e3772d0c3719d5cd Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Fri, 9 Feb 2024 00:15:53 +0100 Subject: [PATCH 2/9] Bumb version v0.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 07483ea5f..2000c17de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.1.0" +version = "0.2.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From c5da559c73af91b6922ff4165bb4cdfa54e8a646 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Sun, 18 Feb 2024 19:04:30 +0100 Subject: [PATCH 3/9] 0.3.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2000c17de..ae1f172e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.2.0" +version = "0.3.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From 02c77eaaeef48cfff1cc2db7f17ac9d9c4b0cfba Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Wed, 21 Feb 2024 00:57:36 +0100 Subject: [PATCH 4/9] 0.4.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ae1f172e1..f570a2cf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.3.0" +version = "0.4.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From f28e6e5e6642ca49ffe137546f1bd60620dd5ec7 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Tue, 27 Feb 2024 23:52:37 +0100 Subject: [PATCH 5/9] 0.5.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eb22f0eab..82d649087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.4.0" +version = "0.5.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From 139b2c68a6418574c6268dd3d30c8062ba0c6943 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Mon, 4 Mar 2024 23:03:23 +0100 Subject: [PATCH 6/9] 0.6.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 82d649087..33deb8db1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.5.0" +version = "0.6.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From 2bda6b488babddaf477715a2513c56e1989253b0 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Tue, 19 Mar 2024 23:45:49 +0100 Subject: [PATCH 7/9] 0.7.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 33deb8db1..28c0ffa88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.6.0" +version = "0.7.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From 114dc475a59c0101d0e844ab3122532ac6be90c0 Mon Sep 17 00:00:00 2001 From: Andrey Tikhonov <17@itishka.org> Date: Mon, 25 Mar 2024 00:55:18 +0100 Subject: [PATCH 8/9] 0.8.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 28c0ffa88..3cdc91c8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ where = ["src"] [project] name = "dishka" -version = "0.7.0" +version = "0.8.0" readme = "README.md" authors = [ { name = "Andrey Tikhonov", email = "17@itishka.org" }, From 5da13ceaf7a6f1f37a50c93f766d4330688fc573 Mon Sep 17 00:00:00 2001 From: Sergey Malanin Date: Wed, 27 Mar 2024 18:46:47 +0000 Subject: [PATCH 9/9] added resolve_all --- src/dishka/async_container.py | 43 +++++++++++++- src/dishka/container.py | 43 +++++++++++++- tests/unit/container/test_components.py | 71 ++++++++++++++++++++++- tests/unit/container/test_context_vars.py | 50 ++++++++++++++++ tests/unit/container/test_resolve.py | 65 +++++++++++++++++++++ 5 files changed, 269 insertions(+), 3 deletions(-) diff --git a/src/dishka/async_container.py b/src/dishka/async_container.py index 9e2af074a..5de8f31c4 100644 --- a/src/dishka/async_container.py +++ b/src/dishka/async_container.py @@ -1,6 +1,7 @@ from asyncio import Lock from collections.abc import Callable -from typing import Any, Optional, TypeVar +from contextlib import suppress +from typing import Any, Iterable, Literal, Optional, TypeVar, overload from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.key import DependencyKey @@ -9,6 +10,7 @@ from .dependency_source import FactoryType from .exceptions import ( ExitError, + NoContextValueError, NoFactoryError, ) from .provider import BaseProvider @@ -109,6 +111,45 @@ async def get( async with lock: return await self._get_unlocked(key) + @overload + async def resolve_all(self, components: None = None) -> None: ... + @overload + async def resolve_all(self, components: Literal[True]) -> None: ... + @overload + async def resolve_all(self, components: Iterable[Component]) -> None: ... + + async def resolve_all(self, components: Any = None) -> None: + """ + Resolve all container dependencies in the current scope for the given + components. + + Examples: + >>> container.resolve_all() + Resolve all dependencies for the default component. + + >>> container.resolve_all(True) + Resolve all dependencies for all components. + + >>> container.resolve_all(['component1', 'component2']) + Resolve dependencies for 'component1' and 'component2'. + """ + if not components: + + def component_check(k: DependencyKey) -> bool: + return k.component == DEFAULT_COMPONENT + elif components is True: + + def component_check(k: DependencyKey) -> bool: + return True + else: + + def component_check(k: DependencyKey) -> bool: + return k.component in components + + for key in filter(component_check, self.registry.factories): + with suppress(NoContextValueError): + await self._get_unlocked(key) + async def _get_unlocked(self, key: DependencyKey) -> Any: if key in self.context: return self.context[key] diff --git a/src/dishka/container.py b/src/dishka/container.py index 2952673d7..ac4375330 100644 --- a/src/dishka/container.py +++ b/src/dishka/container.py @@ -1,6 +1,7 @@ from collections.abc import Callable +from contextlib import suppress from threading import Lock -from typing import Any, Optional, TypeVar +from typing import Any, Iterable, Literal, Optional, TypeVar, overload from dishka.entities.component import DEFAULT_COMPONENT, Component from dishka.entities.key import DependencyKey @@ -9,6 +10,7 @@ from .dependency_source import FactoryType from .exceptions import ( ExitError, + NoContextValueError, NoFactoryError, ) from .provider import BaseProvider @@ -107,6 +109,45 @@ def get( with lock: return self._get_unlocked(key) + @overload + def resolve_all(self, components: None = None) -> None: ... + @overload + def resolve_all(self, components: Literal[True]) -> None: ... + @overload + def resolve_all(self, components: Iterable[Component]) -> None: ... + + def resolve_all(self, components: Any = None) -> None: + """ + Resolve all container dependencies in the current scope for the given + components. + + Examples: + >>> container.resolve_all() + Resolve all dependencies for the default component. + + >>> container.resolve_all(True) + Resolve all dependencies for all components. + + >>> container.resolve_all(['component1', 'component2']) + Resolve dependencies for 'component1' and 'component2'. + """ + if not components: + + def component_check(k: DependencyKey) -> bool: + return k.component == DEFAULT_COMPONENT + elif components is True: + + def component_check(k: DependencyKey) -> bool: + return True + else: + + def component_check(k: DependencyKey) -> bool: + return k.component in components + + for key in filter(component_check, self.registry.factories): + with suppress(NoContextValueError): + self._get_unlocked(key) + def _get_unlocked(self, key: DependencyKey) -> Any: if key in self.context: return self.context[key] diff --git a/tests/unit/container/test_components.py b/tests/unit/container/test_components.py index 492c7a079..4f1407e46 100644 --- a/tests/unit/container/test_components.py +++ b/tests/unit/container/test_components.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Literal import pytest @@ -55,6 +55,24 @@ def foo(self, a: Annotated[int, FromComponent()]) -> float: return a + 1 +class YProvider(Provider): + scope = Scope.APP + component = "Y" + + @provide + def foo(self) -> float: + return 42 + + +class ZProvider(Provider): + scope = Scope.APP + component = "Z" + + @provide + def foo(self) -> bool: + return True + + def test_from_component(): container = make_container(MainProvider(20), XProvider()) assert container.get(complex) == 210 @@ -63,6 +81,31 @@ def test_from_component(): container.get(float) +@pytest.mark.parametrize( + ("component", "expected_count"), + [ + (None, 4), + (("",), 4), + (True, 6), + (("X",), 3), + (("X", ""), 4), + (("X", "Y"), 4), + (("X", "Y", ""), 5), + (("X", "Y", "Z"), 5), + (("X", "Y", "Z", ""), 6), + ], +) +def test_from_component_resolve_all( + component: Literal[True] | tuple[Component] | None, expected_count: int +): + container = make_container( + MainProvider(20), XProvider(), YProvider(), ZProvider() + ) + assert len(container.context) == 1 + container.resolve_all(component) + assert len(container.context) == expected_count + + @pytest.mark.asyncio() async def test_from_component_async(): container = make_async_container(MainProvider(20), XProvider()) @@ -72,6 +115,32 @@ async def test_from_component_async(): await container.get(float) +@pytest.mark.parametrize( + ("component", "expected_count"), + [ + (None, 4), + (("",), 4), + (True, 6), + (("X",), 3), + (("X", ""), 4), + (("X", "Y"), 4), + (("X", "Y", ""), 5), + (("X", "Y", "Z"), 5), + (("X", "Y", "Z", ""), 6), + ], +) +@pytest.mark.asyncio +async def test_from_component_resolve_all_async( + component: Literal[True] | tuple[Component] | None, expected_count: int +): + container = make_async_container( + MainProvider(20), XProvider(), YProvider(), ZProvider() + ) + assert len(container.context) == 1 + await container.resolve_all(component) + assert len(container.context) == expected_count + + class SingleProvider(Provider): scope = Scope.APP diff --git a/tests/unit/container/test_context_vars.py b/tests/unit/container/test_context_vars.py index 1cac58779..12be5f286 100644 --- a/tests/unit/container/test_context_vars.py +++ b/tests/unit/container/test_context_vars.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from dishka import ( @@ -9,6 +11,7 @@ ) from dishka.dependency_source import from_context from dishka.exceptions import NoContextValueError +from ..sample_providers import ClassA def test_simple(): @@ -18,6 +21,35 @@ def test_simple(): assert container.get(int) == 1 +class AProvider(Provider): + scope = Scope.APP + a = from_context(provides=int) + b = from_context(provides=str) + + @provide + def foo(self, a: int) -> ClassA: + return ClassA(a) + + @provide + def bar(self, a: str) -> bool: + return bool(a) + + +@pytest.mark.parametrize( + ("context", "expected_count"), + [ + ({}, 1), + ({int: 1}, 3), + ({int: 1, str: "1"}, 5), + ], +) +def test_simple_resolve_all(context: dict[type, Any], expected_count: int): + provider = AProvider() + container = make_container(provider, context=context) + container.resolve_all() + assert len(container.context) == expected_count + + @pytest.mark.asyncio async def test_simple_async(): provider = Provider() @@ -26,6 +58,24 @@ async def test_simple_async(): assert await container.get(int) == 1 +@pytest.mark.parametrize( + ("context", "expected_count"), + [ + ({}, 1), + ({int: 1}, 3), + ({int: 1, str: "1"}, 5), + ], +) +@pytest.mark.asyncio +async def test_simple_resolve_all_async( + context: dict[type, Any], expected_count: int +): + provider = AProvider() + container = make_async_container(provider, context=context) + await container.resolve_all() + assert len(container.context) == expected_count + + def test_not_found(): provider = Provider() provider.from_context(provides=int, scope=Scope.APP) diff --git a/tests/unit/container/test_resolve.py b/tests/unit/container/test_resolve.py index 3954dfa0f..57f43528b 100644 --- a/tests/unit/container/test_resolve.py +++ b/tests/unit/container/test_resolve.py @@ -1,3 +1,5 @@ +from typing import Any, Callable + import pytest from dishka import ( @@ -129,3 +131,66 @@ def test_external_method(method): container = make_container(provider) assert container.get(ClassA) is A_VALUE + + +@pytest.mark.parametrize( + ("factory", "cache", "expected_count"), + [ + (ClassA, True, 3), + (ClassA, False, 2), + (sync_func_a, True, 3), + (sync_func_a, False, 2), + (sync_iter_a, True, 3), + (sync_iter_a, False, 2), + (sync_gen_a, True, 3), + (sync_gen_a, False, 2), + ], +) +def test_sync_resolve_all( + factory: Callable[..., Any], cache: bool, expected_count: int +): + class MyProvider(Provider): + a = provide(factory, scope=Scope.APP, cache=cache) + + @provide(scope=Scope.APP) + def get_int(self) -> int: + return 100 + + container = make_container(MyProvider()) + assert container.registry.scope is Scope.APP + assert len(container.context) == 1 + container.resolve_all() + assert len(container.context) == expected_count + container.close() + + +@pytest.mark.parametrize( + ("factory", "cache", "expected_count"), + [ + (ClassA, True, 3), + (ClassA, False, 2), + (async_func_a, True, 3), + (async_func_a, False, 2), + (async_iter_a, True, 3), + (async_iter_a, False, 2), + (async_gen_a, True, 3), + (async_gen_a, False, 2), + ], +) +@pytest.mark.asyncio +async def test_async_resolve_all( + factory: Callable[..., Any], cache: bool, expected_count: int +): + class MyProvider(Provider): + a = provide(factory, scope=Scope.APP, cache=cache) + + @provide(scope=Scope.APP) + def get_int(self) -> int: + return 100 + + container = make_async_container(MyProvider()) + assert container.registry.scope is Scope.APP + assert len(container.context) == 1 + await container.resolve_all() + assert len(container.context) == expected_count + await container.close()