Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/dishka/integrations/litestar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
__all__ = [
"DishkaMiddleware",
"DishkaPlugin",
"FromDishka",
"LitestarProvider",
"inject",
Expand All @@ -11,7 +13,10 @@
from typing import ParamSpec, TypeVar, get_type_hints

from litestar import Litestar, Request, WebSocket
from litestar.config.app import AppConfig
from litestar.enums import ScopeType
from litestar.middleware import ASGIMiddleware
from litestar.plugins import InitPlugin
from litestar.types import ASGIApp, Receive, Scope, Send

from dishka import AsyncContainer, FromDishka, Provider, from_context
Expand Down Expand Up @@ -94,3 +99,46 @@ def setup_dishka(container: AsyncContainer, app: Litestar) -> None:
app.asgi_handler,
)
app.state.dishka_container = container


class DishkaMiddleware(ASGIMiddleware):
scopes = (ScopeType.HTTP, ScopeType.WEBSOCKET)
async def handle(
self, scope: Scope, receive: Receive,
send: Send, next_app: ASGIApp,
) -> None:
if scope.get("type") not in (ScopeType.HTTP, ScopeType.WEBSOCKET):
await next_app(scope, receive, send)
return

if scope.get("type") == ScopeType.HTTP:
request: Request = Request(scope)
context = {Request: request}
di_scope = DIScope.REQUEST
async with request.app.state.dishka_container(
context,
scope=di_scope,
) as request_container:
request.state.dishka_container = request_container
await next_app(scope, receive, send)

elif scope.get("type") == ScopeType.WEBSOCKET:
websocket: WebSocket = WebSocket(scope)
context = {WebSocket: websocket}
di_scope = DIScope.SESSION
async with websocket.app.state.dishka_container(
context,
scope=di_scope,
) as request_container:
websocket.state.dishka_container = request_container
await next_app(scope, receive, send)


class DishkaPlugin(InitPlugin):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add plugin automatically inside setup_dishka?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think so,

the current api is

    app = litestar.Litestar()
    container = make_async_container(provider)
    setup_dishka(container, app)

and this PR just adds the possibility to do:

    container = make_async_container(provider)
    app = litestar.Litestar(plugins=[DishkaPlugin(container=container)])

I dont know a litestar's built-in way to add a plugin other than using the plugins kwarg unfortunately.

def __init__(self, container: AsyncContainer):
self.container = container

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.state.dishka_container = self.container
app_config.middleware.append(DishkaMiddleware())
return app_config
71 changes: 51 additions & 20 deletions tests/integrations/litestar/test_litestar.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from unittest.mock import Mock

import litestar
import pytest
from asgi_lifespan import LifespanManager
from litestar import Request, get, websocket_listener
from litestar.contrib.htmx.request import HTMXRequest
from litestar.testing import TestClient
from litestar.testing import AsyncTestClient

