Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions marimo/_code_mode/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
_UpdateOp,
_validate_ops,
)
from marimo._code_mode.screenshot_meta import (
SCREENSHOT_AUTH_TOKEN_KEY,
SCREENSHOT_SERVER_URL_KEY,
)
from marimo._messaging.errors import Error
from marimo._messaging.notebook.changes import (
CreateCell,
Expand Down Expand Up @@ -1272,19 +1276,20 @@ async def screenshot(
"available. screenshot() must be called during cell "
"execution (e.g. from code-mode)."
)
# Read trusted server URL and auth token injected by the
# /execute endpoint (from server config, not request headers).

# Trusted server URL + auth token injected by the /execute
# endpoint (from server config, not request headers).
server_url = cast(
"str | None", request.meta.get("screenshot_server_url")
"str | None", request.meta.get(SCREENSHOT_SERVER_URL_KEY)
)
if server_url is None:
raise ScreenshotError(
"Cannot take screenshots: screenshot_server_url not "
"Cannot take screenshots: screenshot credentials not "
"found in request.meta. This endpoint may not "
"support screenshots."
)
screenshot_auth_token = cast(
"str | None", request.meta.get("screenshot_auth_token")
"str | None", request.meta.get(SCREENSHOT_AUTH_TOKEN_KEY)
)

# Lazy-init the screenshot session (browser reuse).
Expand Down
14 changes: 14 additions & 0 deletions marimo/_code_mode/screenshot_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2026 Marimo. All rights reserved.
"""Meta keys for wiring screenshot credentials through ``HTTPRequest.meta``.
Code-mode tools running in the kernel need to call back into the marimo
server (e.g. ``ctx.screenshot()`` driving Playwright against the
kiosk page). The server stamps a trusted ``server_url`` and
``auth_token`` onto each control request's ``meta`` dict; the runtime
side reads them when building the screenshot session.
"""

from __future__ import annotations

SCREENSHOT_SERVER_URL_KEY = "screenshot_server_url"
SCREENSHOT_AUTH_TOKEN_KEY = "screenshot_auth_token"
55 changes: 55 additions & 0 deletions marimo/_server/ai/tools/code_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

from typing import TYPE_CHECKING

from marimo._ai._tools.types import CodeExecutionResult
from marimo._server.api.deps import AppState
from marimo._server.api.utils import get_code_mode_credentials
from marimo._server.scratchpad import run_scratchpad_code

if TYPE_CHECKING:
from pydantic_ai import FunctionToolset
from starlette.requests import Request

from marimo._session.session import Session


def build_execute_code_toolset(
session: Session,
request: Request,
) -> FunctionToolset[None]:
"""Build a ``FunctionToolset`` exposing one tool: ``execute_code``.
The tool is bound to the caller's *session* and *request*; the model
never sees or passes a session id. Screenshot credentials are derived
per tool call from the request so ``ctx.screenshot()`` can call back
into this server (see ``marimo/_code_mode/_context.py``).
"""

from pydantic_ai import FunctionToolset

toolset: FunctionToolset[None] = FunctionToolset()

async def execute_code(code: str) -> CodeExecutionResult:
"""Run Python inside the running notebook's kernel scratchpad.
Use this for all notebook mutations via ``marimo._code_mode``.
"""
server_url, auth_token = get_code_mode_credentials(
AppState(request), request
)
return await run_scratchpad_code(
session,
request,
code=code,
server_url=server_url,
auth_token=auth_token,
)

toolset.add_function(
execute_code,
name="execute_code",
description=execute_code.__doc__,
)
return toolset
26 changes: 13 additions & 13 deletions marimo/_server/api/endpoints/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@
from starlette.responses import JSONResponse, StreamingResponse

from marimo import _loggers
from marimo._code_mode.screenshot_meta import (
SCREENSHOT_AUTH_TOKEN_KEY,
SCREENSHOT_SERVER_URL_KEY,
)
from marimo._messaging.notification import AlertNotification
from marimo._runtime.commands import HTTPRequest, UpdateUIElementCommand
from marimo._server.api.deps import AppState
from marimo._server.api.endpoints.ws.ws_connection_validator import (
FILE_QUERY_PARAM_KEY,
)
from marimo._server.api.endpoints.ws_endpoint import DOC_MANAGER
from marimo._server.api.utils import dispatch_control_request, parse_request
from marimo._server.api.utils import (
dispatch_control_request,
get_code_mode_credentials,
parse_request,
)
from marimo._server.models.models import (
BaseResponse,
DebugCellRequest,
Expand Down Expand Up @@ -312,19 +320,11 @@ async def sse_generator() -> AsyncGenerator[str, None]:
with session.scoped(listener):
async with session.scratchpad_lock:
http_req = HTTPRequest.from_request(request)
# Inject trusted server URL and auth token for
# code-mode screenshot support. We use the
# server's own host/port (from config) rather
# than the request's Host header to prevent
# header-spoofing attacks.
http_req.meta["screenshot_auth_token"] = str(
app_state.session_manager.auth_token
)
base_url = app_state.base_url.rstrip("/")
scheme = request.url.scheme or "http"
http_req.meta["screenshot_server_url"] = (
f"{scheme}://{app_state.host}:{app_state.port}{base_url}"
server_url, auth_token = get_code_mode_credentials(
app_state, request
)
http_req.meta[SCREENSHOT_SERVER_URL_KEY] = server_url
http_req.meta[SCREENSHOT_AUTH_TOKEN_KEY] = auth_token
session.put_control_request(
ExecuteScratchpadCommand(
code=body.code,
Expand Down
57 changes: 6 additions & 51 deletions marimo/_server/api/lifespans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import asyncio
import contextlib
import ipaddress
import socket
from typing import TYPE_CHECKING, Any

from marimo import _loggers
from marimo._server.ai.mcp.config import is_mcp_config_empty
from marimo._server.ai.tools.tool_manager import setup_tool_manager
from marimo._server.api.deps import AppState, AppStateBase
from marimo._server.api.interrupt import InterruptHandler
from marimo._server.api.utils import open_url_in_browser
from marimo._server.api.utils import (
format_url_host,
open_url_in_browser,
)
from marimo._server.lsp import any_lsp_server_running
from marimo._server.print import (
print_experimental_features,
Expand Down Expand Up @@ -271,48 +272,10 @@ async def reap_subprocesses(app: Starlette) -> AsyncIterator[None]:
await cancel_pending_reaps()


def _pretty_host(host: str, port: int) -> str:
"""Replace loopback addresses with 'localhost' for display.

Uses ipaddress for a reliable cross-platform loopback check (covers
127.0.0.1, ::1, and the full 127.0.0.0/8 range). Falls back to
socket.getnameinfo only for non-IP hosts. getnameinfo is skipped for
raw IP addresses because it can hang on Windows/CI for link-local IPv6.
"""
try:
if ipaddress.ip_address(host).is_loopback:
return "localhost"
except ValueError:
# Not a valid IP literal β€” might be a hostname; try getnameinfo
try:
if (
socket.getnameinfo((host, port), socket.NI_NOFQDN)[0]
== "localhost"
):
return "localhost"
except Exception:
pass
return host


def _startup_url(state: AppStateBase) -> str:
host = state.host.strip(
"[]"
) # normalize: remove brackets if user passed [addr]
url_host = format_url_host(state.host, state.port)
port = state.port

# Strip IPv6 zone ID (e.g. fe80::1%eth0 -> fe80::1); zone IDs are
# interface-specific and not valid in URLs.
# Must happen before _pretty_host β€” zone IDs can cause getnameinfo
# to hang on Windows/CI.
host = host.split("%")[0]

# pretty printing: show "localhost" for loopback addresses
host = _pretty_host(host, port)

url_host_bare = host
# IPv6 addresses must be wrapped in brackets in URLs (RFC 3986)
url_host = f"[{url_host_bare}]" if ":" in url_host_bare else url_host_bare
url = f"http://{url_host}:{port}{state.base_url}"
if port == 80:
url = f"http://{url_host}{state.base_url}"
Expand All @@ -325,18 +288,10 @@ def _startup_url(state: AppStateBase) -> str:


def _mcp_startup_url(state: AppStateBase) -> str:
host = state.host.strip(
"[]"
) # normalize: remove brackets if user passed [addr]
url_host = format_url_host(state.host, state.port)
port = state.port
base_url = state.base_url

# Strip zone ID, then pretty-print loopback (same logic as _startup_url)
host = host.split("%")[0]
host = _pretty_host(host, port)

url_host_bare = host
url_host = f"[{url_host_bare}]" if ":" in url_host_bare else url_host_bare
# Construct MCP endpoint URL
mcp_prefix = "/mcp"
mcp_name = "server"
Expand Down
84 changes: 84 additions & 0 deletions marimo/_server/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from starlette.datastructures import UploadFile
from starlette.requests import Request

from marimo._server.api.deps import AppStateBase
from marimo._session.session import Session


Expand Down Expand Up @@ -113,6 +114,89 @@ async def dispatch_control_request(
return SuccessResponse()


def pretty_host(host: str, port: int) -> str:
"""Replace loopback addresses with 'localhost' for display.

Uses ipaddress for a reliable cross-platform loopback check (covers
127.0.0.1, ::1, and the full 127.0.0.0/8 range). Falls back to
socket.getnameinfo only for non-IP hosts. getnameinfo is skipped
for raw IP addresses because it can hang on Windows/CI for
link-local IPv6.
"""
import ipaddress
import socket

try:
if ipaddress.ip_address(host).is_loopback:
return "localhost"
except ValueError:
# Not a valid IP literal β€” might be a hostname; try getnameinfo
try:
if (
socket.getnameinfo((host, port), socket.NI_NOFQDN)[0]
== "localhost"
):
return "localhost"
except Exception:
pass
return host


def format_url_host(
host: str,
port: int,
*,
route_bind_all_to_loopback: bool = False,
) -> str:
"""Normalize ``host`` for use in a URL authority.

Transforms (in order):

1. Strip user-supplied brackets (``[::1]`` -> ``::1``).
2. Drop IPv6 zone IDs (``fe80::1%eth0`` -> ``fe80::1``); zone IDs
are interface-specific and not valid in URLs. (Must happen
before :func:`pretty_host` β€” zone IDs can cause getnameinfo
to hang on Windows/CI.)
3. Replace loopback addresses with ``"localhost"``.
4. If ``route_bind_all_to_loopback``, replace the bind-all
sentinels ``0.0.0.0`` / ``::`` with ``"localhost"`` so internal
callback URLs reach a routable target (the bind-all addresses
are not valid destinations for an HTTP client / Playwright).
Display URLs leave them alone β€” the user knows what they
configured.
5. Wrap bare IPv6 literals in brackets per RFC 3986.
"""
host = host.strip("[]").split("%")[0]
host = pretty_host(host, port)
if route_bind_all_to_loopback and host in {"0.0.0.0", "::"}:
host = "localhost"
return f"[{host}]" if ":" in host else host


def get_code_mode_credentials(
app_state: AppStateBase, request: Request
) -> tuple[str, str]:
"""Return ``(server_url, auth_token)`` for code-mode tools that
call back into this marimo server (e.g. ``ctx.screenshot()``
driving Playwright against the running notebook UI).

The URL is built from the server's configured ``host``/``port``
rather than the request's ``Host`` header so a spoofed header
can't redirect the auth token to an attacker-controlled URL.
``host`` is also normalized for URL safety (IPv6 bracketing, zone-ID
stripping) and for routability (bind-all addresses mapped to
``localhost``) so the callback actually reaches the server.
"""
auth_token = str(app_state.session_manager.auth_token)
base_url = app_state.base_url.rstrip("/")
scheme = request.url.scheme or "http"
url_host = format_url_host(
app_state.host, app_state.port, route_bind_all_to_loopback=True
)
server_url = f"{scheme}://{url_host}:{app_state.port}{base_url}"
return server_url, auth_token
Comment thread
Light2Dark marked this conversation as resolved.


def parse_title(filepath: str | None) -> str:
"""
Create a title from a filename.
Expand Down
Loading
Loading