From e4f855027d0ff32ecbf3fdffaf8fb6c3bba628a5 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Sun, 22 Jun 2025 15:54:36 +0300 Subject: [PATCH 1/9] Run mypy on integrations.base and fix typing issues --- .mypy.ini | 17 +++++++++++++++- src/dishka/integrations/base.py | 35 +++++++++++---------------------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index ac2f644c8..5ee6bf3af 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,6 +1,21 @@ [mypy] files = src/dishka -exclude = ^src/dishka/(_adaptix|integrations)/ +exclude = (?x)( + ^src/dishka/_adaptix/ + |^src/dishka/integrations/( + aiohttp + |aiogram + |celery + |click + |fastapi + |faststream + |grpcio + |litestar + |sanic + |starlette + |taskiq + |telebot + ).py) strict = true strict_bytes = true diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 680d1cab5..f73292460 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -278,7 +278,7 @@ def _get_auto_injected_sync_container_async_gen_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], - func: Callable[P, Iterator[T]], + func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: async def auto_injected_generator( @@ -307,7 +307,7 @@ def _get_auto_injected_sync_container_async_gen( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], - func: Callable[P, Iterator[T]], + func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: if provide_context is not None: @@ -334,9 +334,9 @@ def _get_auto_injected_sync_container_async_func_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], - func: Callable[P, T], + func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, -) -> Callable[P, T]: +) -> Callable[P, Awaitable[T]]: async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) @@ -359,9 +359,9 @@ def _get_auto_injected_sync_container_async_func( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], - func: Callable[P, T], + func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, -) -> Callable[P, T]: +) -> Callable[P, Awaitable[T]]: if provide_context is not None: raise ImproperProvideContextUsageError @@ -464,7 +464,7 @@ def __post_init__(self) -> None: } -def get_func_type(func: Callable) -> FunctionType: +def get_func_type(func: Callable[P, T]) -> FunctionType: if isasyncgenfunction(func): return FunctionType.ASYNC_GENERATOR elif isgeneratorfunction(func): @@ -573,24 +573,11 @@ def wrap_injection( for param in new_params ] - auto_injected_func: Callable[P, T | Awaitable[T]] if additional_params: new_params = _add_params(new_params, additional_params) for param in additional_params: new_annotations[param.name] = param.annotation - if is_async: - func = cast(Callable[P, Awaitable[T]], func) - container_getter = cast( - ContainerGetter[AsyncContainer], - container_getter, - ) - else: - container_getter = cast( - ContainerGetter[Container], - container_getter, - ) - injected_func_type = InjectedFuncType( is_async_container=is_async, manage_scope=manage_scope, @@ -598,7 +585,7 @@ def wrap_injection( ) get_auto_injected_func = _GET_AUTO_INJECTED_FUNC_DICT[injected_func_type] - auto_injected_func = get_auto_injected_func( + auto_injected_func = get_auto_injected_func( # type: ignore[operator] func=func, provide_context=provide_context, dependencies=dependencies, @@ -607,13 +594,13 @@ def wrap_injection( ) auto_injected_func.__dishka_orig_func__ = func - auto_injected_func.__dishka_injected__ = True # type: ignore[attr-defined] + auto_injected_func.__dishka_injected__ = True auto_injected_func.__name__ = func.__name__ auto_injected_func.__qualname__ = func.__qualname__ auto_injected_func.__doc__ = func.__doc__ auto_injected_func.__module__ = func.__module__ auto_injected_func.__annotations__ = new_annotations - auto_injected_func.__signature__ = Signature( # type: ignore[attr-defined] + auto_injected_func.__signature__ = Signature( parameters=new_params, return_annotation=func_signature.return_annotation, ) @@ -627,7 +614,7 @@ def is_dishka_injected(func: Callable[..., Any]) -> bool: def _add_params( params: Sequence[Parameter], additional_params: Sequence[Parameter], -): +) -> list[Parameter]: params_kind_dict: dict[_ParameterKind, list[Parameter]] = {} for param in params: From e30aaab82934ba6f74465bc3ce8dcc4adc3cf85d Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Sun, 22 Jun 2025 17:06:32 +0300 Subject: [PATCH 2/9] DRYfy the _get_auto_injected_* functions --- .ruff.toml | 1 + src/dishka/integrations/base.py | 147 ++++++++++++++++---------------- 2 files changed, 76 insertions(+), 72 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index 13c42ef93..73e807083 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -25,6 +25,7 @@ lint.ignore = [ "PLR0913", "SIM103", "ISC003", + "COM812", # identitcal by code != identical by meaning "SIM114", diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index f73292460..72b4106de 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -54,6 +54,34 @@ ) +async def _maybe_inject_async( + container: AsyncContainer, + dependencies: dict[str, DependencyKey], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + solved = { + name: await container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + return func(*args, **kwargs, **solved) + + +def _maybe_inject_sync( + container: Container, + dependencies: dict[str, DependencyKey], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + return func(*args, **kwargs, **solved) + + def _get_auto_injected_async_gen_scoped( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], @@ -73,14 +101,10 @@ async def auto_injected_generator( ) container = container_getter(args, kwargs) async with container(additional_context) as container: - solved = { - name: await container.get( - dep.type_hint, - component=dep.component, - ) - for name, dep in dependencies.items() - } - async for message in func(*args, **kwargs, **solved): + async_gen = await _maybe_inject_async( + container, dependencies, func, *args, **kwargs + ) + async for message in async_gen: yield message return auto_injected_generator @@ -104,14 +128,10 @@ async def auto_injected_generator( kwargs.pop(param.name) container = container_getter(args, kwargs) - solved = { - name: await container.get( - dep.type_hint, - component=dep.component, - ) - for name, dep in dependencies.items() - } - async for message in func(*args, **kwargs, **solved): + async_gen = await _maybe_inject_async( + container, dependencies, func, *args, **kwargs + ) + async for message in async_gen: yield message return auto_injected_generator @@ -133,14 +153,10 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: {} if provide_context is None else provide_context(args, kwargs) ) async with container(additional_context) as container: - solved = { - name: await container.get( - dep.type_hint, - component=dep.component, - ) - for name, dep in dependencies.items() - } - return await func(*args, **kwargs, **solved) + coro = await _maybe_inject_async( + container, dependencies, func, *args, **kwargs + ) + return await coro return auto_injected_func @@ -159,14 +175,11 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) - solved = { - name: await container.get( - dep.type_hint, - component=dep.component, - ) - for name, dep in dependencies.items() - } - return await func(*args, **kwargs, **solved) + + coro = await _maybe_inject_async( + container, dependencies, func, *args, **kwargs + ) + return await coro return auto_injected_func @@ -190,11 +203,10 @@ def auto_injected_generator( ) container = container_getter(args, kwargs) with container(additional_context) as container: - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - yield from func(*args, **kwargs, **solved) + sync_gen = _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) + yield from sync_gen return auto_injected_generator @@ -216,11 +228,11 @@ def auto_injected_generator( container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - yield from func(*args, **kwargs, **solved) + + sync_gen = _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) + yield from sync_gen return auto_injected_generator @@ -241,11 +253,9 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) container = container_getter(args, kwargs) with container(additional_context) as container: - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - return func(*args, **kwargs, **solved) + return _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) return auto_injected_func @@ -265,11 +275,9 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - return func(*args, **kwargs, **solved) + return _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) return auto_injected_func @@ -293,11 +301,10 @@ async def auto_injected_generator( ) container = container_getter(args, kwargs) with container(additional_context) as container: - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - async for message in func(*args, **kwargs, **solved): + async_gen = _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) + async for message in async_gen: yield message return auto_injected_generator @@ -320,11 +327,11 @@ async def auto_injected_generator( container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - async for message in func(*args, **kwargs, **solved): + + async_gen = _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) + async for message in async_gen: yield message return auto_injected_generator @@ -346,11 +353,9 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) container = container_getter(args, kwargs) with container(additional_context) as container: - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - return await func(*args, **kwargs, **solved) + return await _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) return auto_injected_func @@ -370,11 +375,9 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) - solved = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() - } - return await func(*args, **kwargs, **solved) + return await _maybe_inject_sync( + container, dependencies, func, *args, **kwargs + ) return auto_injected_func From f504589b50a9caa55d96128ee4c191f444c74c1e Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Sun, 22 Jun 2025 22:45:29 +0300 Subject: [PATCH 3/9] Allow passing explicitly the dependencies when remove_depends=False --- src/dishka/integrations/base.py | 92 ++++++-- .../base/test_iter_dependencies_to_inject.py | 92 ++++++++ .../test_wrap_injection_remove_depends.py | 212 ++++++++++++++++++ 3 files changed, 373 insertions(+), 23 deletions(-) create mode 100644 tests/integrations/base/test_iter_dependencies_to_inject.py create mode 100644 tests/integrations/base/test_wrap_injection_remove_depends.py diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 72b4106de..589223eed 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -54,36 +54,70 @@ ) +def _iter_dependencies_to_inject( + dependencies: dict[str, DependencyKey], + params: Sequence[Parameter], + *args: Any, + **kwargs: Any, +) -> Iterator[tuple[str, DependencyKey]]: + named_params = {param.name: param for param in params} + named_indexes = {param.name: i for i, param in enumerate(params)} + for name, dep in dependencies.items(): + param = named_params.get(name) + if param is None: + # Inject the dependency if it was removed from the signature + yield name, dep + elif param.kind is Parameter.POSITIONAL_OR_KEYWORD: + # Inject the dependency if not provided explicitly + if named_indexes[name] >= len(args) and name not in kwargs: + yield name, dep + elif param.kind is Parameter.KEYWORD_ONLY: + # Inject the dependency if not provided explicitly + if name not in kwargs: + yield name, dep + else: + raise NotImplementedError( + f"Unsupported parameter kind: {param.kind}" + ) + + async def _maybe_inject_async( container: AsyncContainer, dependencies: dict[str, DependencyKey], + params: Sequence[Parameter], func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: - solved = { + resolved_deps = { name: await container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() + for name, dep in _iter_dependencies_to_inject( + dependencies, params, *args, **kwargs + ) } - return func(*args, **kwargs, **solved) + return func(*args, **kwargs, **resolved_deps) def _maybe_inject_sync( container: Container, dependencies: dict[str, DependencyKey], + params: Sequence[Parameter], func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: - solved = { + resolved_deps = { name: container.get(dep.type_hint, component=dep.component) - for name, dep in dependencies.items() + for name, dep in _iter_dependencies_to_inject( + dependencies, params, *args, **kwargs + ) } - return func(*args, **kwargs, **solved) + return func(*args, **kwargs, **resolved_deps) def _get_auto_injected_async_gen_scoped( container_getter: ContainerGetter[AsyncContainer], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -102,7 +136,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) async with container(additional_context) as container: async_gen = await _maybe_inject_async( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) async for message in async_gen: yield message @@ -112,6 +146,7 @@ async def auto_injected_generator( def _get_auto_injected_async_gen( container_getter: ContainerGetter[AsyncContainer], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -129,7 +164,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) async_gen = await _maybe_inject_async( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) async for message in async_gen: yield message @@ -139,6 +174,7 @@ async def auto_injected_generator( def _get_auto_injected_async_func_scoped( container_getter: ContainerGetter[AsyncContainer], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -154,7 +190,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) async with container(additional_context) as container: coro = await _maybe_inject_async( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) return await coro @@ -163,6 +199,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_async_func( container_getter: ContainerGetter[AsyncContainer], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -177,7 +214,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) coro = await _maybe_inject_async( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) return await coro @@ -186,6 +223,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_gen_scoped( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Iterator[T]], @@ -204,7 +242,7 @@ def auto_injected_generator( container = container_getter(args, kwargs) with container(additional_context) as container: sync_gen = _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) yield from sync_gen @@ -213,6 +251,7 @@ def auto_injected_generator( def _get_auto_injected_sync_gen( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Iterator[T]], @@ -230,7 +269,7 @@ def auto_injected_generator( kwargs.pop(param.name) sync_gen = _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) yield from sync_gen @@ -239,6 +278,7 @@ def auto_injected_generator( def _get_auto_injected_sync_func_scoped( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, T], @@ -254,7 +294,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) with container(additional_context) as container: return _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) return auto_injected_func @@ -262,6 +302,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_func( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, T], @@ -276,7 +317,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) return _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) return auto_injected_func @@ -284,6 +325,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_gen_scoped( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -302,7 +344,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) with container(additional_context) as container: async_gen = _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) async for message in async_gen: yield message @@ -312,6 +354,7 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_gen( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -329,7 +372,7 @@ async def auto_injected_generator( kwargs.pop(param.name) async_gen = _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) async for message in async_gen: yield message @@ -339,6 +382,7 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_func_scoped( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -354,7 +398,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) with container(additional_context) as container: return await _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) return auto_injected_func @@ -362,6 +406,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_func( container_getter: ContainerGetter[Container], + params: Sequence[Parameter], additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -376,7 +421,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) return await _maybe_inject_sync( - container, dependencies, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs ) return auto_injected_func @@ -576,11 +621,6 @@ def wrap_injection( for param in new_params ] - if additional_params: - new_params = _add_params(new_params, additional_params) - for param in additional_params: - new_annotations[param.name] = param.annotation - injected_func_type = InjectedFuncType( is_async_container=is_async, manage_scope=manage_scope, @@ -592,10 +632,16 @@ def wrap_injection( func=func, provide_context=provide_context, dependencies=dependencies, + params=new_params, additional_params=additional_params, container_getter=container_getter, ) + if additional_params: + new_params = _add_params(new_params, additional_params) + for param in additional_params: + new_annotations[param.name] = param.annotation + auto_injected_func.__dishka_orig_func__ = func auto_injected_func.__dishka_injected__ = True auto_injected_func.__name__ = func.__name__ diff --git a/tests/integrations/base/test_iter_dependencies_to_inject.py b/tests/integrations/base/test_iter_dependencies_to_inject.py new file mode 100644 index 000000000..80a2b2ab2 --- /dev/null +++ b/tests/integrations/base/test_iter_dependencies_to_inject.py @@ -0,0 +1,92 @@ +from inspect import signature +from unittest.mock import Mock + +import pytest + +from dishka import FromDishka +from dishka.integrations.base import _iter_dependencies_to_inject +from tests.integrations.common import AppMock + +Dep1 = FromDishka[Mock] +Dep2 = FromDishka[AppMock] + + +def pos_or_kw(i: int, d1: Dep1, d2: Dep2, j: int = 0): ... + + +def kw_only(i: int, *, d1: Dep1, d2: Dep2, j: int = 0): ... + + +def mixed(i: int, d1: Dep1, *, d2: Dep2, j: int = 0): ... + + +def pos_only(i: int, d1: Dep1, d2: Dep2, /, j: int = 0): ... + + +def pos_only_d1(i: int, d1: Dep1, /, d2: Dep2, j: int = 0): ... + + +def var_args(i: int, d1: Dep1, *d2: Dep2, j: int = 0): ... + + +def var_kwargs(i: int, d1: Dep1, j: int = 0, **d2: Dep2): ... + + +def var_args_kwargs(i: int, *d1: Dep1, j: int = 0, **d2: Dep2): ... + + +def get_injected_names_factory(func): + params = list(signature(func).parameters.values()) + deps = {"d1": Mock(), "d2": AppMock(Mock())} + + def get_injected_names(*args, **kw): + named_deps = _iter_dependencies_to_inject(deps, params, *args, **kw) + return [name for name, _ in named_deps] + + return get_injected_names + + +@pytest.mark.parametrize("func", [pos_or_kw, kw_only, mixed]) +def test_dont_pass_dependencies(func): + get_injected_names = get_injected_names_factory(func) + # Both dependencies injected + assert get_injected_names(1) == ["d1", "d2"] + assert get_injected_names(2, j=9) == ["d1", "d2"] + + +@pytest.mark.parametrize("func", [pos_or_kw, kw_only, mixed]) +def test_pass_dependencies_by_name(func): + get_injected_names = get_injected_names_factory(func) + # d1 passed by name, d2 injected + assert get_injected_names(1, d1=Mock()) == ["d2"] + assert get_injected_names(2, d1=Mock(), j=9) == ["d2"] + # d2 passed by name, d1 injected + assert get_injected_names(3, d2=Mock()) == ["d1"] + assert get_injected_names(4, d2=Mock(), j=9) == ["d1"] + # Both dependencies passed by name, no injection + assert get_injected_names(1, d1=Mock(), d2=Mock()) == [] + assert get_injected_names(2, d1=Mock(), d2=Mock(), j=9) == [] + + +@pytest.mark.parametrize("func", [pos_or_kw, mixed]) +def test_pass_dependencies_by_position(func): + get_injected_names = get_injected_names_factory(func) + # d1 passed positionally, d2 injected + assert get_injected_names(1, Mock()) == ["d2"] + assert get_injected_names(2, Mock(), j=9) == ["d2"] + # d1 passed positionally, d2 passed by name, no injection + assert get_injected_names(2, Mock(), d2=Mock()) == [] + assert get_injected_names(3, Mock(), d2=Mock(), j=9) == [] + if func is pos_or_kw: + # d1 and d2 passed positionally, no injection + assert get_injected_names(3, Mock(), Mock()) == [] + assert get_injected_names(3, Mock(), Mock(), j=9) == [] + + +@pytest.mark.parametrize( + "func", [pos_only, pos_only_d1, var_args, var_kwargs, var_args_kwargs] +) +def test_not_implemented_parameter_kinds(func): + get_injected_names = get_injected_names_factory(func) + with pytest.raises(NotImplementedError): + get_injected_names(1, Mock(), d2=Mock()) diff --git a/tests/integrations/base/test_wrap_injection_remove_depends.py b/tests/integrations/base/test_wrap_injection_remove_depends.py new file mode 100644 index 000000000..57938a557 --- /dev/null +++ b/tests/integrations/base/test_wrap_injection_remove_depends.py @@ -0,0 +1,212 @@ +import asyncio +from collections.abc import Iterable +from inspect import isasyncgen, iscoroutine, isgenerator +from unittest.mock import Mock + +import pytest + +from dishka import FromDishka, make_async_container, make_container +from dishka.integrations.base import wrap_injection +from tests.integrations.common import AppMock + + +def raises_multiple_values(obj): + with pytest.raises(TypeError, match="multiple values for"): # noqa: PT012 + if isgenerator(obj): + list(obj) + elif callable(obj): + obj() + else: + pytest.fail("Object is neither a generator nor callable") + + +async def raises_multiple_values_async(obj): + with pytest.raises(TypeError, match="multiple values for"): # noqa: PT012 + if isasyncgen(obj): + [x async for x in obj] + elif iscoroutine(obj): + await obj + else: + pytest.fail("Object is neither a generator nor callable") + + +def sync_func(i: int, dep: FromDishka[AppMock], j: int = 0): + return dep(i, j) + + +def sync_gen(data: Iterable[int], dep: FromDishka[AppMock], j: int = 0): + for i in data: + yield dep(i, j) + + +async def async_func(i: int, dep: FromDishka[AppMock], j: int = 0): + await asyncio.sleep(0) + return dep(i, j) + + +async def async_gen( + data: Iterable[int], + dep: FromDishka[AppMock], + j: int = 0, +): + for i in data: + await asyncio.sleep(0) + yield dep(i, j) + + +@pytest.mark.parametrize("remove_depends", [True, False]) +def test_sync_func(remove_depends, app_provider): + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=sync_func, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=False, + ) + + wrapped_func(1) + app_provider.app_mock.assert_called_with(1, 0) + wrapped_func(2, j=3) + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + raises_multiple_values(lambda: wrapped_func(1, new_dep)) + raises_multiple_values(lambda: wrapped_func(2, dep=new_dep)) + raises_multiple_values(lambda: wrapped_func(3, new_dep, 9)) + raises_multiple_values(lambda: wrapped_func(4, new_dep, j=9)) + raises_multiple_values(lambda: wrapped_func(5, dep=new_dep, j=9)) + else: + wrapped_func(1, new_dep) + new_dep.assert_called_with(1, 0) + wrapped_func(2, dep=new_dep) + new_dep.assert_called_with(2, 0) + wrapped_func(3, new_dep, 9) + new_dep.assert_called_with(3, 9) + wrapped_func(4, new_dep, j=9) + new_dep.assert_called_with(4, 9) + wrapped_func(5, dep=new_dep, j=9) + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() + + +@pytest.mark.parametrize("remove_depends", [True, False]) +def test_sync_gen(remove_depends, app_provider): + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=sync_gen, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=False, + ) + + list(wrapped_func([1])) + app_provider.app_mock.assert_called_with(1, 0) + list(wrapped_func([2], j=3)) + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + raises_multiple_values(wrapped_func([1], new_dep)) + raises_multiple_values(wrapped_func([2], dep=new_dep)) + raises_multiple_values(wrapped_func([3], new_dep, 9)) + raises_multiple_values(wrapped_func([4], new_dep, j=9)) + raises_multiple_values(wrapped_func([5], dep=new_dep, j=9)) + else: + list(wrapped_func([1], new_dep)) + new_dep.assert_called_with(1, 0) + list(wrapped_func([2], dep=new_dep)) + new_dep.assert_called_with(2, 0) + list(wrapped_func([3], new_dep, 9)) + new_dep.assert_called_with(3, 9) + list(wrapped_func([4], new_dep, j=9)) + new_dep.assert_called_with(4, 9) + list(wrapped_func([5], dep=new_dep, j=9)) + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_container", [True, False]) +@pytest.mark.parametrize("remove_depends", [True, False]) +async def test_async_func(async_container, remove_depends, app_provider): + if async_container: + container = make_async_container(app_provider) + else: + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=async_func, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=async_container, + ) + + await wrapped_func(1) + app_provider.app_mock.assert_called_with(1, 0) + await wrapped_func(2, j=3) + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + await raises_multiple_values_async(wrapped_func(1, new_dep)) + await raises_multiple_values_async(wrapped_func(2, dep=new_dep)) + await raises_multiple_values_async(wrapped_func(3, new_dep, 9)) + await raises_multiple_values_async(wrapped_func(4, new_dep, j=9)) + await raises_multiple_values_async(wrapped_func(5, dep=new_dep, j=9)) + else: + await wrapped_func(1, new_dep) + new_dep.assert_called_with(1, 0) + await wrapped_func(2, dep=new_dep) + new_dep.assert_called_with(2, 0) + await wrapped_func(3, new_dep, 9) + new_dep.assert_called_with(3, 9) + await wrapped_func(4, new_dep, j=9) + new_dep.assert_called_with(4, 9) + await wrapped_func(5, dep=new_dep, j=9) + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_container", [True, False]) +@pytest.mark.parametrize("remove_depends", [True, False]) +async def test_async_gen(async_container, remove_depends, app_provider): + if async_container: + container = make_async_container(app_provider) + else: + container = make_container(app_provider) + wrapped_func = wrap_injection( + func=async_gen, + container_getter=lambda *_: container, + remove_depends=remove_depends, + is_async=async_container, + ) + + [item async for item in wrapped_func([1])] + app_provider.app_mock.assert_called_with(1, 0) + [item async for item in wrapped_func([2], j=3)] + app_provider.app_mock.assert_called_with(2, 3) + + app_provider.app_mock.reset_mock() + new_dep = AppMock(Mock()) + if remove_depends: + await raises_multiple_values_async(wrapped_func([1], new_dep)) + await raises_multiple_values_async(wrapped_func([2], dep=new_dep)) + await raises_multiple_values_async(wrapped_func([3], new_dep, 9)) + await raises_multiple_values_async(wrapped_func([4], new_dep, j=9)) + await raises_multiple_values_async(wrapped_func([5], dep=new_dep, j=9)) + else: + [item async for item in wrapped_func([1], new_dep)] + new_dep.assert_called_with(1, 0) + [item async for item in wrapped_func([2], dep=new_dep)] + new_dep.assert_called_with(2, 0) + [item async for item in wrapped_func([3], new_dep, 9)] + new_dep.assert_called_with(3, 9) + [item async for item in wrapped_func([4], new_dep, j=9)] + new_dep.assert_called_with(4, 9) + [item async for item in wrapped_func([5], dep=new_dep, j=9)] + new_dep.assert_called_with(5, 9) + app_provider.app_mock.assert_not_called() From 76317d46dbfcb2325b2ba4f72875112372899307 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Tue, 24 Jun 2025 02:09:48 +0300 Subject: [PATCH 4/9] Unignore COM812 --- .ruff.toml | 1 - src/dishka/integrations/base.py | 30 +++++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/.ruff.toml b/.ruff.toml index 73e807083..13c42ef93 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -25,7 +25,6 @@ lint.ignore = [ "PLR0913", "SIM103", "ISC003", - "COM812", # identitcal by code != identical by meaning "SIM114", diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 589223eed..408018ce3 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -77,7 +77,7 @@ def _iter_dependencies_to_inject( yield name, dep else: raise NotImplementedError( - f"Unsupported parameter kind: {param.kind}" + f"Unsupported parameter kind: {param.kind}", ) @@ -92,7 +92,7 @@ async def _maybe_inject_async( resolved_deps = { name: await container.get(dep.type_hint, component=dep.component) for name, dep in _iter_dependencies_to_inject( - dependencies, params, *args, **kwargs + dependencies, params, *args, **kwargs, ) } return func(*args, **kwargs, **resolved_deps) @@ -109,7 +109,7 @@ def _maybe_inject_sync( resolved_deps = { name: container.get(dep.type_hint, component=dep.component) for name, dep in _iter_dependencies_to_inject( - dependencies, params, *args, **kwargs + dependencies, params, *args, **kwargs, ) } return func(*args, **kwargs, **resolved_deps) @@ -136,7 +136,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) async with container(additional_context) as container: async_gen = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -164,7 +164,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) async_gen = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -190,7 +190,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) async with container(additional_context) as container: coro = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) return await coro @@ -214,7 +214,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) coro = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) return await coro @@ -242,7 +242,7 @@ def auto_injected_generator( container = container_getter(args, kwargs) with container(additional_context) as container: sync_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) yield from sync_gen @@ -269,7 +269,7 @@ def auto_injected_generator( kwargs.pop(param.name) sync_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) yield from sync_gen @@ -294,7 +294,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) with container(additional_context) as container: return _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) return auto_injected_func @@ -317,7 +317,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) return _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) return auto_injected_func @@ -344,7 +344,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) with container(additional_context) as container: async_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -372,7 +372,7 @@ async def auto_injected_generator( kwargs.pop(param.name) async_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -398,7 +398,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) with container(additional_context) as container: return await _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) return auto_injected_func @@ -421,7 +421,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) return await _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs + container, dependencies, params, func, *args, **kwargs, ) return auto_injected_func From 8ea2d9c2956fe1222490ba3be84829de52e2b365 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Tue, 24 Jun 2025 02:09:03 +0300 Subject: [PATCH 5/9] Optimize the remove_depends=True path --- src/dishka/integrations/base.py | 56 ++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 408018ce3..af88a9707 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import ( AsyncIterator, Awaitable, @@ -84,33 +85,46 @@ def _iter_dependencies_to_inject( async def _maybe_inject_async( container: AsyncContainer, dependencies: dict[str, DependencyKey], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: - resolved_deps = { - name: await container.get(dep.type_hint, component=dep.component) - for name, dep in _iter_dependencies_to_inject( + if params is None: + named_deps = list(dependencies.items()) + else: + named_deps = list(_iter_dependencies_to_inject( dependencies, params, *args, **kwargs, + )) + + resolved_deps: dict[str, Any] = {} + if named_deps: + names, deps = zip(*named_deps, strict=True) + coros = (container.get(dep.type_hint, component=dep.component) + for dep in deps) + resolved_deps.update( + zip(names, await asyncio.gather(*coros), strict=True), ) - } return func(*args, **kwargs, **resolved_deps) def _maybe_inject_sync( container: Container, dependencies: dict[str, DependencyKey], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: - resolved_deps = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in _iter_dependencies_to_inject( + if params is None: + named_deps = iter(dependencies.items()) + else: + named_deps = _iter_dependencies_to_inject( dependencies, params, *args, **kwargs, ) + resolved_deps = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in named_deps } return func(*args, **kwargs, **resolved_deps) @@ -146,7 +160,7 @@ async def auto_injected_generator( def _get_auto_injected_async_gen( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -174,7 +188,7 @@ async def auto_injected_generator( def _get_auto_injected_async_func_scoped( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -199,7 +213,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_async_func( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -223,7 +237,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_gen_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Iterator[T]], @@ -251,7 +265,7 @@ def auto_injected_generator( def _get_auto_injected_sync_gen( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Iterator[T]], @@ -278,7 +292,7 @@ def auto_injected_generator( def _get_auto_injected_sync_func_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, T], @@ -302,7 +316,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_func( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, T], @@ -325,7 +339,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_gen_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -354,7 +368,7 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_gen( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, AsyncIterator[T]], @@ -382,7 +396,7 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_func_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -406,7 +420,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_func( container_getter: ContainerGetter[Container], - params: Sequence[Parameter], + params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], dependencies: dict[str, DependencyKey], func: Callable[P, Awaitable[T]], @@ -632,7 +646,7 @@ def wrap_injection( func=func, provide_context=provide_context, dependencies=dependencies, - params=new_params, + params=new_params if not remove_depends else None, additional_params=additional_params, container_getter=container_getter, ) From 6327e8c5dec5ddb2391e954646bca4a5810bc311 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Tue, 24 Jun 2025 11:17:36 +0300 Subject: [PATCH 6/9] Remove asyncio.gather() --- src/dishka/integrations/base.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index af88a9707..70738e923 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -1,4 +1,3 @@ -import asyncio from collections.abc import ( AsyncIterator, Awaitable, @@ -91,20 +90,16 @@ async def _maybe_inject_async( **kwargs: P.kwargs, ) -> T: if params is None: - named_deps = list(dependencies.items()) + named_deps = iter(dependencies.items()) else: - named_deps = list(_iter_dependencies_to_inject( + named_deps = iter(_iter_dependencies_to_inject( dependencies, params, *args, **kwargs, )) - resolved_deps: dict[str, Any] = {} - if named_deps: - names, deps = zip(*named_deps, strict=True) - coros = (container.get(dep.type_hint, component=dep.component) - for dep in deps) - resolved_deps.update( - zip(names, await asyncio.gather(*coros), strict=True), - ) + resolved_deps = { + name: await container.get(dep.type_hint, component=dep.component) + for name, dep in named_deps + } return func(*args, **kwargs, **resolved_deps) From cd21328d1a255fae9ef7d98170bb89312d0dac36 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Tue, 24 Jun 2025 12:40:30 +0300 Subject: [PATCH 7/9] More optimization and refactoring: replace _iter_dependencies_to_inject with ParameterDependencyResolver --- src/dishka/integrations/base.py | 159 +++++++++--------- ... => test_parameter_dependency_resolver.py} | 11 +- 2 files changed, 83 insertions(+), 87 deletions(-) rename tests/integrations/base/{test_iter_dependencies_to_inject.py => test_parameter_dependency_resolver.py} (89%) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 70738e923..b70aab9bf 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -7,6 +7,7 @@ ) from dataclasses import dataclass from enum import Enum +from functools import partial from inspect import ( Parameter, Signature, @@ -54,81 +55,85 @@ ) -def _iter_dependencies_to_inject( - dependencies: dict[str, DependencyKey], - params: Sequence[Parameter], - *args: Any, - **kwargs: Any, -) -> Iterator[tuple[str, DependencyKey]]: - named_params = {param.name: param for param in params} - named_indexes = {param.name: i for i, param in enumerate(params)} - for name, dep in dependencies.items(): - param = named_params.get(name) - if param is None: - # Inject the dependency if it was removed from the signature - yield name, dep - elif param.kind is Parameter.POSITIONAL_OR_KEYWORD: - # Inject the dependency if not provided explicitly - if named_indexes[name] >= len(args) and name not in kwargs: - yield name, dep - elif param.kind is Parameter.KEYWORD_ONLY: - # Inject the dependency if not provided explicitly - if name not in kwargs: +class ParameterDependencyResolver: + def __init__( + self, + params: Sequence[Parameter], + dependencies: dict[str, DependencyKey], + ): + self._named_deps_predicates = [] + named_params = {param.name: param for param in params} + named_idxs = {param.name: i for i, param in enumerate(params)} + for name, dep in dependencies.items(): + match named_params[name].kind: + case Parameter.POSITIONAL_OR_KEYWORD: + pred = partial(self._has_pos_or_kw, named_idxs[name], name) + case Parameter.KEYWORD_ONLY: + pred = partial(self._has_kw_only, name) + case kind: + raise NotImplementedError( + f"Unsupported parameter kind: {kind}", + ) + self._named_deps_predicates.append((name, dep, pred)) + + def __call__( + self, *args: Any, **kwargs: Any, + ) -> Iterator[tuple[str, DependencyKey]]: + for name, dep, has_param in self._named_deps_predicates: + if not has_param(*args, **kwargs): yield name, dep - else: - raise NotImplementedError( - f"Unsupported parameter kind: {param.kind}", - ) + + @staticmethod + def _has_pos_or_kw(i: int, name: str, *args: Any, **kwargs: Any) -> bool: + return i < len(args) or name in kwargs + + @staticmethod + def _has_kw_only(name: str, *args: Any, **kwargs: Any) -> bool: + return name in kwargs + + +class DependencyResolver: + def __init__(self, dependencies: dict[str, DependencyKey]): + self._named_deps = list(dependencies.items()) + + def __call__( + self, *args: Any, **kwargs: Any, + ) -> Iterator[tuple[str, DependencyKey]]: + return iter(self._named_deps) async def _maybe_inject_async( container: AsyncContainer, - dependencies: dict[str, DependencyKey], - params: Sequence[Parameter] | None, + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: - if params is None: - named_deps = iter(dependencies.items()) - else: - named_deps = iter(_iter_dependencies_to_inject( - dependencies, params, *args, **kwargs, - )) - resolved_deps = { name: await container.get(dep.type_hint, component=dep.component) - for name, dep in named_deps + for name, dep in resolver(*args, **kwargs) } return func(*args, **kwargs, **resolved_deps) def _maybe_inject_sync( container: Container, - dependencies: dict[str, DependencyKey], - params: Sequence[Parameter] | None, + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs, ) -> T: - if params is None: - named_deps = iter(dependencies.items()) - else: - named_deps = _iter_dependencies_to_inject( - dependencies, params, *args, **kwargs, - ) resolved_deps = { name: container.get(dep.type_hint, component=dep.component) - for name, dep in named_deps + for name, dep in resolver(*args, **kwargs) } return func(*args, **kwargs, **resolved_deps) def _get_auto_injected_async_gen_scoped( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter], additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -145,7 +150,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) async with container(additional_context) as container: async_gen = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -155,9 +160,8 @@ async def auto_injected_generator( def _get_auto_injected_async_gen( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -173,7 +177,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) async_gen = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -183,9 +187,8 @@ async def auto_injected_generator( def _get_auto_injected_async_func_scoped( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -199,7 +202,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) async with container(additional_context) as container: coro = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) return await coro @@ -208,9 +211,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_async_func( container_getter: ContainerGetter[AsyncContainer], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -223,7 +225,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) coro = await _maybe_inject_async( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) return await coro @@ -232,9 +234,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_gen_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, Iterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Iterator[T]]: @@ -251,7 +252,7 @@ def auto_injected_generator( container = container_getter(args, kwargs) with container(additional_context) as container: sync_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) yield from sync_gen @@ -260,9 +261,8 @@ def auto_injected_generator( def _get_auto_injected_sync_gen( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, Iterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Iterator[T]]: @@ -278,7 +278,7 @@ def auto_injected_generator( kwargs.pop(param.name) sync_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) yield from sync_gen @@ -287,9 +287,8 @@ def auto_injected_generator( def _get_auto_injected_sync_func_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, T], provide_context: ProvideContext | None = None, ) -> Callable[P, T]: @@ -303,7 +302,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) with container(additional_context) as container: return _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) return auto_injected_func @@ -311,9 +310,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_func( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, T], provide_context: ProvideContext | None = None, ) -> Callable[P, T]: @@ -326,7 +324,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) return _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) return auto_injected_func @@ -334,9 +332,8 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_gen_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -353,7 +350,7 @@ async def auto_injected_generator( container = container_getter(args, kwargs) with container(additional_context) as container: async_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -363,9 +360,8 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_gen( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -381,7 +377,7 @@ async def auto_injected_generator( kwargs.pop(param.name) async_gen = _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) async for message in async_gen: yield message @@ -391,9 +387,8 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_func_scoped( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -407,7 +402,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) with container(additional_context) as container: return await _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) return auto_injected_func @@ -415,9 +410,8 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_func( container_getter: ContainerGetter[Container], - params: Sequence[Parameter] | None, additional_params: Sequence[Parameter], - dependencies: dict[str, DependencyKey], + resolver: DependencyResolver | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -430,7 +424,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: kwargs.pop(param.name) return await _maybe_inject_sync( - container, dependencies, params, func, *args, **kwargs, + container, resolver, func, *args, **kwargs, ) return auto_injected_func @@ -640,8 +634,11 @@ def wrap_injection( auto_injected_func = get_auto_injected_func( # type: ignore[operator] func=func, provide_context=provide_context, - dependencies=dependencies, - params=new_params if not remove_depends else None, + resolver=( + DependencyResolver(dependencies) + if remove_depends + else ParameterDependencyResolver(new_params, dependencies) + ), additional_params=additional_params, container_getter=container_getter, ) diff --git a/tests/integrations/base/test_iter_dependencies_to_inject.py b/tests/integrations/base/test_parameter_dependency_resolver.py similarity index 89% rename from tests/integrations/base/test_iter_dependencies_to_inject.py rename to tests/integrations/base/test_parameter_dependency_resolver.py index 80a2b2ab2..8417fa200 100644 --- a/tests/integrations/base/test_iter_dependencies_to_inject.py +++ b/tests/integrations/base/test_parameter_dependency_resolver.py @@ -4,7 +4,7 @@ import pytest from dishka import FromDishka -from dishka.integrations.base import _iter_dependencies_to_inject +from dishka.integrations.base import ParameterDependencyResolver from tests.integrations.common import AppMock Dep1 = FromDishka[Mock] @@ -38,10 +38,10 @@ def var_args_kwargs(i: int, *d1: Dep1, j: int = 0, **d2: Dep2): ... def get_injected_names_factory(func): params = list(signature(func).parameters.values()) deps = {"d1": Mock(), "d2": AppMock(Mock())} + resolver = ParameterDependencyResolver(params, deps) def get_injected_names(*args, **kw): - named_deps = _iter_dependencies_to_inject(deps, params, *args, **kw) - return [name for name, _ in named_deps] + return [name for name, _ in resolver(*args, **kw)] return get_injected_names @@ -84,9 +84,8 @@ def test_pass_dependencies_by_position(func): @pytest.mark.parametrize( - "func", [pos_only, pos_only_d1, var_args, var_kwargs, var_args_kwargs] + "func", [pos_only, pos_only_d1, var_args, var_kwargs, var_args_kwargs], ) def test_not_implemented_parameter_kinds(func): - get_injected_names = get_injected_names_factory(func) with pytest.raises(NotImplementedError): - get_injected_names(1, Mock(), d2=Mock()) + get_injected_names_factory(func) From c161f7bfb5b6249b54163625757e20d206641703 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Tue, 24 Jun 2025 23:37:34 +0300 Subject: [PATCH 8/9] More remove_depends=True path optimization: just an extra isinstance() call --- src/dishka/integrations/base.py | 233 ++++++++++-------- .../test_parameter_dependency_resolver.py | 3 +- 2 files changed, 127 insertions(+), 109 deletions(-) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index b70aab9bf..fdb07349c 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -61,79 +61,45 @@ def __init__( params: Sequence[Parameter], dependencies: dict[str, DependencyKey], ): + self._selected_named_deps: list[tuple[str, DependencyKey]] = [] self._named_deps_predicates = [] named_params = {param.name: param for param in params} named_idxs = {param.name: i for i, param in enumerate(params)} for name, dep in dependencies.items(): match named_params[name].kind: case Parameter.POSITIONAL_OR_KEYWORD: - pred = partial(self._has_pos_or_kw, named_idxs[name], name) + pred = partial(_has_pos_or_kw, named_idxs[name], name) case Parameter.KEYWORD_ONLY: - pred = partial(self._has_kw_only, name) + pred = partial(_has_kw_only, name) case kind: raise NotImplementedError( f"Unsupported parameter kind: {kind}", ) self._named_deps_predicates.append((name, dep, pred)) - def __call__( - self, *args: Any, **kwargs: Any, - ) -> Iterator[tuple[str, DependencyKey]]: - for name, dep, has_param in self._named_deps_predicates: - if not has_param(*args, **kwargs): - yield name, dep - - @staticmethod - def _has_pos_or_kw(i: int, name: str, *args: Any, **kwargs: Any) -> bool: - return i < len(args) or name in kwargs + def bind(self, *args: Any, **kwargs: Any) -> None: + self._selected_named_deps = [ + (name, dep) + for name, dep, has_param in self._named_deps_predicates + if not has_param(*args, **kwargs) + ] - @staticmethod - def _has_kw_only(name: str, *args: Any, **kwargs: Any) -> bool: - return name in kwargs + def items(self) -> Iterator[tuple[str, DependencyKey]]: + return iter(self._selected_named_deps) -class DependencyResolver: - def __init__(self, dependencies: dict[str, DependencyKey]): - self._named_deps = list(dependencies.items()) +def _has_pos_or_kw(i: int, name: str, *args: Any, **kwargs: Any) -> bool: + return i < len(args) or name in kwargs - def __call__( - self, *args: Any, **kwargs: Any, - ) -> Iterator[tuple[str, DependencyKey]]: - return iter(self._named_deps) - -async def _maybe_inject_async( - container: AsyncContainer, - resolver: DependencyResolver | ParameterDependencyResolver, - func: Callable[P, T], - *args: P.args, - **kwargs: P.kwargs, -) -> T: - resolved_deps = { - name: await container.get(dep.type_hint, component=dep.component) - for name, dep in resolver(*args, **kwargs) - } - return func(*args, **kwargs, **resolved_deps) - - -def _maybe_inject_sync( - container: Container, - resolver: DependencyResolver | ParameterDependencyResolver, - func: Callable[P, T], - *args: P.args, - **kwargs: P.kwargs, -) -> T: - resolved_deps = { - name: container.get(dep.type_hint, component=dep.component) - for name, dep in resolver(*args, **kwargs) - } - return func(*args, **kwargs, **resolved_deps) +def _has_kw_only(name: str, *args: Any, **kwargs: Any) -> bool: + return name in kwargs def _get_auto_injected_async_gen_scoped( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -149,10 +115,16 @@ async def auto_injected_generator( ) container = container_getter(args, kwargs) async with container(additional_context) as container: - async_gen = await _maybe_inject_async( - container, resolver, func, *args, **kwargs, - ) - async for message in async_gen: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: await container.get( + dep.type_hint, + component=dep.component, + ) + for name, dep in dependencies.items() + } + async for message in func(*args, **kwargs, **solved): yield message return auto_injected_generator @@ -161,7 +133,7 @@ async def auto_injected_generator( def _get_auto_injected_async_gen( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -176,10 +148,16 @@ async def auto_injected_generator( kwargs.pop(param.name) container = container_getter(args, kwargs) - async_gen = await _maybe_inject_async( - container, resolver, func, *args, **kwargs, - ) - async for message in async_gen: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: await container.get( + dep.type_hint, + component=dep.component, + ) + for name, dep in dependencies.items() + } + async for message in func(*args, **kwargs, **solved): yield message return auto_injected_generator @@ -188,7 +166,7 @@ async def auto_injected_generator( def _get_auto_injected_async_func_scoped( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -201,10 +179,16 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: {} if provide_context is None else provide_context(args, kwargs) ) async with container(additional_context) as container: - coro = await _maybe_inject_async( - container, resolver, func, *args, **kwargs, - ) - return await coro + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: await container.get( + dep.type_hint, + component=dep.component, + ) + for name, dep in dependencies.items() + } + return await func(*args, **kwargs, **solved) return auto_injected_func @@ -212,7 +196,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_async_func( container_getter: ContainerGetter[AsyncContainer], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -223,11 +207,16 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: container = container_getter(args, kwargs) for param in additional_params: kwargs.pop(param.name) - - coro = await _maybe_inject_async( - container, resolver, func, *args, **kwargs, - ) - return await coro + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: await container.get( + dep.type_hint, + component=dep.component, + ) + for name, dep in dependencies.items() + } + return await func(*args, **kwargs, **solved) return auto_injected_func @@ -235,7 +224,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_gen_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Iterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Iterator[T]]: @@ -251,10 +240,13 @@ def auto_injected_generator( ) container = container_getter(args, kwargs) with container(additional_context) as container: - sync_gen = _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) - yield from sync_gen + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + yield from func(*args, **kwargs, **solved) return auto_injected_generator @@ -262,7 +254,7 @@ def auto_injected_generator( def _get_auto_injected_sync_gen( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Iterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Iterator[T]]: @@ -277,10 +269,13 @@ def auto_injected_generator( for param in additional_params: kwargs.pop(param.name) - sync_gen = _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) - yield from sync_gen + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + yield from func(*args, **kwargs, **solved) return auto_injected_generator @@ -288,7 +283,7 @@ def auto_injected_generator( def _get_auto_injected_sync_func_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, T], provide_context: ProvideContext | None = None, ) -> Callable[P, T]: @@ -301,9 +296,13 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) container = container_getter(args, kwargs) with container(additional_context) as container: - return _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + return func(*args, **kwargs, **solved) return auto_injected_func @@ -311,7 +310,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_func( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, T], provide_context: ProvideContext | None = None, ) -> Callable[P, T]: @@ -323,9 +322,13 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) - return _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + return func(*args, **kwargs, **solved) return auto_injected_func @@ -333,7 +336,7 @@ def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_gen_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -349,10 +352,13 @@ async def auto_injected_generator( ) container = container_getter(args, kwargs) with container(additional_context) as container: - async_gen = _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) - async for message in async_gen: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + async for message in func(*args, **kwargs, **solved): yield message return auto_injected_generator @@ -361,7 +367,7 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_gen( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, AsyncIterator[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, AsyncIterator[T]]: @@ -376,10 +382,13 @@ async def auto_injected_generator( for param in additional_params: kwargs.pop(param.name) - async_gen = _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) - async for message in async_gen: + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + async for message in func(*args, **kwargs, **solved): yield message return auto_injected_generator @@ -388,7 +397,7 @@ async def auto_injected_generator( def _get_auto_injected_sync_container_async_func_scoped( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -401,9 +410,13 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: ) container = container_getter(args, kwargs) with container(additional_context) as container: - return await _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + return await func(*args, **kwargs, **solved) return auto_injected_func @@ -411,7 +424,7 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: def _get_auto_injected_sync_container_async_func( container_getter: ContainerGetter[Container], additional_params: Sequence[Parameter], - resolver: DependencyResolver | ParameterDependencyResolver, + dependencies: dict[str, DependencyKey] | ParameterDependencyResolver, func: Callable[P, Awaitable[T]], provide_context: ProvideContext | None = None, ) -> Callable[P, Awaitable[T]]: @@ -423,9 +436,13 @@ async def auto_injected_func(*args: P.args, **kwargs: P.kwargs) -> T: for param in additional_params: kwargs.pop(param.name) - return await _maybe_inject_sync( - container, resolver, func, *args, **kwargs, - ) + if isinstance(dependencies, ParameterDependencyResolver): + dependencies.bind(*args, **kwargs) + solved = { + name: container.get(dep.type_hint, component=dep.component) + for name, dep in dependencies.items() + } + return await func(*args, **kwargs, **solved) return auto_injected_func @@ -634,8 +651,8 @@ def wrap_injection( auto_injected_func = get_auto_injected_func( # type: ignore[operator] func=func, provide_context=provide_context, - resolver=( - DependencyResolver(dependencies) + dependencies=( + dependencies if remove_depends else ParameterDependencyResolver(new_params, dependencies) ), diff --git a/tests/integrations/base/test_parameter_dependency_resolver.py b/tests/integrations/base/test_parameter_dependency_resolver.py index 8417fa200..bf66861a5 100644 --- a/tests/integrations/base/test_parameter_dependency_resolver.py +++ b/tests/integrations/base/test_parameter_dependency_resolver.py @@ -41,7 +41,8 @@ def get_injected_names_factory(func): resolver = ParameterDependencyResolver(params, deps) def get_injected_names(*args, **kw): - return [name for name, _ in resolver(*args, **kw)] + resolver.bind(*args, **kw) + return [name for name, _ in resolver.items()] return get_injected_names From 1a1eea0b16b791d8066bea665f62813a5cdfa294 Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Tue, 24 Jun 2025 23:56:09 +0300 Subject: [PATCH 9/9] Replace list comprehensions with normal loops --- .../test_wrap_injection_remove_depends.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/integrations/base/test_wrap_injection_remove_depends.py b/tests/integrations/base/test_wrap_injection_remove_depends.py index 57938a557..ecd819445 100644 --- a/tests/integrations/base/test_wrap_injection_remove_depends.py +++ b/tests/integrations/base/test_wrap_injection_remove_depends.py @@ -23,7 +23,8 @@ def raises_multiple_values(obj): async def raises_multiple_values_async(obj): with pytest.raises(TypeError, match="multiple values for"): # noqa: PT012 if isasyncgen(obj): - [x async for x in obj] + async for _ in obj: + pass elif iscoroutine(obj): await obj else: @@ -185,9 +186,11 @@ async def test_async_gen(async_container, remove_depends, app_provider): is_async=async_container, ) - [item async for item in wrapped_func([1])] + async for _ in wrapped_func([1]): + pass app_provider.app_mock.assert_called_with(1, 0) - [item async for item in wrapped_func([2], j=3)] + async for _ in wrapped_func([2], j=3): + pass app_provider.app_mock.assert_called_with(2, 3) app_provider.app_mock.reset_mock() @@ -199,14 +202,19 @@ async def test_async_gen(async_container, remove_depends, app_provider): await raises_multiple_values_async(wrapped_func([4], new_dep, j=9)) await raises_multiple_values_async(wrapped_func([5], dep=new_dep, j=9)) else: - [item async for item in wrapped_func([1], new_dep)] + async for _ in wrapped_func([1], new_dep): + pass new_dep.assert_called_with(1, 0) - [item async for item in wrapped_func([2], dep=new_dep)] + async for _ in wrapped_func([2], dep=new_dep): + pass new_dep.assert_called_with(2, 0) - [item async for item in wrapped_func([3], new_dep, 9)] + async for _ in wrapped_func([3], new_dep, 9): + pass new_dep.assert_called_with(3, 9) - [item async for item in wrapped_func([4], new_dep, j=9)] + async for _ in wrapped_func([4], new_dep, j=9): + pass new_dep.assert_called_with(4, 9) - [item async for item in wrapped_func([5], dep=new_dep, j=9)] + async for _ in wrapped_func([5], dep=new_dep, j=9): + pass new_dep.assert_called_with(5, 9) app_provider.app_mock.assert_not_called()