from dishka import make_async_container
from dishka.integrations.litestar import (
DishkaPlugin,
FromDishka,
inject,
setup_dishka,
Expand All @@ -24,21 +25,37 @@


@asynccontextmanager
async def dishka_app(
async def dishka_app_via_setup(
view,
provider,
request_class: type[Request] = Request,
) -> TestClient:
app = litestar.Litestar(request_class=request_class)
) -> AsyncGenerator[AsyncTestClient]:
app = litestar.Litestar(request_class=request_class, debug=True)
app.register(get("/")(inject(view)))
app.register(websocket_listener("/ws")(websocket_handler))
container = make_async_container(provider)
setup_dishka(container, app)
async with LifespanManager(app):
yield litestar.testing.TestClient(app)
async with AsyncTestClient(app) as client:
yield client
await container.close()

@asynccontextmanager
async def dishka_app_via_plugin(
view,
provider,
request_class: type[Request] = Request,
) -> AsyncGenerator[AsyncTestClient]:
container = make_async_container(provider)
app = litestar.Litestar(request_class=request_class,
plugins=[DishkaPlugin(container=container)])
app.register(get("/")(inject(view)))
app.register(websocket_listener("/ws")(websocket_handler))
async with AsyncTestClient(app) as client:
yield client
await container.close()



async def websocket_handler(data: str):
pass

Expand All @@ -65,53 +82,67 @@ async def handler(
return handler



@pytest.mark.parametrize("client_setup", [dishka_app_via_setup,
dishka_app_via_plugin])
@pytest.mark.parametrize("request_class", [Request, HTMXRequest])
@pytest.mark.asyncio
async def test_app_dependency(request_class, app_provider: AppProvider):
async with dishka_app(get_with_app(request_class), app_provider) as client:
client.get("/")
async def test_app_dependency(request_class, app_provider: AppProvider,
client_setup: AsyncTestClient) -> None:
async with client_setup(get_with_app(request_class),
app_provider) as client:
await client.get("/")
app_provider.mock.assert_called_with(APP_DEP_VALUE)
app_provider.app_released.assert_not_called()
app_provider.app_released.assert_called()


@pytest.mark.parametrize("client_setup", [dishka_app_via_setup,
dishka_app_via_plugin])
@pytest.mark.parametrize("request_class", [Request, HTMXRequest])
@pytest.mark.asyncio
async def test_request_dependency(request_class, app_provider: AppProvider):
async with dishka_app(
async def test_request_dependency(request_class, app_provider: AppProvider,
client_setup: AsyncTestClient) -> None:
async with client_setup(
get_with_request(request_class),
app_provider,
request_class,
) as client:
client.get("/")
await client.get("/")
app_provider.mock.assert_called_with(REQUEST_DEP_VALUE)
app_provider.request_released.assert_called_once()


@pytest.mark.parametrize("client_setup", [dishka_app_via_setup,
dishka_app_via_plugin])
@pytest.mark.parametrize("request_class", [Request, HTMXRequest])
@pytest.mark.asyncio
async def test_request_dependency2(request_class, app_provider: AppProvider):
async with dishka_app(
async def test_request_dependency2(request_class, app_provider: AppProvider,
client_setup: AsyncTestClient) -> None:
async with client_setup(
get_with_request(request_class),
app_provider,
request_class,
) as client:
client.get("/")
await client.get("/")
app_provider.mock.assert_called_with(REQUEST_DEP_VALUE)
app_provider.mock.reset_mock()
app_provider.request_released.assert_called_once()
app_provider.request_released.reset_mock()
client.get("/")
await client.get("/")
app_provider.mock.assert_called_with(REQUEST_DEP_VALUE)
app_provider.request_released.assert_called_once()


@pytest.mark.parametrize("client_setup", [dishka_app_via_setup,
dishka_app_via_plugin])
@pytest.mark.asyncio
async def test_request_middleware(app_provider: AppProvider):
async with dishka_app(
async def test_request_middleware(app_provider: AppProvider,
client_setup: AsyncTestClient) -> None:
async with client_setup(
get_with_request(Request),
app_provider,
Request,
) as client:
with client.websocket_connect("/ws") as websocket:
with await client.websocket_connect("/ws") as websocket:
websocket.send("test")
27 changes: 17 additions & 10 deletions tests/integrations/litestar/test_litestar_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from unittest.mock import Mock

import pytest
from asgi_lifespan import LifespanManager
from litestar import Litestar, websocket_listener
from litestar.handlers import WebsocketListener
from litestar.testing import TestClient
Expand All @@ -12,7 +11,7 @@
from dishka.integrations.litestar import (
FromDishka,
inject_websocket,
setup_dishka,
setup_dishka, DishkaPlugin,
)
from ..common import (
APP_DEP_VALUE,
Expand All @@ -26,14 +25,21 @@


@asynccontextmanager
async def dishka_app(view, provider) -> AsyncGenerator[TestClient, None]:
async def dishka_app_via_setup(view, provider) -> AsyncGenerator[TestClient, None]:
app = Litestar([view], debug=True)
container = make_async_container(provider)
setup_dishka(container, app)
async with LifespanManager(app):
yield TestClient(app)
with TestClient(app) as client:
yield client
Comment thread
euri10 marked this conversation as resolved.
Outdated
await container.close()

@asynccontextmanager
async def dishka_app_via_plugin(view, provider) -> AsyncGenerator[TestClient, None]:
container = make_async_container(provider)
app = Litestar([view], debug=True, plugins=[DishkaPlugin(container=container)])
with TestClient(app) as client:
yield client
await container.close()

@websocket_listener("/")
@inject_websocket
Expand Down Expand Up @@ -61,9 +67,10 @@ async def on_receive(


@pytest.mark.asyncio
@pytest.mark.parametrize("setup_client", [dishka_app_via_setup, dishka_app_via_plugin])
@pytest.mark.parametrize("view", [get_with_app, GetWithApp])
async def test_app_dependency(view, ws_app_provider: WebSocketAppProvider):
async with dishka_app(view, ws_app_provider) as client:
async def test_app_dependency(view, ws_app_provider: WebSocketAppProvider, setup_client: TestClient) -> None:
async with setup_client(view, ws_app_provider) as client:
with client.websocket_connect("/") as connection:
connection.send_text("...")
assert connection.receive_text() == "passed"
Expand Down Expand Up @@ -104,7 +111,7 @@ async def test_request_dependency(
view,
ws_app_provider: WebSocketAppProvider,
):
async with dishka_app(view, ws_app_provider) as client:
async with dishka_app_via_setup(view, ws_app_provider) as client:
with client.websocket_connect("/") as connection:
connection.send_text("...")
assert connection.receive_text() == "passed"
Expand All @@ -118,7 +125,7 @@ async def test_request_dependency2(
view,
ws_app_provider: WebSocketAppProvider,
):
async with dishka_app(view, ws_app_provider) as client:
async with dishka_app_via_setup(view, ws_app_provider) as client:
with client.websocket_connect("/") as connection:
connection.send_text("...")
assert connection.receive_text() == "passed"
Expand Down Expand Up @@ -165,7 +172,7 @@ async def test_websocket_dependency(
view,
ws_app_provider: WebSocketAppProvider,
):
async with dishka_app(view, ws_app_provider) as client:
async with dishka_app_via_setup(view, ws_app_provider) as client:
with client.websocket_connect("/") as connection:
connection.send_text("...")
assert connection.receive_text() == "passed"
Expand Down