diff --git a/src/dishka/integrations/litestar.py b/src/dishka/integrations/litestar.py index 12d90ef84..8738ec4da 100644 --- a/src/dishka/integrations/litestar.py +++ b/src/dishka/integrations/litestar.py @@ -12,6 +12,7 @@ from typing import ParamSpec, TypeVar, get_type_hints from litestar import Controller, Litestar, Request, Router, WebSocket +from litestar.config.app import AppConfig from litestar.enums import ScopeType from litestar.handlers import ( BaseRouteHandler, @@ -29,6 +30,12 @@ Send, ) +try: + from litestar.plugins import InitPlugin + HAS_PLUGINS = True +except ImportError: + HAS_PLUGINS = False + from dishka import AsyncContainer, FromDishka, Provider, from_context from dishka import Scope as DIScope from dishka.integrations.base import wrap_injection @@ -165,3 +172,15 @@ def setup_dishka(container: AsyncContainer, app: Litestar) -> None: app.asgi_handler, ) app.state.dishka_container = container + +if HAS_PLUGINS: + __all__ += ["DishkaPlugin"] + + class DishkaPlugin(InitPlugin): + def __init__(self, container: AsyncContainer): + self.container = container + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + app_config.state.dishka_container = self.container + app_config.middleware.append(make_add_request_container_middleware) + return app_config diff --git a/tests/integrations/litestar/test_litestar.py b/tests/integrations/litestar/test_litestar.py index abc9a2816..13d5ec123 100644 --- a/tests/integrations/litestar/test_litestar.py +++ b/tests/integrations/litestar/test_litestar.py @@ -1,3 +1,4 @@ +import importlib.util from contextlib import asynccontextmanager from unittest.mock import Mock @@ -8,6 +9,12 @@ from litestar.contrib.htmx.request import HTMXRequest from litestar.testing import TestClient +try: + importlib.util.find_spec("litestar.plugins.InitPlugin") + HAS_PLUGINS = True +except ImportError: + HAS_PLUGINS = False + from dishka import make_async_container from dishka.integrations.litestar import ( DishkaRouter, @@ -39,6 +46,22 @@ async def dishka_app( yield TestClient(app) await container.close() +@asynccontextmanager +async def dishka_plugin_app( + view, + provider, + request_class: type[Request] = Request, +) -> TestClient: + from dishka.integrations.litestar import DishkaPlugin + + container = make_async_container(provider) + app = litestar.Litestar(request_class=request_class, + plugins=[DishkaPlugin(container)]) + app.register(get("/")(inject(view))) + app.register(websocket_listener("/ws")(websocket_handler)) + async with LifespanManager(app): + yield TestClient(app) + await container.close() @asynccontextmanager async def dishka_auto_app( @@ -90,6 +113,10 @@ async def handler( (HTMXRequest, dishka_app), (Request, dishka_auto_app), (HTMXRequest, dishka_auto_app), + *(( + (Request, dishka_plugin_app), + (HTMXRequest, dishka_plugin_app), + ) if HAS_PLUGINS else ()), ], ) @pytest.mark.asyncio @@ -116,6 +143,10 @@ async def test_app_dependency( (HTMXRequest, dishka_app), (Request, dishka_auto_app), (HTMXRequest, dishka_auto_app), + *(( + (Request, dishka_plugin_app), + (HTMXRequest, dishka_plugin_app), + ) if HAS_PLUGINS else ()), ], ) @pytest.mark.asyncio @@ -141,6 +172,10 @@ async def test_request_dependency( (HTMXRequest, dishka_app), (Request, dishka_auto_app), (HTMXRequest, dishka_auto_app), + *(( + (Request, dishka_plugin_app), + (HTMXRequest, dishka_plugin_app), + ) if HAS_PLUGINS else ()), ], ) @pytest.mark.asyncio diff --git a/tests/integrations/litestar/test_litestar_websockets.py b/tests/integrations/litestar/test_litestar_websockets.py index f7c398a72..acd0e9a46 100644 --- a/tests/integrations/litestar/test_litestar_websockets.py +++ b/tests/integrations/litestar/test_litestar_websockets.py @@ -1,3 +1,4 @@ +import importlib.util from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from unittest.mock import Mock @@ -8,6 +9,12 @@ from litestar.handlers import WebsocketListener from litestar.testing import TestClient +try: + importlib.util.find_spec("litestar.plugins.InitPlugin") + HAS_PLUGINS = True +except ImportError: + HAS_PLUGINS = False + from dishka import make_async_container from dishka.integrations.litestar import ( DishkaRouter, @@ -48,6 +55,21 @@ async def dishka_auto_app(view, provider) -> AsyncGenerator[TestClient, None]: await container.close() +@asynccontextmanager +async def dishka_plugin_app( + view, provider, +) -> AsyncGenerator[TestClient, None]: + from dishka.integrations.litestar import DishkaPlugin + + router = DishkaRouter("", route_handlers=[]) + router.register(view) + container = make_async_container(provider) + app = Litestar([router], debug=True, plugins=[DishkaPlugin(container)]) + async with LifespanManager(app): + yield TestClient(app) + await container.close() + + @websocket_listener("/") @inject_websocket async def get_with_app( @@ -104,6 +126,12 @@ async def on_receive( (dishka_auto_app, auto_get_with_app), (dishka_app, GetWithApp), (dishka_auto_app, AutoGetWithApp), + *(( + (dishka_plugin_app, get_with_app), + (dishka_plugin_app, auto_get_with_app), + (dishka_plugin_app, GetWithApp), + (dishka_plugin_app, AutoGetWithApp), + ) if HAS_PLUGINS else ()), ], ) async def test_app_dependency( @@ -177,6 +205,12 @@ async def on_receive( (dishka_auto_app, auto_get_with_request), (dishka_app, GetWithRequest), (dishka_auto_app, AutoGetWithRequest), + *(( + (dishka_plugin_app, get_with_request), + (dishka_plugin_app, auto_get_with_request), + (dishka_plugin_app, GetWithRequest), + (dishka_plugin_app, AutoGetWithRequest), + ) if HAS_PLUGINS else ()), ], ) async def test_request_dependency( @@ -200,6 +234,12 @@ async def test_request_dependency( (dishka_auto_app, auto_get_with_request), (dishka_app, GetWithRequest), (dishka_auto_app, AutoGetWithRequest), + *(( + (dishka_plugin_app, get_with_request), + (dishka_plugin_app, auto_get_with_request), + (dishka_plugin_app, GetWithRequest), + (dishka_plugin_app, AutoGetWithRequest), + ) if HAS_PLUGINS else ()), ], ) async def test_request_dependency2( @@ -279,6 +319,12 @@ async def on_receive( (dishka_auto_app, auto_get_with_websocket), (dishka_app, GetWithWebsocket), (dishka_auto_app, AutoGetWithWebsocket), + *(( + (dishka_plugin_app, get_with_websocket), + (dishka_plugin_app, auto_get_with_websocket), + (dishka_plugin_app, GetWithWebsocket), + (dishka_plugin_app, AutoGetWithWebsocket), + ) if HAS_PLUGINS else ()), ], ) async def test_websocket_dependency(