Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion marimo/_runtime/context/kernel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def create_kernel_context(
from marimo._plugins.ui._core.registry import UIElementRegistry
from marimo._runtime.state import StateRegistry
from marimo._runtime.virtual_file import (
DiskStorage,
FallbackStorage,
InMemoryStorage,
SharedMemoryStorage,
VirtualFileRegistry,
Expand All @@ -163,8 +165,11 @@ def create_kernel_context(
# Storage is chosen explicitly by the caller. None means virtual files
# are not supported; we still construct an (inert) InMemoryStorage so
# the registry has a backend, but ctx.virtual_files_supported is False.
# In EDIT mode, fall back to disk if shared memory cannot allocate
# (e.g., /dev/shm is full). Disk is cross-process readable, so the
# main-process server can still serve files via the /@file endpoint.
storage: VirtualFileStorage = (
SharedMemoryStorage()
FallbackStorage([SharedMemoryStorage(), DiskStorage()])
if virtual_file_storage == "shared_memory"
else InMemoryStorage()
)
Expand Down
4 changes: 4 additions & 0 deletions marimo/_runtime/virtual_file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from __future__ import annotations

from marimo._runtime.virtual_file.storage import (
DiskStorage,
FallbackStorage,
InMemoryStorage,
SharedMemoryStorage,
VirtualFileStorage,
Expand All @@ -31,6 +33,8 @@
"VirtualFileStorageType",
"SharedMemoryStorage",
"InMemoryStorage",
"DiskStorage",
"FallbackStorage",
"VirtualFileStorageManager",
# Virtual files
"VirtualFile",
Expand Down
279 changes: 273 additions & 6 deletions marimo/_runtime/virtual_file/storage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

import os
import sys
from typing import TYPE_CHECKING, Literal, Protocol
import tempfile
from pathlib import Path
from typing import IO, TYPE_CHECKING, Literal, Protocol

from marimo import _loggers
from marimo._utils.platform import is_pyodide

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

LOGGER = _loggers.marimo_logger()

DEFAULT_CHUNK_SIZE = 256 * 1024 # 256KB

VirtualFileStorageType = Literal["in_memory", "shared_memory"]
Expand Down Expand Up @@ -155,8 +161,13 @@ def read_chunked(
view = None
try:
shm = shared_memory.SharedMemory(name=key)
# Slice clamps to actual segment size; iterate over the
# clamped length, not the requested byte_length, so a huge
# byte_length from the URL doesn't trigger a long sequence
# of empty yields after the segment is exhausted.
view = shm.buf[:byte_length]
for i in range(0, byte_length, chunk_size):
actual_length = len(view)
for i in range(0, actual_length, chunk_size):
yield bytes(view[i : i + chunk_size])
except FileNotFoundError as err:
raise KeyError(f"Virtual file not found: {key}") from err
Expand Down Expand Up @@ -196,7 +207,22 @@ def shutdown(self, keys: Iterable[str] | None = None) -> None:
self._stale = True

def has(self, key: str) -> bool:
return key in self._storage
if key in self._storage:
return True
# Cross-process probe: try to open the segment by name. Used by
# FallbackStorage to locate keys written by another process.
# Catch OSError (parent of FileNotFoundError, plus errno variants
# like ENOENT/EACCES on hostile keys) and ValueError (invalid name
# characters on some platforms) so probing falls through to the
# next backend instead of crashing.
if is_pyodide():
return False
try:
shm = shared_memory.SharedMemory(name=key)
except (OSError, ValueError):
return False
shm.close()
return True


class InMemoryStorage(VirtualFileStorage):
Expand Down Expand Up @@ -248,6 +274,239 @@ def has(self, key: str) -> bool:
return key in self._storage


class DiskStorage(VirtualFileStorage):
"""Storage backend that writes virtual file bytes to a temp directory.

Cross-process safe: any process can read a key by computing the same path.
Used as a fallback tier when ``SharedMemoryStorage`` cannot allocate
(e.g., ``/dev/shm`` is full).
"""

def __init__(self, base_dir: str | Path | None = None) -> None:
self._base_dir = (
Path(base_dir)
if base_dir is not None
else Path(tempfile.gettempdir()) / "marimo-vfs"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be verifying that this gets mapped to tmpfs on molab, otherwise this will be very slow when actually used. Currently maps to /tmp (expected) but our mounts should this to be on disk

)
# Refuse a base dir that already exists as a symlink. Defends
# against the classic shared-/tmp pre-creation attack where a
# co-tenant pre-creates `/tmp/marimo-vfs` as a symlink to
# somewhere they control, redirecting all our writes.
if self._base_dir.is_symlink():
raise OSError(
f"Refusing to use {self._base_dir}: path is a symlink"
)
self._owned_keys: set[str] = set()
self._stale = False

@property
def stale(self) -> bool:
return self._stale

def _path(self, key: str) -> Path:
# Reject keys that would escape base_dir (path separators, traversal
# components, null bytes). Virtual file keys are generated by
# `random_filename` as a single filename component, so any key with
# path-like structure is either a bug or a malicious request via
# the `/@file/{...:path}` endpoint. Treating these as missing keeps
# the disk fallback safe even if the read endpoint forwards the
# raw URL segment.
if (
not key
or "/" in key
or "\\" in key
or "\x00" in key
or key in (".", "..")
):
raise KeyError(f"Invalid virtual file key: {key!r}")
return self._base_dir / key
Comment thread
dmadisetti marked this conversation as resolved.

def _open(self, key: str) -> IO[bytes]:
try:
return self._path(key).open("rb")
except FileNotFoundError as err:
raise KeyError(f"Virtual file not found: {key}") from err

def store(self, key: str, buffer: bytes) -> None:
# Validate up front: rejects path-traversal keys before mkstemp
# uses `key` as a filename prefix.
target = self._path(key)
self._base_dir.mkdir(parents=True, exist_ok=True)
# Atomic write: write to a unique temp sibling then rename. Prevents
# cross-process readers from seeing a half-written file; mkstemp's
# uniqueness avoids collisions when two stores race on one key.
fd, tmp_str = tempfile.mkstemp(
prefix=f"{key}.", suffix=".tmp", dir=self._base_dir
)
tmp = Path(tmp_str)
try:
with os.fdopen(fd, "wb") as f:
f.write(buffer)
tmp.replace(target)
except BaseException:
tmp.unlink(missing_ok=True)
raise
self._owned_keys.add(key)

def read(self, key: str, byte_length: int) -> bytes:
with self._open(key) as f:
return f.read(byte_length)

def read_chunked(
self,
key: str,
byte_length: int,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Iterator[bytes]:
with self._open(key) as f:
remaining = byte_length
while remaining > 0:
chunk = f.read(min(chunk_size, remaining))
if not chunk:
break
remaining -= len(chunk)
yield chunk

def remove(self, key: str) -> None:
self._path(key).unlink(missing_ok=True)
self._owned_keys.discard(key)

def shutdown(self, keys: Iterable[str] | None = None) -> None:
# When keys is None, only remove keys this instance wrote — avoids
# racing other kernels that share the temp directory.
target = list(keys) if keys is not None else list(self._owned_keys)
for key in target:
self.remove(key)
if keys is None:
self._owned_keys.clear()
self._stale = True

def has(self, key: str) -> bool:
# Invalid keys never exist — return False rather than raising so
# FallbackStorage probing handles hostile keys gracefully.
try:
return self._path(key).exists()
except KeyError:
return False


class FallbackStorage(VirtualFileStorage):
"""Composite storage that tries each backend in order.

On ``store``, attempts each backend in order and falls back to the next
when the current one raises ``OSError`` (e.g., shared memory exhaustion).
Reads, removes, and existence checks are routed to the backend that
accepted the original ``store``; if no routing is recorded (e.g., a fresh
instance used only for cross-process reads), each backend is probed in
order via ``has``.
"""

def __init__(self, backends: list[VirtualFileStorage]) -> None:
if not backends:
raise ValueError("FallbackStorage requires at least one backend")
self._backends = backends
self._routing: dict[str, int] = {}

@property
def stale(self) -> bool:
return all(b.stale for b in self._backends)

def store(self, key: str, buffer: bytes) -> None:
if key in self._routing:
return
last_err: OSError | None = None
for i, backend in enumerate(self._backends):
if backend.stale:
continue
try:
backend.store(key, buffer)
except OSError as err:
last_err = err
LOGGER.warning(
"Virtual file storage backend %d (%s) failed to store "
"key %r (%d bytes): %s; trying next backend",
i,
type(backend).__name__,
key,
len(buffer),
err,
)
continue
self._routing[key] = i
return
if last_err is not None:
raise last_err
# All backends were stale (no store was attempted).
raise OSError(
"All virtual file storage backends are stale; cannot store "
f"key {key!r}"
)

def _backend_for(self, key: str) -> VirtualFileStorage | None:
idx = self._routing.get(key)
if idx is not None:
return self._backends[idx]
for backend in self._backends:
if backend.has(key):
return backend
return None

def read(self, key: str, byte_length: int) -> bytes:
backend = self._backend_for(key)
if backend is None:
raise KeyError(f"Virtual file not found: {key}")
return backend.read(key, byte_length)

def read_chunked(
self,
key: str,
byte_length: int,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Iterator[bytes]:
backend = self._backend_for(key)
if backend is None:
raise KeyError(f"Virtual file not found: {key}")
yield from backend.read_chunked(key, byte_length, chunk_size)

def remove(self, key: str) -> None:
idx = self._routing.pop(key, None)
if idx is not None:
self._backends[idx].remove(key)
return
for backend in self._backends:
backend.remove(key)

def shutdown(self, keys: Iterable[str] | None = None) -> None:
# Forward to every backend, isolating failures: one backend's
# cleanup error must not prevent the others from releasing
# resources (shared memory segments, disk files).
key_list = list(keys) if keys is not None else None
for i, backend in enumerate(self._backends):
try:
if key_list is None:
backend.shutdown()
else:
backend.shutdown(keys=key_list)
except Exception as err:
LOGGER.warning(
"Virtual file storage backend %d (%s) failed during "
"shutdown: %s",
i,
type(backend).__name__,
err,
)
if key_list is None:
self._routing.clear()
else:
for key in key_list:
self._routing.pop(key, None)

def has(self, key: str) -> bool:
if key in self._routing:
return self._backends[self._routing[key]].has(key)
return any(b.has(key) for b in self._backends)


class VirtualFileStorageManager:
"""Singleton manager for virtual file storage access."""

Expand All @@ -269,6 +528,13 @@ def storage(self) -> VirtualFileStorage | None:
def storage(self, value: VirtualFileStorage | None) -> None:
self._storage = value

@staticmethod
def _cross_process_storage() -> VirtualFileStorage:
"""Build an ad-hoc reader that can locate files in either tier
written by the kernel subprocess in EDIT mode.
"""
return FallbackStorage([SharedMemoryStorage(), DiskStorage()])

def read(self, filename: str, byte_length: int) -> bytes:
"""Read from storage, with cross-process fallback for EDIT mode server.

Expand All @@ -279,8 +545,9 @@ def read(self, filename: str, byte_length: int) -> bytes:
storage = self.storage
if storage is None:
# Never initialized so in a separate thread from the kernel.
# Use SharedMemoryStorage to read by name across processes
return SharedMemoryStorage().read(filename, byte_length)
# Probe shared memory first, then disk, to locate files written
# by the kernel subprocess.
return self._cross_process_storage().read(filename, byte_length)
return storage.read(filename, byte_length)

def read_chunked(
Expand All @@ -300,7 +567,7 @@ def read_chunked(
"""
storage = self.storage
if storage is None:
yield from SharedMemoryStorage().read_chunked(
yield from self._cross_process_storage().read_chunked(
filename, byte_length, chunk_size
)
else:
Expand Down
Loading
Loading