Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/dishka/integrations/litestar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import ParamSpec, TypeVar, get_type_hints

from litestar import Controller, Litestar, Request, Router, WebSocket
from litestar.config.app import AppConfig
from litestar.enums import ScopeType
from litestar.handlers import (
BaseRouteHandler,
Expand All @@ -29,6 +30,12 @@
Send,
)

try:
from litestar.plugins import InitPlugin
HAS_PLUGINS = True
except ImportError:
HAS_PLUGINS = False

from dishka import AsyncContainer, FromDishka, Provider, from_context
from dishka import Scope as DIScope
from dishka.integrations.base import wrap_injection
Expand Down Expand Up @@ -165,3 +172,15 @@ def setup_dishka(container: AsyncContainer, app: Litestar) -> None:
app.asgi_handler,
)
app.state.dishka_container = container

if HAS_PLUGINS:
__all__ += ["DishkaPlugin"]

class DishkaPlugin(InitPlugin):
def __init__(self, container: AsyncContainer):
self.container = container

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.state.dishka_container = self.container
app_config.middleware.append(make_add_request_container_middleware)
return app_config
35 changes: 35 additions & 0 deletions tests/integrations/litestar/test_litestar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
from contextlib import asynccontextmanager
from unittest.mock import Mock

Expand All @@ -8,6 +9,12 @@
from litestar.contrib.htmx.request import HTMXRequest
from litestar.testing import TestClient

try:
importlib.util.find_spec("litestar.plugins.InitPlugin")
HAS_PLUGINS = True
except ImportError:
HAS_PLUGINS = False

from dishka import make_async_container
from dishka.integrations.litestar import (
DishkaRouter,
Expand Down Expand Up @@ -39,6 +46,22 @@ async def dishka_app(
yield TestClient(app)
await container.close()

@asynccontextmanager
async def dishka_plugin_app(
view,
provider,
request_class: type[Request] = Request,
) -> TestClient:
from dishka.integrations.litestar import DishkaPlugin

container = make_async_container(provider)
app = litestar.Litestar(request_class=request_class,
plugins=[DishkaPlugin(container)])
app.register(get("/")(inject(view)))
app.register(websocket_listener("/ws")(websocket_handler))
async with LifespanManager(app):
yield TestClient(app)
await container.close()

@asynccontextmanager
async def dishka_auto_app(
Expand Down Expand Up @@ -90,6 +113,10 @@ async def handler(
(HTMXRequest, dishka_app),
(Request, dishka_auto_app),
(HTMXRequest, dishka_auto_app),
*((
(Request, dishka_plugin_app),
(HTMXRequest, dishka_plugin_app),
) if HAS_PLUGINS else ()),
],
)
@pytest.mark.asyncio
Expand All @@ -116,6 +143,10 @@ async def test_app_dependency(
(HTMXRequest, dishka_app),
(Request, dishka_auto_app),
(HTMXRequest, dishka_auto_app),
*((
(Request, dishka_plugin_app),
(HTMXRequest, dishka_plugin_app),
) if HAS_PLUGINS else ()),
],
)
@pytest.mark.asyncio
Expand All @@ -141,6 +172,10 @@ async def test_request_dependency(
(HTMXRequest, dishka_app),
(Request, dishka_auto_app),
(HTMXRequest, dishka_auto_app),
*((
(Request, dishka_plugin_app),
(HTMXRequest, dishka_plugin_app),
) if HAS_PLUGINS else ()),
],
)
@pytest.mark.asyncio
Expand Down
46 changes: 46 additions & 0 deletions tests/integrations/litestar/test_litestar_websockets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from unittest.mock import Mock
Expand All @@ -8,6 +9,12 @@
from litestar.handlers import WebsocketListener
from litestar.testing import TestClient

try:
importlib.util.find_spec("litestar.plugins.InitPlugin")
HAS_PLUGINS = True
except ImportError:
HAS_PLUGINS = False

from dishka import make_async_container
from dishka.integrations.litestar import (
DishkaRouter,
Expand Down Expand Up @@ -48,6 +55,21 @@ async def dishka_auto_app(view, provider) -> AsyncGenerator[TestClient, None]:
await container.close()


@asynccontextmanager
async def dishka_plugin_app(
view, provider,
) -> AsyncGenerator[TestClient, None]:
from dishka.integrations.litestar import DishkaPlugin

router = DishkaRouter("", route_handlers=[])
router.register(view)
container = make_async_container(provider)
app = Litestar([router], debug=True, plugins=[DishkaPlugin(container)])
async with LifespanManager(app):
yield TestClient(app)
await container.close()


@websocket_listener("/")
@inject_websocket
async def get_with_app(
Expand Down Expand Up @@ -104,6 +126,12 @@ async def on_receive(
(dishka_auto_app, auto_get_with_app),
(dishka_app, GetWithApp),
(dishka_auto_app, AutoGetWithApp),
*((
(dishka_plugin_app, get_with_app),
(dishka_plugin_app, auto_get_with_app),
(dishka_plugin_app, GetWithApp),
(dishka_plugin_app, AutoGetWithApp),
) if HAS_PLUGINS else ()),
],
)
async def test_app_dependency(
Expand Down Expand Up @@ -177,6 +205,12 @@ async def on_receive(
(dishka_auto_app, auto_get_with_request),
(dishka_app, GetWithRequest),
(dishka_auto_app, AutoGetWithRequest),
*((
(dishka_plugin_app, get_with_request),
(dishka_plugin_app, auto_get_with_request),
(dishka_plugin_app, GetWithRequest),
(dishka_plugin_app, AutoGetWithRequest),
) if HAS_PLUGINS else ()),
],
)
async def test_request_dependency(
Expand All @@ -200,6 +234,12 @@ async def test_request_dependency(
(dishka_auto_app, auto_get_with_request),
(dishka_app, GetWithRequest),
(dishka_auto_app, AutoGetWithRequest),
*((
(dishka_plugin_app, get_with_request),
(dishka_plugin_app, auto_get_with_request),
(dishka_plugin_app, GetWithRequest),
(dishka_plugin_app, AutoGetWithRequest),
) if HAS_PLUGINS else ()),
],
)
async def test_request_dependency2(
Expand Down Expand Up @@ -279,6 +319,12 @@ async def on_receive(
(dishka_auto_app, auto_get_with_websocket),
(dishka_app, GetWithWebsocket),
(dishka_auto_app, AutoGetWithWebsocket),
*((
(dishka_plugin_app, get_with_websocket),
(dishka_plugin_app, auto_get_with_websocket),
(dishka_plugin_app, GetWithWebsocket),
(dishka_plugin_app, AutoGetWithWebsocket),
) if HAS_PLUGINS else ()),
],
)
async def test_websocket_dependency(
Expand Down
Loading