Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion marimo/_runtime/context/kernel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def create_kernel_context(
from marimo._plugins.ui._core.registry import UIElementRegistry
from marimo._runtime.state import StateRegistry
from marimo._runtime.virtual_file import (
DiskStorage,
FallbackStorage,
InMemoryStorage,
SharedMemoryStorage,
VirtualFileRegistry,
Expand All @@ -163,8 +165,11 @@ def create_kernel_context(
# Storage is chosen explicitly by the caller. None means virtual files
# are not supported; we still construct an (inert) InMemoryStorage so
# the registry has a backend, but ctx.virtual_files_supported is False.
# In EDIT mode, fall back to disk if shared memory cannot allocate
# (e.g., /dev/shm is full). Disk is cross-process readable, so the
# main-process server can still serve files via the /@file endpoint.
storage: VirtualFileStorage = (
SharedMemoryStorage()
FallbackStorage([SharedMemoryStorage(), DiskStorage()])
if virtual_file_storage == "shared_memory"
else InMemoryStorage()
)
Expand Down
4 changes: 4 additions & 0 deletions marimo/_runtime/virtual_file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from __future__ import annotations

from marimo._runtime.virtual_file.storage import (
DiskStorage,
FallbackStorage,
InMemoryStorage,
SharedMemoryStorage,
VirtualFileStorage,
Expand All @@ -31,6 +33,8 @@
"VirtualFileStorageType",
"SharedMemoryStorage",
"InMemoryStorage",
"DiskStorage",
"FallbackStorage",
"VirtualFileStorageManager",
# Virtual files
"VirtualFile",
Expand Down
219 changes: 214 additions & 5 deletions marimo/_runtime/virtual_file/storage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright 2026 Marimo. All rights reserved.
from __future__ import annotations

import os
import sys
from typing import TYPE_CHECKING, Literal, Protocol
import tempfile
from pathlib import Path
from typing import IO, TYPE_CHECKING, Literal, Protocol

from marimo import _loggers
from marimo._utils.platform import is_pyodide

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

LOGGER = _loggers.marimo_logger()

DEFAULT_CHUNK_SIZE = 256 * 1024 # 256KB

VirtualFileStorageType = Literal["in_memory", "shared_memory"]
Expand Down Expand Up @@ -196,7 +202,18 @@ def shutdown(self, keys: Iterable[str] | None = None) -> None:
self._stale = True

def has(self, key: str) -> bool:
return key in self._storage
if key in self._storage:
return True
# Cross-process probe: try to open the segment by name. Used by
# FallbackStorage to locate keys written by another process.
if is_pyodide():
return False
try:
shm = shared_memory.SharedMemory(name=key)
except FileNotFoundError:
return False
shm.close()
return True


class InMemoryStorage(VirtualFileStorage):
Expand Down Expand Up @@ -248,6 +265,190 @@ def has(self, key: str) -> bool:
return key in self._storage


class DiskStorage(VirtualFileStorage):
"""Storage backend that writes virtual file bytes to a temp directory.

Cross-process safe: any process can read a key by computing the same path.
Used as a fallback tier when ``SharedMemoryStorage`` cannot allocate
(e.g., ``/dev/shm`` is full).
"""

def __init__(self, base_dir: str | Path | None = None) -> None:
self._base_dir = (
Path(base_dir)
if base_dir is not None
else Path(tempfile.gettempdir()) / "marimo-vfs"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be verifying that this gets mapped to tmpfs on molab, otherwise this will be very slow when actually used. Currently maps to /tmp (expected) but our mounts should this to be on disk

)
self._owned_keys: set[str] = set()
self._stale = False

@property
def stale(self) -> bool:
return self._stale

def _path(self, key: str) -> Path:
return self._base_dir / key
Comment thread
dmadisetti marked this conversation as resolved.

def _open(self, key: str) -> IO[bytes]:
try:
return self._path(key).open("rb")
except FileNotFoundError as err:
raise KeyError(f"Virtual file not found: {key}") from err

def store(self, key: str, buffer: bytes) -> None:
self._base_dir.mkdir(parents=True, exist_ok=True)
# Atomic write: write to a unique temp sibling then rename. Prevents
# cross-process readers from seeing a half-written file; mkstemp's
# uniqueness avoids collisions when two stores race on one key.
fd, tmp_str = tempfile.mkstemp(
prefix=f"{key}.", suffix=".tmp", dir=self._base_dir
)
tmp = Path(tmp_str)
try:
with os.fdopen(fd, "wb") as f:
f.write(buffer)
tmp.replace(self._path(key))
except BaseException:
tmp.unlink(missing_ok=True)
raise
self._owned_keys.add(key)

def read(self, key: str, byte_length: int) -> bytes:
with self._open(key) as f:
return f.read(byte_length)

def read_chunked(
self,
key: str,
byte_length: int,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Iterator[bytes]:
with self._open(key) as f:
remaining = byte_length
while remaining > 0:
chunk = f.read(min(chunk_size, remaining))
if not chunk:
break
remaining -= len(chunk)
yield chunk

def remove(self, key: str) -> None:
self._path(key).unlink(missing_ok=True)
self._owned_keys.discard(key)

def shutdown(self, keys: Iterable[str] | None = None) -> None:
# When keys is None, only remove keys this instance wrote — avoids
# racing other kernels that share the temp directory.
target = list(keys) if keys is not None else list(self._owned_keys)
for key in target:
self.remove(key)
if keys is None:
self._owned_keys.clear()
self._stale = True

def has(self, key: str) -> bool:
return self._path(key).exists()


class FallbackStorage(VirtualFileStorage):
"""Composite storage that tries each backend in order.

On ``store``, attempts each backend in order and falls back to the next
when the current one raises ``OSError`` (e.g., shared memory exhaustion).
Reads, removes, and existence checks are routed to the backend that
accepted the original ``store``; if no routing is recorded (e.g., a fresh
instance used only for cross-process reads), each backend is probed in
order via ``has``.
"""

def __init__(self, backends: list[VirtualFileStorage]) -> None:
if not backends:
raise ValueError("FallbackStorage requires at least one backend")
self._backends = backends
self._routing: dict[str, int] = {}

@property
def stale(self) -> bool:
return all(b.stale for b in self._backends)

def store(self, key: str, buffer: bytes) -> None:
if key in self._routing:
return
last_err: OSError | None = None
for i, backend in enumerate(self._backends):
if backend.stale:
continue
try:
backend.store(key, buffer)
except OSError as err:
last_err = err
LOGGER.warning(
"Virtual file storage backend %d (%s) failed to store "
"key %r (%d bytes): %s; trying next backend",
i,
type(backend).__name__,
key,
len(buffer),
err,
)
continue
self._routing[key] = i
return
assert last_err is not None
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated
raise last_err

def _backend_for(self, key: str) -> VirtualFileStorage | None:
idx = self._routing.get(key)
if idx is not None:
return self._backends[idx]
for backend in self._backends:
if backend.has(key):
return backend
return None

def read(self, key: str, byte_length: int) -> bytes:
backend = self._backend_for(key)
if backend is None:
raise KeyError(f"Virtual file not found: {key}")
return backend.read(key, byte_length)

def read_chunked(
self,
key: str,
byte_length: int,
chunk_size: int = DEFAULT_CHUNK_SIZE,
) -> Iterator[bytes]:
backend = self._backend_for(key)
if backend is None:
raise KeyError(f"Virtual file not found: {key}")
yield from backend.read_chunked(key, byte_length, chunk_size)

def remove(self, key: str) -> None:
idx = self._routing.pop(key, None)
if idx is not None:
self._backends[idx].remove(key)
return
for backend in self._backends:
backend.remove(key)

def shutdown(self, keys: Iterable[str] | None = None) -> None:
if keys is None:
for backend in self._backends:
backend.shutdown()
self._routing.clear()
return
key_list = list(keys)
for backend in self._backends:
backend.shutdown(keys=key_list)
for key in key_list:
self._routing.pop(key, None)

def has(self, key: str) -> bool:
if key in self._routing:
return self._backends[self._routing[key]].has(key)
return any(b.has(key) for b in self._backends)


class VirtualFileStorageManager:
"""Singleton manager for virtual file storage access."""

Expand All @@ -269,6 +470,13 @@ def storage(self) -> VirtualFileStorage | None:
def storage(self, value: VirtualFileStorage | None) -> None:
self._storage = value

@staticmethod
def _cross_process_storage() -> VirtualFileStorage:
"""Build an ad-hoc reader that can locate files in either tier
written by the kernel subprocess in EDIT mode.
"""
return FallbackStorage([SharedMemoryStorage(), DiskStorage()])

def read(self, filename: str, byte_length: int) -> bytes:
"""Read from storage, with cross-process fallback for EDIT mode server.

Expand All @@ -279,8 +487,9 @@ def read(self, filename: str, byte_length: int) -> bytes:
storage = self.storage
if storage is None:
# Never initialized so in a separate thread from the kernel.
# Use SharedMemoryStorage to read by name across processes
return SharedMemoryStorage().read(filename, byte_length)
# Probe shared memory first, then disk, to locate files written
# by the kernel subprocess.
return self._cross_process_storage().read(filename, byte_length)
return storage.read(filename, byte_length)

def read_chunked(
Expand All @@ -300,7 +509,7 @@ def read_chunked(
"""
storage = self.storage
if storage is None:
yield from SharedMemoryStorage().read_chunked(
yield from self._cross_process_storage().read_chunked(
filename, byte_length, chunk_size
)
else:
Expand Down
Loading
Loading