diff --git a/marimo/_code_mode/_context.py b/marimo/_code_mode/_context.py index 767f182e121..e2d6c31d1c6 100644 --- a/marimo/_code_mode/_context.py +++ b/marimo/_code_mode/_context.py @@ -54,6 +54,10 @@ _UpdateOp, _validate_ops, ) +from marimo._code_mode.screenshot_meta import ( + SCREENSHOT_AUTH_TOKEN_KEY, + SCREENSHOT_SERVER_URL_KEY, +) from marimo._messaging.errors import Error from marimo._messaging.notebook.changes import ( CreateCell, @@ -1272,19 +1276,20 @@ async def screenshot( "available. screenshot() must be called during cell " "execution (e.g. from code-mode)." ) - # Read trusted server URL and auth token injected by the - # /execute endpoint (from server config, not request headers). + + # Trusted server URL + auth token injected by the /execute + # endpoint (from server config, not request headers). server_url = cast( - "str | None", request.meta.get("screenshot_server_url") + "str | None", request.meta.get(SCREENSHOT_SERVER_URL_KEY) ) if server_url is None: raise ScreenshotError( - "Cannot take screenshots: screenshot_server_url not " + "Cannot take screenshots: screenshot credentials not " "found in request.meta. This endpoint may not " "support screenshots." ) screenshot_auth_token = cast( - "str | None", request.meta.get("screenshot_auth_token") + "str | None", request.meta.get(SCREENSHOT_AUTH_TOKEN_KEY) ) # Lazy-init the screenshot session (browser reuse). diff --git a/marimo/_code_mode/screenshot_meta.py b/marimo/_code_mode/screenshot_meta.py new file mode 100644 index 00000000000..3ad996565d4 --- /dev/null +++ b/marimo/_code_mode/screenshot_meta.py @@ -0,0 +1,14 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Meta keys for wiring screenshot credentials through ``HTTPRequest.meta``. + +Code-mode tools running in the kernel need to call back into the marimo +server (e.g. ``ctx.screenshot()`` driving Playwright against the +kiosk page). The server stamps a trusted ``server_url`` and +``auth_token`` onto each control request's ``meta`` dict; the runtime +side reads them when building the screenshot session. +""" + +from __future__ import annotations + +SCREENSHOT_SERVER_URL_KEY = "screenshot_server_url" +SCREENSHOT_AUTH_TOKEN_KEY = "screenshot_auth_token" diff --git a/marimo/_server/ai/tools/code_mode.py b/marimo/_server/ai/tools/code_mode.py new file mode 100644 index 00000000000..667d59162f2 --- /dev/null +++ b/marimo/_server/ai/tools/code_mode.py @@ -0,0 +1,55 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from marimo._ai._tools.types import CodeExecutionResult +from marimo._server.api.deps import AppState +from marimo._server.api.utils import get_code_mode_credentials +from marimo._server.scratchpad import run_scratchpad_code + +if TYPE_CHECKING: + from pydantic_ai import FunctionToolset + from starlette.requests import Request + + from marimo._session.session import Session + + +def build_execute_code_toolset( + session: Session, + request: Request, +) -> FunctionToolset[None]: + """Build a ``FunctionToolset`` exposing one tool: ``execute_code``. + + The tool is bound to the caller's *session* and *request*; the model + never sees or passes a session id. Screenshot credentials are derived + per tool call from the request so ``ctx.screenshot()`` can call back + into this server (see ``marimo/_code_mode/_context.py``). + """ + + from pydantic_ai import FunctionToolset + + toolset: FunctionToolset[None] = FunctionToolset() + + async def execute_code(code: str) -> CodeExecutionResult: + """Run Python inside the running notebook's kernel scratchpad. + + Use this for all notebook mutations via ``marimo._code_mode``. + """ + server_url, auth_token = get_code_mode_credentials( + AppState(request), request + ) + return await run_scratchpad_code( + session, + request, + code=code, + server_url=server_url, + auth_token=auth_token, + ) + + toolset.add_function( + execute_code, + name="execute_code", + description=execute_code.__doc__, + ) + return toolset diff --git a/marimo/_server/api/endpoints/execution.py b/marimo/_server/api/endpoints/execution.py index e4aa718fbc1..1ded712fb47 100644 --- a/marimo/_server/api/endpoints/execution.py +++ b/marimo/_server/api/endpoints/execution.py @@ -9,6 +9,10 @@ from starlette.responses import JSONResponse, StreamingResponse from marimo import _loggers +from marimo._code_mode.screenshot_meta import ( + SCREENSHOT_AUTH_TOKEN_KEY, + SCREENSHOT_SERVER_URL_KEY, +) from marimo._messaging.notification import AlertNotification from marimo._runtime.commands import HTTPRequest, UpdateUIElementCommand from marimo._server.api.deps import AppState @@ -16,7 +20,11 @@ FILE_QUERY_PARAM_KEY, ) from marimo._server.api.endpoints.ws_endpoint import DOC_MANAGER -from marimo._server.api.utils import dispatch_control_request, parse_request +from marimo._server.api.utils import ( + dispatch_control_request, + get_code_mode_credentials, + parse_request, +) from marimo._server.models.models import ( BaseResponse, DebugCellRequest, @@ -312,19 +320,11 @@ async def sse_generator() -> AsyncGenerator[str, None]: with session.scoped(listener): async with session.scratchpad_lock: http_req = HTTPRequest.from_request(request) - # Inject trusted server URL and auth token for - # code-mode screenshot support. We use the - # server's own host/port (from config) rather - # than the request's Host header to prevent - # header-spoofing attacks. - http_req.meta["screenshot_auth_token"] = str( - app_state.session_manager.auth_token - ) - base_url = app_state.base_url.rstrip("/") - scheme = request.url.scheme or "http" - http_req.meta["screenshot_server_url"] = ( - f"{scheme}://{app_state.host}:{app_state.port}{base_url}" + server_url, auth_token = get_code_mode_credentials( + app_state, request ) + http_req.meta[SCREENSHOT_SERVER_URL_KEY] = server_url + http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] = auth_token session.put_control_request( ExecuteScratchpadCommand( code=body.code, diff --git a/marimo/_server/api/lifespans.py b/marimo/_server/api/lifespans.py index e767ac11384..b3316bb9381 100644 --- a/marimo/_server/api/lifespans.py +++ b/marimo/_server/api/lifespans.py @@ -3,8 +3,6 @@ import asyncio import contextlib -import ipaddress -import socket from typing import TYPE_CHECKING, Any from marimo import _loggers @@ -12,7 +10,10 @@ from marimo._server.ai.tools.tool_manager import setup_tool_manager from marimo._server.api.deps import AppState, AppStateBase from marimo._server.api.interrupt import InterruptHandler -from marimo._server.api.utils import open_url_in_browser +from marimo._server.api.utils import ( + format_url_host, + open_url_in_browser, +) from marimo._server.lsp import any_lsp_server_running from marimo._server.print import ( print_experimental_features, @@ -271,48 +272,10 @@ async def reap_subprocesses(app: Starlette) -> AsyncIterator[None]: await cancel_pending_reaps() -def _pretty_host(host: str, port: int) -> str: - """Replace loopback addresses with 'localhost' for display. - - Uses ipaddress for a reliable cross-platform loopback check (covers - 127.0.0.1, ::1, and the full 127.0.0.0/8 range). Falls back to - socket.getnameinfo only for non-IP hosts. getnameinfo is skipped for - raw IP addresses because it can hang on Windows/CI for link-local IPv6. - """ - try: - if ipaddress.ip_address(host).is_loopback: - return "localhost" - except ValueError: - # Not a valid IP literal — might be a hostname; try getnameinfo - try: - if ( - socket.getnameinfo((host, port), socket.NI_NOFQDN)[0] - == "localhost" - ): - return "localhost" - except Exception: - pass - return host - - def _startup_url(state: AppStateBase) -> str: - host = state.host.strip( - "[]" - ) # normalize: remove brackets if user passed [addr] + url_host = format_url_host(state.host, state.port) port = state.port - # Strip IPv6 zone ID (e.g. fe80::1%eth0 -> fe80::1); zone IDs are - # interface-specific and not valid in URLs. - # Must happen before _pretty_host — zone IDs can cause getnameinfo - # to hang on Windows/CI. - host = host.split("%")[0] - - # pretty printing: show "localhost" for loopback addresses - host = _pretty_host(host, port) - - url_host_bare = host - # IPv6 addresses must be wrapped in brackets in URLs (RFC 3986) - url_host = f"[{url_host_bare}]" if ":" in url_host_bare else url_host_bare url = f"http://{url_host}:{port}{state.base_url}" if port == 80: url = f"http://{url_host}{state.base_url}" @@ -325,18 +288,10 @@ def _startup_url(state: AppStateBase) -> str: def _mcp_startup_url(state: AppStateBase) -> str: - host = state.host.strip( - "[]" - ) # normalize: remove brackets if user passed [addr] + url_host = format_url_host(state.host, state.port) port = state.port base_url = state.base_url - # Strip zone ID, then pretty-print loopback (same logic as _startup_url) - host = host.split("%")[0] - host = _pretty_host(host, port) - - url_host_bare = host - url_host = f"[{url_host_bare}]" if ":" in url_host_bare else url_host_bare # Construct MCP endpoint URL mcp_prefix = "/mcp" mcp_name = "server" diff --git a/marimo/_server/api/utils.py b/marimo/_server/api/utils.py index 08a10fdda9a..91d468e3958 100644 --- a/marimo/_server/api/utils.py +++ b/marimo/_server/api/utils.py @@ -31,6 +31,7 @@ from starlette.datastructures import UploadFile from starlette.requests import Request + from marimo._server.api.deps import AppStateBase from marimo._session.session import Session @@ -113,6 +114,89 @@ async def dispatch_control_request( return SuccessResponse() +def pretty_host(host: str, port: int) -> str: + """Replace loopback addresses with 'localhost' for display. + + Uses ipaddress for a reliable cross-platform loopback check (covers + 127.0.0.1, ::1, and the full 127.0.0.0/8 range). Falls back to + socket.getnameinfo only for non-IP hosts. getnameinfo is skipped + for raw IP addresses because it can hang on Windows/CI for + link-local IPv6. + """ + import ipaddress + import socket + + try: + if ipaddress.ip_address(host).is_loopback: + return "localhost" + except ValueError: + # Not a valid IP literal — might be a hostname; try getnameinfo + try: + if ( + socket.getnameinfo((host, port), socket.NI_NOFQDN)[0] + == "localhost" + ): + return "localhost" + except Exception: + pass + return host + + +def format_url_host( + host: str, + port: int, + *, + route_bind_all_to_loopback: bool = False, +) -> str: + """Normalize ``host`` for use in a URL authority. + + Transforms (in order): + + 1. Strip user-supplied brackets (``[::1]`` -> ``::1``). + 2. Drop IPv6 zone IDs (``fe80::1%eth0`` -> ``fe80::1``); zone IDs + are interface-specific and not valid in URLs. (Must happen + before :func:`pretty_host` — zone IDs can cause getnameinfo + to hang on Windows/CI.) + 3. Replace loopback addresses with ``"localhost"``. + 4. If ``route_bind_all_to_loopback``, replace the bind-all + sentinels ``0.0.0.0`` / ``::`` with ``"localhost"`` so internal + callback URLs reach a routable target (the bind-all addresses + are not valid destinations for an HTTP client / Playwright). + Display URLs leave them alone — the user knows what they + configured. + 5. Wrap bare IPv6 literals in brackets per RFC 3986. + """ + host = host.strip("[]").split("%")[0] + host = pretty_host(host, port) + if route_bind_all_to_loopback and host in {"0.0.0.0", "::"}: + host = "localhost" + return f"[{host}]" if ":" in host else host + + +def get_code_mode_credentials( + app_state: AppStateBase, request: Request +) -> tuple[str, str]: + """Return ``(server_url, auth_token)`` for code-mode tools that + call back into this marimo server (e.g. ``ctx.screenshot()`` + driving Playwright against the running notebook UI). + + The URL is built from the server's configured ``host``/``port`` + rather than the request's ``Host`` header so a spoofed header + can't redirect the auth token to an attacker-controlled URL. + ``host`` is also normalized for URL safety (IPv6 bracketing, zone-ID + stripping) and for routability (bind-all addresses mapped to + ``localhost``) so the callback actually reaches the server. + """ + auth_token = str(app_state.session_manager.auth_token) + base_url = app_state.base_url.rstrip("/") + scheme = request.url.scheme or "http" + url_host = format_url_host( + app_state.host, app_state.port, route_bind_all_to_loopback=True + ) + server_url = f"{scheme}://{url_host}:{app_state.port}{base_url}" + return server_url, auth_token + + def parse_title(filepath: str | None) -> str: """ Create a title from a filename. diff --git a/marimo/_server/scratchpad.py b/marimo/_server/scratchpad.py index e5ddb989937..de7085d2480 100644 --- a/marimo/_server/scratchpad.py +++ b/marimo/_server/scratchpad.py @@ -6,20 +6,29 @@ import asyncio import json from typing import TYPE_CHECKING, Any, TypedDict +from uuid import uuid4 from marimo._ai._tools.types import CodeExecutionResult +from marimo._code_mode.screenshot_meta import ( + SCREENSHOT_AUTH_TOKEN_KEY, + SCREENSHOT_SERVER_URL_KEY, +) from marimo._messaging.cell_output import CellChannel from marimo._messaging.notification import ( CellNotification, CompletedRunNotification, ) from marimo._messaging.serde import deserialize_kernel_message +from marimo._runtime.commands import ExecuteScratchpadCommand, HTTPRequest from marimo._runtime.scratch import SCRATCH_CELL_ID +from marimo._server.models.models import InstantiateNotebookRequest from marimo._session.extensions.types import EventAwareExtension if TYPE_CHECKING: from collections.abc import AsyncGenerator + from starlette.requests import Request + from marimo._messaging.types import KernelMessage from marimo._session.session import Session @@ -286,3 +295,57 @@ def extract_result( stderr=stderr, errors=errors, ) + + +async def run_scratchpad_code( + session: Session, + request: Request, + *, + code: str, + server_url: str, + auth_token: str, + timeout: float = EXECUTION_TIMEOUT, +) -> CodeExecutionResult: + """Drive the kernel scratchpad on behalf of code-mode (AI tool). + + ``server_url`` and ``auth_token`` are stamped onto ``http_req.meta`` + so ``ctx.screenshot()`` from inside code-mode can authenticate + Playwright against this server (see ``marimo/_code_mode/_context.py``). + """ + http_req = HTTPRequest.from_request(request) + http_req.meta[SCREENSHOT_SERVER_URL_KEY] = server_url + http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] = auth_token + + session.instantiate( + InstantiateNotebookRequest(object_ids=[], values=[], auto_run=False), + http_request=http_req, + ) + + run_id = str(uuid4()) + listener = ScratchCellListener(run_id=run_id) + + with session.scoped(listener): + async with session.scratchpad_lock: + session.put_control_request( + ExecuteScratchpadCommand( + code=code, + request=http_req, + notebook_cells=tuple(session.document.cells), + run_id=run_id, + ), + from_consumer_id=None, + ) + settled = False + try: + await listener.wait(timeout=timeout) + settled = not listener.timed_out + finally: + if not settled: + session.try_interrupt() + if listener.timed_out: + return CodeExecutionResult( + success=False, + errors=[f"Execution timed out after {timeout}s"], + ) + + return extract_result(session, listener) diff --git a/tests/_server/api/endpoints/test_execution.py b/tests/_server/api/endpoints/test_execution.py index 8d5e28be7aa..6e7b16f2947 100644 --- a/tests/_server/api/endpoints/test_execution.py +++ b/tests/_server/api/endpoints/test_execution.py @@ -8,6 +8,10 @@ import pytest +from marimo._code_mode.screenshot_meta import ( + SCREENSHOT_AUTH_TOKEN_KEY, + SCREENSHOT_SERVER_URL_KEY, +) from marimo._types.ids import CellId_t, SessionId from marimo._utils.lists import first from tests._server.mocks import ( @@ -206,10 +210,13 @@ async def empty_stream(self: object): # noqa: ARG001 assert len(scratchpad_cmds) == 1, ( f"expected one ExecuteScratchpadCommand, got {captured!r}" ) - meta = scratchpad_cmds[0].request.meta - assert meta["screenshot_auth_token"] == "fake-token" + http_req = scratchpad_cmds[0].request + assert http_req is not None # Mock server uses host="localhost", port=1234, base_url="" - assert meta["screenshot_server_url"] == "http://localhost:1234" + assert http_req.meta[SCREENSHOT_SERVER_URL_KEY] == ( + "http://localhost:1234" + ) + assert http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] == "fake-token" @staticmethod @with_session(SESSION_ID) diff --git a/tests/_server/api/test_api_utils.py b/tests/_server/api/test_api_utils.py index 56e009e543d..58ff15d6518 100644 --- a/tests/_server/api/test_api_utils.py +++ b/tests/_server/api/test_api_utils.py @@ -1,19 +1,26 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING +from types import SimpleNamespace +from typing import TYPE_CHECKING, cast import msgspec import pytest from starlette.applications import Starlette +from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route from starlette.testclient import TestClient -from marimo._server.api.utils import parse_multipart_request +from marimo._server.api.utils import ( + get_code_mode_credentials, + parse_multipart_request, +) if TYPE_CHECKING: - from starlette.requests import Request + from starlette.types import Scope + + from marimo._server.api.deps import AppStateBase class _SampleForm(msgspec.Struct): @@ -70,3 +77,128 @@ def test_parse_multipart_request_raises_on_missing_field() -> None: client = _build_app(captured) with pytest.raises(msgspec.ValidationError): client.post("/test", data={"name": "marimo"}) + + +def _fake_app_state( + *, + host: str = "localhost", + port: int = 2718, + base_url: str = "", + auth_token: str = "tok", +) -> AppStateBase: + """Minimal duck-typed stand-in for AppStateBase exposing only the + attributes that get_code_mode_credentials reads.""" + state = SimpleNamespace( + host=host, + port=port, + base_url=base_url, + session_manager=SimpleNamespace(auth_token=auth_token), + ) + return cast("AppStateBase", cast(object, state)) + + +def _fake_request( + *, + scheme: str = "http", + host_header: str = "evil.example.com:80", +) -> Request: + """Build a Starlette Request with a chosen scheme and Host header + without spinning up an app.""" + scope: Scope = { + "type": "http", + "method": "POST", + "scheme": scheme, + "server": ("localhost", 2718), + "path": "/", + "raw_path": b"/", + "query_string": b"", + "headers": [(b"host", host_header.encode())], + } + return Request(scope) + + +def test_get_code_mode_credentials_uses_configured_host_not_request_host_header() -> ( + None +): + """Regression guard for the security property documented on + get_code_mode_credentials: the server URL must come from the server's + configured host/port, not from the (spoofable) Host header. + Loopback addresses are pretty-printed to ``localhost``.""" + url, token = get_code_mode_credentials( + _fake_app_state(host="127.0.0.1", port=2718, auth_token="secret"), + _fake_request(host_header="evil.example.com:80"), + ) + assert url == "http://localhost:2718" + assert token == "secret" + + +def test_get_code_mode_credentials_strips_trailing_slash_from_base_url() -> ( + None +): + url, _ = get_code_mode_credentials( + _fake_app_state(base_url="/notebook/"), + _fake_request(), + ) + assert url == "http://localhost:2718/notebook" + + +def test_get_code_mode_credentials_includes_non_empty_base_url() -> None: + url, _ = get_code_mode_credentials( + _fake_app_state(base_url="/notebook"), + _fake_request(), + ) + assert url == "http://localhost:2718/notebook" + + +def test_get_code_mode_credentials_empty_base_url() -> None: + url, _ = get_code_mode_credentials( + _fake_app_state(base_url=""), + _fake_request(), + ) + assert url == "http://localhost:2718" + + +def test_get_code_mode_credentials_propagates_https_scheme() -> None: + url, _ = get_code_mode_credentials( + _fake_app_state(), + _fake_request(scheme="https"), + ) + assert url == "https://localhost:2718" + + +@pytest.mark.parametrize( + ("host", "expected_url"), + [ + # Loopback addresses get pretty-printed to "localhost". + ("127.0.0.1", "http://localhost:2718"), + ("::1", "http://localhost:2718"), + # Bracket-wrapped loopback (user passed [::1]) — strip + map. + ("[::1]", "http://localhost:2718"), + # Bind-all sentinels are NOT routable for an internal callback; + # they must be mapped to localhost so Playwright can connect. + ("0.0.0.0", "http://localhost:2718"), + ("::", "http://localhost:2718"), + # Real IPv6 literals must be bracketed per RFC 3986. + ("fd00::cafe", "http://[fd00::cafe]:2718"), + ("2001:db8::1", "http://[2001:db8::1]:2718"), + # User-supplied brackets on a real address — strip and re-add. + ("[fd00::cafe]", "http://[fd00::cafe]:2718"), + # Zone IDs are interface-specific and not valid in URLs. + ("fe80::1%eth0", "http://[fe80::1]:2718"), + ("[fe80::1%lo0]", "http://[fe80::1]:2718"), + # Regular hostnames pass through unchanged. + ("example.com", "http://example.com:2718"), + ], +) +def test_get_code_mode_credentials_normalizes_host( + host: str, expected_url: str +) -> None: + """The server URL must be a routable, URL-safe destination — bind-all + addresses mapped to loopback, IPv6 bracketed, zone IDs stripped — + so code-mode callbacks (e.g. Playwright screenshots) actually reach + the server.""" + url, _ = get_code_mode_credentials( + _fake_app_state(host=host), + _fake_request(), + ) + assert url == expected_url diff --git a/tests/_server/test_scratchpad.py b/tests/_server/test_scratchpad.py index 4e1647b319e..ee36b31128a 100644 --- a/tests/_server/test_scratchpad.py +++ b/tests/_server/test_scratchpad.py @@ -1,28 +1,52 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations +import asyncio import json +from contextlib import contextmanager +from types import SimpleNamespace +from typing import TYPE_CHECKING, cast from unittest.mock import MagicMock import pytest from inline_snapshot import snapshot from marimo._ai._tools.types import CodeExecutionResult +from marimo._code_mode.screenshot_meta import ( + SCREENSHOT_AUTH_TOKEN_KEY, + SCREENSHOT_SERVER_URL_KEY, +) from marimo._messaging.cell_output import CellChannel, CellOutput from marimo._messaging.errors import MarimoExceptionRaisedError from marimo._messaging.notification import ( CellNotification, CompletedRunNotification, + NotificationMessage, +) +from marimo._messaging.serde import serialize_kernel_message +from marimo._runtime.commands import ( + CommandMessage, + ExecuteScratchpadCommand, + HTTPRequest, ) from marimo._runtime.scratch import SCRATCH_CELL_ID +from marimo._server.models.models import InstantiateNotebookRequest from marimo._server.scratchpad import ( ScratchCellListener, _format_console, _format_sse, build_done_event, extract_result, + run_scratchpad_code, ) +if TYPE_CHECKING: + from collections.abc import Iterator + + from starlette.types import Scope + + from marimo._session.session import Session + _TEST_RUN_ID = "test-run-id" @@ -52,6 +76,122 @@ def _parse_sse(sse: str) -> tuple[str, dict[str, object]]: return event, json.loads(data) +def _build_request( + *, + scheme: str = "http", + host: str = "localhost", + port: int = 1234, +): + """Build a real Starlette Request so ``HTTPRequest.from_request`` + can read it without an exhaustive MagicMock setup.""" + from starlette.requests import Request + + scope: Scope = { + "type": "http", + "method": "POST", + "scheme": scheme, + "server": (host, port), + "path": "/", + "raw_path": b"/", + "query_string": b"", + "headers": [(b"host", f"{host}:{port}".encode())], + } + return Request(scope) + + +class _FakeSession: + """Minimal duck-typed Session for ``run_scratchpad_code`` tests. + + Behaves like a real Session in the ways the runner cares about: + + * ``scoped`` registers a listener for the duration of the context. + * ``put_control_request`` routes a ``CompletedRunNotification`` to + the active listener so ``listener.wait()`` returns naturally — + no need to monkey-patch listener internals. + + Set ``auto_complete=False`` to drive the timeout path (the listener + will never see a completion event). + """ + + document: SimpleNamespace + session_view: SimpleNamespace + scratchpad_lock: asyncio.Lock + control_requests: list[CommandMessage] + instantiate_calls: list[ + tuple[InstantiateNotebookRequest, HTTPRequest | None] + ] + interrupt_count: int + _auto_complete: bool + _active_listener: ScratchCellListener | None + _pre_complete_notifs: list[NotificationMessage] + + def __init__(self, *, auto_complete: bool = True) -> None: + self.document = SimpleNamespace(cells=()) + self.session_view = SimpleNamespace(cell_notifications={}) + self.scratchpad_lock = asyncio.Lock() + self.control_requests = [] + self.instantiate_calls = [] + self.interrupt_count = 0 + self._auto_complete = auto_complete + self._active_listener = None + self._pre_complete_notifs = [] + + def as_session(self) -> Session: + """Type-only cast to ``Session`` for the runner's signature.""" + return cast("Session", cast(object, self)) + + def emit(self, notification: NotificationMessage) -> None: + """Schedule a notification to be delivered to the active + listener just before the auto-generated completion event.""" + self._pre_complete_notifs.append(notification) + + @contextmanager + def scoped( + self, listener: ScratchCellListener + ) -> Iterator[ScratchCellListener]: + self._active_listener = listener + try: + yield listener + finally: + self._active_listener = None + + def instantiate( + self, + request: InstantiateNotebookRequest, + *, + http_request: HTTPRequest | None, + ) -> None: + self.instantiate_calls.append((request, http_request)) + + def put_control_request( + self, + req: CommandMessage, + from_consumer_id: object = None, + ) -> None: + del from_consumer_id + self.control_requests.append(req) + if not ( + self._auto_complete + and isinstance(req, ExecuteScratchpadCommand) + and self._active_listener is not None + ): + return + session = self.as_session() + for notif in self._pre_complete_notifs: + self._active_listener.on_notification_sent( + session, serialize_kernel_message(notif) + ) + self._active_listener.on_notification_sent( + session, + serialize_kernel_message( + CompletedRunNotification(run_id=req.run_id) + ), + ) + + def try_interrupt(self) -> None: + self.interrupt_count += 1 + + class TestExtractResult: def test_no_notification(self) -> None: result = extract_result(_make_session()) @@ -521,3 +661,232 @@ async def consume() -> None: name, payload = _parse_sse(events[0]) assert name == "stdout" assert payload["data"] == "partial\n" + + +class TestRunScratchpadCode: + """Regression guards for ``run_scratchpad_code`` — the runner that + backs the AI ``execute_code`` tool.""" + + @staticmethod + def _execute_command(session: _FakeSession) -> ExecuteScratchpadCommand: + cmds = [ + c + for c in session.control_requests + if isinstance(c, ExecuteScratchpadCommand) + ] + assert len(cmds) == 1, ( + f"expected one ExecuteScratchpadCommand, got {session.control_requests!r}" + ) + return cmds[0] + + @pytest.mark.asyncio + async def test_stamps_screenshot_meta_and_run_id_on_command( + self, + ) -> None: + """Regression guard: ``run_id`` and screenshot meta must reach + the ``ExecuteScratchpadCommand`` unchanged. Without ``run_id``, + ``ScratchCellListener`` filters out the completion event and + every code-mode tool call hangs ~30s before timing out.""" + session = _FakeSession() + + result = await run_scratchpad_code( + session.as_session(), + _build_request(), + code="x = 1", + server_url="http://localhost:1234", + auth_token="fake-token", + ) + + assert result.success is True + cmd = self._execute_command(session) + assert cmd.run_id is not None + assert cmd.request is not None + assert cmd.request.meta[SCREENSHOT_SERVER_URL_KEY] == ( + "http://localhost:1234" + ) + assert cmd.request.meta[SCREENSHOT_AUTH_TOKEN_KEY] == "fake-token" + + @pytest.mark.asyncio + async def test_instantiates_session_without_auto_run(self) -> None: + """The runner must seed the dependency graph before executing + (so ``_code_mode.run_cell`` can resolve cell IDs) but it must + NOT auto-run the notebook's cells.""" + session = _FakeSession() + + await run_scratchpad_code( + session.as_session(), + _build_request(), + code="x = 1", + server_url="u", + auth_token="t", + ) + + assert len(session.instantiate_calls) == 1 + instantiate_req, _ = session.instantiate_calls[0] + assert instantiate_req.auto_run is False + + @pytest.mark.asyncio + async def test_holds_scratchpad_lock_while_putting_execute_command( + self, + ) -> None: + """``scratchpad_lock`` must be held while putting the execute + command so two concurrent code-mode calls can't interleave.""" + session = _FakeSession() + lock_held: list[bool] = [] + original_put = session.put_control_request + + def spy(req: CommandMessage, from_consumer_id: object = None) -> None: + if isinstance(req, ExecuteScratchpadCommand): + lock_held.append(session.scratchpad_lock.locked()) + original_put(req, from_consumer_id) + + session.put_control_request = spy # type: ignore[method-assign] + + await run_scratchpad_code( + session.as_session(), + _build_request(), + code="x = 1", + server_url="u", + auth_token="t", + ) + + assert lock_held == [True] + + @pytest.mark.asyncio + async def test_timeout_interrupts_kernel_and_surfaces_in_errors( + self, + ) -> None: + """On timeout the kernel is still running the (likely hung) + scratchpad code; ``run_scratchpad_code`` must interrupt it so the + next code-mode call doesn't block on ``scratchpad_lock`` — and + the timeout must be reported in ``errors`` (plural), matching + the success path's shape.""" + session = _FakeSession(auto_complete=False) + + result = await run_scratchpad_code( + session.as_session(), + _build_request(), + code="while True: pass", + server_url="u", + auth_token="t", + timeout=0.05, + ) + + assert result == snapshot( + CodeExecutionResult( + success=False, + errors=["Execution timed out after 0.05s"], + ) + ) + assert session.interrupt_count == 1 + + @pytest.mark.asyncio + async def test_cancelled_wait_interrupts_kernel(self) -> None: + """If the caller cancels mid-tool-call (e.g. the chat session + ends or pydantic_ai aborts the run), the kernel is still + processing our ``ExecuteScratchpadCommand``. We must interrupt + before releasing ``scratchpad_lock`` so the next code-mode + call isn't blocked behind the abandoned code, and we must + re-raise the cancellation so the caller still observes it.""" + session = _FakeSession(auto_complete=False) + + async def run() -> None: + await run_scratchpad_code( + session.as_session(), + _build_request(), + code="while True: pass", + server_url="u", + auth_token="t", + ) + + task = asyncio.create_task(run()) + # Yield enough for the runner to reach `await listener.wait(...)`. + for _ in range(10): + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert session.interrupt_count == 1 + + @pytest.mark.asyncio + async def test_timeout_interrupt_happens_while_holding_lock( + self, + ) -> None: + """Regression guard: ``try_interrupt()`` must run BEFORE releasing + ``scratchpad_lock``. Otherwise a concurrent code-mode call could + acquire the lock and start running between timeout detection and + the interrupt — and get its brand-new execution killed by us.""" + session = _FakeSession(auto_complete=False) + lock_held_during_interrupt: list[bool] = [] + original_interrupt = session.try_interrupt + + def spy() -> None: + lock_held_during_interrupt.append(session.scratchpad_lock.locked()) + original_interrupt() + + session.try_interrupt = spy # type: ignore[method-assign] + + await run_scratchpad_code( + session.as_session(), + _build_request(), + code="while True: pass", + server_url="u", + auth_token="t", + timeout=0.05, + ) + + assert lock_held_during_interrupt == [True] + + @pytest.mark.asyncio + async def test_child_cell_errors_flow_into_result_errors(self) -> None: + """End-to-end: child-cell errors captured by the listener during + execution must surface in ``result.errors`` — otherwise the AI + never learns its ``run_cell`` calls failed. This pins down the + ``extract_result(session, listener)`` plumbing as well; dropping + the ``listener`` arg silently loses every child-cell error.""" + from marimo._types.ids import CellId_t + + session = _FakeSession() + # Real scratch cell runs emit an idle notification even when + # their own output is empty; without one, extract_result + # short-circuits before consulting the listener. + session.session_view.cell_notifications[SCRATCH_CELL_ID] = ( + CellNotification( + cell_id=SCRATCH_CELL_ID, + output=CellOutput( + channel=CellChannel.OUTPUT, + mimetype="text/plain", + data="", + ), + console=None, + status="idle", + ) + ) + session.emit( + CellNotification( + cell_id=CellId_t("child-cell"), + output=CellOutput.errors( + [ + MarimoExceptionRaisedError( + msg="division by zero", + exception_type="ZeroDivisionError", + raising_cell=None, + ) + ] + ), + console=None, + status="idle", + ) + ) + + result = await run_scratchpad_code( + session.as_session(), + _build_request(), + code="run_cell('child-cell')", + server_url="u", + auth_token="t", + ) + + assert result.success is False + assert result.errors == ["cell 'child-cell' raised ZeroDivisionError"]