Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions marimo/_code_mode/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
_UpdateOp,
_validate_ops,
)
from marimo._code_mode.screenshot_meta import (
SCREENSHOT_AUTH_TOKEN_KEY,
SCREENSHOT_SERVER_URL_KEY,
)
from marimo._messaging.errors import Error
from marimo._messaging.notebook.changes import (
CreateCell,
Expand Down Expand Up @@ -1272,19 +1276,20 @@ async def screenshot(
"available. screenshot() must be called during cell "
"execution (e.g. from code-mode)."
)
# Read trusted server URL and auth token injected by the
# /execute endpoint (from server config, not request headers).

# Trusted server URL + auth token injected by the /execute
# endpoint (from server config, not request headers).
server_url = cast(
"str | None", request.meta.get("screenshot_server_url")
"str | None", request.meta.get(SCREENSHOT_SERVER_URL_KEY)
)
if server_url is None:
raise ScreenshotError(
"Cannot take screenshots: screenshot_server_url not "
"Cannot take screenshots: screenshot credentials not "
"found in request.meta. This endpoint may not "
"support screenshots."
)
screenshot_auth_token = cast(
"str | None", request.meta.get("screenshot_auth_token")
"str | None", request.meta.get(SCREENSHOT_AUTH_TOKEN_KEY)
)

# Lazy-init the screenshot session (browser reuse).
Expand Down
14 changes: 14 additions & 0 deletions marimo/_code_mode/screenshot_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2026 Marimo. All rights reserved.
"""Meta keys for wiring screenshot credentials through ``HTTPRequest.meta``.

Code-mode tools running in the kernel need to call back into the marimo
server (e.g. ``ctx.screenshot()`` driving Playwright against the
kiosk page). The server stamps a trusted ``server_url`` and
``auth_token`` onto each control request's ``meta`` dict; the runtime
side reads them when building the screenshot session.
"""

from __future__ import annotations

SCREENSHOT_SERVER_URL_KEY = "screenshot_server_url"
SCREENSHOT_AUTH_TOKEN_KEY = "screenshot_auth_token"
55 changes: 55 additions & 0 deletions marimo/_server/ai/tools/code_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING

from marimo._ai._tools.types import CodeExecutionResult
from marimo._server.api.deps import AppState
from marimo._server.api.utils import get_code_mode_credentials
from marimo._server.scratchpad import run_scratchpad_code

if TYPE_CHECKING:
from pydantic_ai import FunctionToolset
from starlette.requests import Request

from marimo._session.session import Session


def build_execute_code_toolset(
session: Session,
request: Request,
) -> FunctionToolset[None]:
"""Build a ``FunctionToolset`` exposing one tool: ``execute_code``.

The tool is bound to the caller's *session* and *request*; the model
never sees or passes a session id. Screenshot credentials are derived
per tool call from the request so ``ctx.screenshot()`` can call back
into this server (see ``marimo/_code_mode/_context.py``).
"""

from pydantic_ai import FunctionToolset

toolset: FunctionToolset[None] = FunctionToolset()

async def execute_code(code: str) -> CodeExecutionResult:
"""Run Python inside the running notebook's kernel scratchpad.

Use this for all notebook mutations via ``marimo._code_mode``.
"""
server_url, auth_token = get_code_mode_credentials(
AppState(request), request
)
return await run_scratchpad_code(
session,
request,
code=code,
server_url=server_url,
auth_token=auth_token,
)

toolset.add_function(
execute_code,
name="execute_code",
description=execute_code.__doc__,
)
return toolset
26 changes: 13 additions & 13 deletions marimo/_server/api/endpoints/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@
from starlette.responses import JSONResponse, StreamingResponse

from marimo import _loggers
from marimo._code_mode.screenshot_meta import (
SCREENSHOT_AUTH_TOKEN_KEY,
SCREENSHOT_SERVER_URL_KEY,
)
from marimo._messaging.notification import AlertNotification
from marimo._runtime.commands import HTTPRequest, UpdateUIElementCommand
from marimo._server.api.deps import AppState
from marimo._server.api.endpoints.ws.ws_connection_validator import (
FILE_QUERY_PARAM_KEY,
)
from marimo._server.api.endpoints.ws_endpoint import DOC_MANAGER
from marimo._server.api.utils import dispatch_control_request, parse_request
from marimo._server.api.utils import (
dispatch_control_request,
get_code_mode_credentials,
parse_request,
)
from marimo._server.models.models import (
BaseResponse,
DebugCellRequest,
Expand Down Expand Up @@ -312,19 +320,11 @@ async def sse_generator() -> AsyncGenerator[str, None]:
with session.scoped(listener):
async with session.scratchpad_lock:
http_req = HTTPRequest.from_request(request)
# Inject trusted server URL and auth token for
# code-mode screenshot support. We use the
# server's own host/port (from config) rather
# than the request's Host header to prevent
# header-spoofing attacks.
http_req.meta["screenshot_auth_token"] = str(
app_state.session_manager.auth_token
)
base_url = app_state.base_url.rstrip("/")
scheme = request.url.scheme or "http"
http_req.meta["screenshot_server_url"] = (
f"{scheme}://{app_state.host}:{app_state.port}{base_url}"
server_url, auth_token = get_code_mode_credentials(
app_state, request
)
http_req.meta[SCREENSHOT_SERVER_URL_KEY] = server_url
http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] = auth_token
session.put_control_request(
ExecuteScratchpadCommand(
code=body.code,
Expand Down
19 changes: 19 additions & 0 deletions marimo/_server/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from starlette.datastructures import UploadFile
from starlette.requests import Request

from marimo._server.api.deps import AppStateBase
from marimo._session.session import Session


Expand Down Expand Up @@ -113,6 +114,24 @@ async def dispatch_control_request(
return SuccessResponse()


def get_code_mode_credentials(
app_state: AppStateBase, request: Request
) -> tuple[str, str]:
"""Return ``(server_url, auth_token)`` for code-mode tools that
call back into this marimo server (e.g. ``ctx.screenshot()``
driving Playwright against the running notebook UI).

The URL is built from the server's configured ``host``/``port``
rather than the request's ``Host`` header so a spoofed header
can't redirect the auth token to an attacker-controlled URL.
"""
auth_token = str(app_state.session_manager.auth_token)
base_url = app_state.base_url.rstrip("/")
scheme = request.url.scheme or "http"
server_url = f"{scheme}://{app_state.host}:{app_state.port}{base_url}"
Comment thread
Light2Dark marked this conversation as resolved.
Outdated
return server_url, auth_token
Comment thread
Light2Dark marked this conversation as resolved.


def parse_title(filepath: str | None) -> str:
"""
Create a title from a filename.
Expand Down
61 changes: 61 additions & 0 deletions marimo/_server/scratchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@
import asyncio
import json
from typing import TYPE_CHECKING, Any, TypedDict
from uuid import uuid4

from marimo._ai._tools.types import CodeExecutionResult
from marimo._code_mode.screenshot_meta import (
SCREENSHOT_AUTH_TOKEN_KEY,
SCREENSHOT_SERVER_URL_KEY,
)
from marimo._messaging.cell_output import CellChannel
from marimo._messaging.notification import (
CellNotification,
CompletedRunNotification,
)
from marimo._messaging.serde import deserialize_kernel_message
from marimo._runtime.commands import ExecuteScratchpadCommand, HTTPRequest
from marimo._runtime.scratch import SCRATCH_CELL_ID
from marimo._server.models.models import InstantiateNotebookRequest
from marimo._session.extensions.types import EventAwareExtension

if TYPE_CHECKING:
from collections.abc import AsyncGenerator

from starlette.requests import Request

from marimo._messaging.types import KernelMessage
from marimo._session.session import Session

Expand Down Expand Up @@ -286,3 +295,55 @@ def extract_result(
stderr=stderr,
errors=errors,
)


async def run_scratchpad_code(
session: Session,
request: Request,
*,
code: str,
server_url: str,
auth_token: str,
timeout: float = EXECUTION_TIMEOUT,
) -> CodeExecutionResult:
"""Drive the kernel scratchpad on behalf of code-mode (AI tool).

``server_url`` and ``auth_token`` are stamped onto ``http_req.meta``
so ``ctx.screenshot()`` from inside code-mode can authenticate
Playwright against this server (see ``marimo/_code_mode/_context.py``).
"""
http_req = HTTPRequest.from_request(request)
http_req.meta[SCREENSHOT_SERVER_URL_KEY] = server_url
http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] = auth_token

session.instantiate(
InstantiateNotebookRequest(object_ids=[], values=[], auto_run=False),
http_request=HTTPRequest.from_request(request),
Comment thread
Light2Dark marked this conversation as resolved.
Outdated
Comment thread
Light2Dark marked this conversation as resolved.
Outdated
)

# See /api/kernel/execute for rationale.
run_id = str(uuid4())
listener = ScratchCellListener(run_id=run_id)

with session.scoped(listener):
async with session.scratchpad_lock:
session.put_control_request(
ExecuteScratchpadCommand(
code=code,
request=http_req,
notebook_cells=tuple(session.document.cells),
run_id=run_id,
),
from_consumer_id=None,
)
await listener.wait(timeout=timeout)
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated
if listener.timed_out:
# The kernel is still running the (likely hung) scratchpad
# code; so we interrupt it to unblock the next code-mode call.
session.try_interrupt()
Comment thread
Light2Dark marked this conversation as resolved.
Outdated
return CodeExecutionResult(
success=False,
errors=[f"Execution timed out after {timeout}s"],
)
Comment thread
Light2Dark marked this conversation as resolved.
Outdated

return extract_result(session, listener)
13 changes: 10 additions & 3 deletions tests/_server/api/endpoints/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

import pytest

from marimo._code_mode.screenshot_meta import (
SCREENSHOT_AUTH_TOKEN_KEY,
SCREENSHOT_SERVER_URL_KEY,
)
from marimo._types.ids import CellId_t, SessionId
from marimo._utils.lists import first
from tests._server.mocks import (
Expand Down Expand Up @@ -206,10 +210,13 @@ async def empty_stream(self: object): # noqa: ARG001
assert len(scratchpad_cmds) == 1, (
f"expected one ExecuteScratchpadCommand, got {captured!r}"
)
meta = scratchpad_cmds[0].request.meta
assert meta["screenshot_auth_token"] == "fake-token"
http_req = scratchpad_cmds[0].request
assert http_req is not None
# Mock server uses host="localhost", port=1234, base_url=""
assert meta["screenshot_server_url"] == "http://localhost:1234"
assert http_req.meta[SCREENSHOT_SERVER_URL_KEY] == (
"http://localhost:1234"
)
assert http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] == "fake-token"

@staticmethod
@with_session(SESSION_ID)
Expand Down
Loading
Loading