diff --git a/marimo/_runtime/context/kernel_context.py b/marimo/_runtime/context/kernel_context.py index a4e7ba3dd51..adcef8b8c26 100644 --- a/marimo/_runtime/context/kernel_context.py +++ b/marimo/_runtime/context/kernel_context.py @@ -152,6 +152,8 @@ def create_kernel_context( from marimo._plugins.ui._core.registry import UIElementRegistry from marimo._runtime.state import StateRegistry from marimo._runtime.virtual_file import ( + DiskStorage, + FallbackStorage, InMemoryStorage, SharedMemoryStorage, VirtualFileRegistry, @@ -163,8 +165,11 @@ def create_kernel_context( # Storage is chosen explicitly by the caller. None means virtual files # are not supported; we still construct an (inert) InMemoryStorage so # the registry has a backend, but ctx.virtual_files_supported is False. + # In EDIT mode, fall back to disk if shared memory cannot allocate + # (e.g., /dev/shm is full). Disk is cross-process readable, so the + # main-process server can still serve files via the /@file endpoint. storage: VirtualFileStorage = ( - SharedMemoryStorage() + FallbackStorage([SharedMemoryStorage(), DiskStorage()]) if virtual_file_storage == "shared_memory" else InMemoryStorage() ) diff --git a/marimo/_runtime/virtual_file/__init__.py b/marimo/_runtime/virtual_file/__init__.py index ecf2974ce52..a9537cf4c7e 100644 --- a/marimo/_runtime/virtual_file/__init__.py +++ b/marimo/_runtime/virtual_file/__init__.py @@ -8,6 +8,8 @@ from __future__ import annotations from marimo._runtime.virtual_file.storage import ( + DiskStorage, + FallbackStorage, InMemoryStorage, SharedMemoryStorage, VirtualFileStorage, @@ -31,6 +33,8 @@ "VirtualFileStorageType", "SharedMemoryStorage", "InMemoryStorage", + "DiskStorage", + "FallbackStorage", "VirtualFileStorageManager", # Virtual files "VirtualFile", diff --git a/marimo/_runtime/virtual_file/storage.py b/marimo/_runtime/virtual_file/storage.py index 0a9143a9211..7e674799e13 100644 --- a/marimo/_runtime/virtual_file/storage.py +++ b/marimo/_runtime/virtual_file/storage.py @@ -1,14 +1,20 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations +import os import sys -from typing import TYPE_CHECKING, Literal, Protocol +import tempfile +from pathlib import Path +from typing import IO, TYPE_CHECKING, Literal, Protocol +from marimo import _loggers from marimo._utils.platform import is_pyodide if TYPE_CHECKING: from collections.abc import Iterable, Iterator +LOGGER = _loggers.marimo_logger() + DEFAULT_CHUNK_SIZE = 256 * 1024 # 256KB VirtualFileStorageType = Literal["in_memory", "shared_memory"] @@ -155,8 +161,13 @@ def read_chunked( view = None try: shm = shared_memory.SharedMemory(name=key) + # Slice clamps to actual segment size; iterate over the + # clamped length, not the requested byte_length, so a huge + # byte_length from the URL doesn't trigger a long sequence + # of empty yields after the segment is exhausted. view = shm.buf[:byte_length] - for i in range(0, byte_length, chunk_size): + actual_length = len(view) + for i in range(0, actual_length, chunk_size): yield bytes(view[i : i + chunk_size]) except FileNotFoundError as err: raise KeyError(f"Virtual file not found: {key}") from err @@ -196,7 +207,22 @@ def shutdown(self, keys: Iterable[str] | None = None) -> None: self._stale = True def has(self, key: str) -> bool: - return key in self._storage + if key in self._storage: + return True + # Cross-process probe: try to open the segment by name. Used by + # FallbackStorage to locate keys written by another process. + # Catch OSError (parent of FileNotFoundError, plus errno variants + # like ENOENT/EACCES on hostile keys) and ValueError (invalid name + # characters on some platforms) so probing falls through to the + # next backend instead of crashing. + if is_pyodide(): + return False + try: + shm = shared_memory.SharedMemory(name=key) + except (OSError, ValueError): + return False + shm.close() + return True class InMemoryStorage(VirtualFileStorage): @@ -248,6 +274,239 @@ def has(self, key: str) -> bool: return key in self._storage +class DiskStorage(VirtualFileStorage): + """Storage backend that writes virtual file bytes to a temp directory. + + Cross-process safe: any process can read a key by computing the same path. + Used as a fallback tier when ``SharedMemoryStorage`` cannot allocate + (e.g., ``/dev/shm`` is full). + """ + + def __init__(self, base_dir: str | Path | None = None) -> None: + self._base_dir = ( + Path(base_dir) + if base_dir is not None + else Path(tempfile.gettempdir()) / "marimo-vfs" + ) + # Refuse a base dir that already exists as a symlink. Defends + # against the classic shared-/tmp pre-creation attack where a + # co-tenant pre-creates `/tmp/marimo-vfs` as a symlink to + # somewhere they control, redirecting all our writes. + if self._base_dir.is_symlink(): + raise OSError( + f"Refusing to use {self._base_dir}: path is a symlink" + ) + self._owned_keys: set[str] = set() + self._stale = False + + @property + def stale(self) -> bool: + return self._stale + + def _path(self, key: str) -> Path: + # Reject keys that would escape base_dir (path separators, traversal + # components, null bytes). Virtual file keys are generated by + # `random_filename` as a single filename component, so any key with + # path-like structure is either a bug or a malicious request via + # the `/@file/{...:path}` endpoint. Treating these as missing keeps + # the disk fallback safe even if the read endpoint forwards the + # raw URL segment. + if ( + not key + or "/" in key + or "\\" in key + or "\x00" in key + or key in (".", "..") + ): + raise KeyError(f"Invalid virtual file key: {key!r}") + return self._base_dir / key + + def _open(self, key: str) -> IO[bytes]: + try: + return self._path(key).open("rb") + except FileNotFoundError as err: + raise KeyError(f"Virtual file not found: {key}") from err + + def store(self, key: str, buffer: bytes) -> None: + # Validate up front: rejects path-traversal keys before mkstemp + # uses `key` as a filename prefix. + target = self._path(key) + self._base_dir.mkdir(parents=True, exist_ok=True) + # Atomic write: write to a unique temp sibling then rename. Prevents + # cross-process readers from seeing a half-written file; mkstemp's + # uniqueness avoids collisions when two stores race on one key. + fd, tmp_str = tempfile.mkstemp( + prefix=f"{key}.", suffix=".tmp", dir=self._base_dir + ) + tmp = Path(tmp_str) + try: + with os.fdopen(fd, "wb") as f: + f.write(buffer) + tmp.replace(target) + except BaseException: + tmp.unlink(missing_ok=True) + raise + self._owned_keys.add(key) + + def read(self, key: str, byte_length: int) -> bytes: + with self._open(key) as f: + return f.read(byte_length) + + def read_chunked( + self, + key: str, + byte_length: int, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> Iterator[bytes]: + with self._open(key) as f: + remaining = byte_length + while remaining > 0: + chunk = f.read(min(chunk_size, remaining)) + if not chunk: + break + remaining -= len(chunk) + yield chunk + + def remove(self, key: str) -> None: + self._path(key).unlink(missing_ok=True) + self._owned_keys.discard(key) + + def shutdown(self, keys: Iterable[str] | None = None) -> None: + # When keys is None, only remove keys this instance wrote — avoids + # racing other kernels that share the temp directory. + target = list(keys) if keys is not None else list(self._owned_keys) + for key in target: + self.remove(key) + if keys is None: + self._owned_keys.clear() + self._stale = True + + def has(self, key: str) -> bool: + # Invalid keys never exist — return False rather than raising so + # FallbackStorage probing handles hostile keys gracefully. + try: + return self._path(key).exists() + except KeyError: + return False + + +class FallbackStorage(VirtualFileStorage): + """Composite storage that tries each backend in order. + + On ``store``, attempts each backend in order and falls back to the next + when the current one raises ``OSError`` (e.g., shared memory exhaustion). + Reads, removes, and existence checks are routed to the backend that + accepted the original ``store``; if no routing is recorded (e.g., a fresh + instance used only for cross-process reads), each backend is probed in + order via ``has``. + """ + + def __init__(self, backends: list[VirtualFileStorage]) -> None: + if not backends: + raise ValueError("FallbackStorage requires at least one backend") + self._backends = backends + self._routing: dict[str, int] = {} + + @property + def stale(self) -> bool: + return all(b.stale for b in self._backends) + + def store(self, key: str, buffer: bytes) -> None: + if key in self._routing: + return + last_err: OSError | None = None + for i, backend in enumerate(self._backends): + if backend.stale: + continue + try: + backend.store(key, buffer) + except OSError as err: + last_err = err + LOGGER.warning( + "Virtual file storage backend %d (%s) failed to store " + "key %r (%d bytes): %s; trying next backend", + i, + type(backend).__name__, + key, + len(buffer), + err, + ) + continue + self._routing[key] = i + return + if last_err is not None: + raise last_err + # All backends were stale (no store was attempted). + raise OSError( + "All virtual file storage backends are stale; cannot store " + f"key {key!r}" + ) + + def _backend_for(self, key: str) -> VirtualFileStorage | None: + idx = self._routing.get(key) + if idx is not None: + return self._backends[idx] + for backend in self._backends: + if backend.has(key): + return backend + return None + + def read(self, key: str, byte_length: int) -> bytes: + backend = self._backend_for(key) + if backend is None: + raise KeyError(f"Virtual file not found: {key}") + return backend.read(key, byte_length) + + def read_chunked( + self, + key: str, + byte_length: int, + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> Iterator[bytes]: + backend = self._backend_for(key) + if backend is None: + raise KeyError(f"Virtual file not found: {key}") + yield from backend.read_chunked(key, byte_length, chunk_size) + + def remove(self, key: str) -> None: + idx = self._routing.pop(key, None) + if idx is not None: + self._backends[idx].remove(key) + return + for backend in self._backends: + backend.remove(key) + + def shutdown(self, keys: Iterable[str] | None = None) -> None: + # Forward to every backend, isolating failures: one backend's + # cleanup error must not prevent the others from releasing + # resources (shared memory segments, disk files). + key_list = list(keys) if keys is not None else None + for i, backend in enumerate(self._backends): + try: + if key_list is None: + backend.shutdown() + else: + backend.shutdown(keys=key_list) + except Exception as err: + LOGGER.warning( + "Virtual file storage backend %d (%s) failed during " + "shutdown: %s", + i, + type(backend).__name__, + err, + ) + if key_list is None: + self._routing.clear() + else: + for key in key_list: + self._routing.pop(key, None) + + def has(self, key: str) -> bool: + if key in self._routing: + return self._backends[self._routing[key]].has(key) + return any(b.has(key) for b in self._backends) + + class VirtualFileStorageManager: """Singleton manager for virtual file storage access.""" @@ -269,6 +528,13 @@ def storage(self) -> VirtualFileStorage | None: def storage(self, value: VirtualFileStorage | None) -> None: self._storage = value + @staticmethod + def _cross_process_storage() -> VirtualFileStorage: + """Build an ad-hoc reader that can locate files in either tier + written by the kernel subprocess in EDIT mode. + """ + return FallbackStorage([SharedMemoryStorage(), DiskStorage()]) + def read(self, filename: str, byte_length: int) -> bytes: """Read from storage, with cross-process fallback for EDIT mode server. @@ -279,8 +545,9 @@ def read(self, filename: str, byte_length: int) -> bytes: storage = self.storage if storage is None: # Never initialized so in a separate thread from the kernel. - # Use SharedMemoryStorage to read by name across processes - return SharedMemoryStorage().read(filename, byte_length) + # Probe shared memory first, then disk, to locate files written + # by the kernel subprocess. + return self._cross_process_storage().read(filename, byte_length) return storage.read(filename, byte_length) def read_chunked( @@ -300,7 +567,7 @@ def read_chunked( """ storage = self.storage if storage is None: - yield from SharedMemoryStorage().read_chunked( + yield from self._cross_process_storage().read_chunked( filename, byte_length, chunk_size ) else: diff --git a/marimo/_runtime/virtual_file/virtual_file.py b/marimo/_runtime/virtual_file/virtual_file.py index 3f7413792df..8ed4697e002 100644 --- a/marimo/_runtime/virtual_file/virtual_file.py +++ b/marimo/_runtime/virtual_file/virtual_file.py @@ -5,6 +5,7 @@ import dataclasses import mimetypes import random +import re import string import threading from typing import TYPE_CHECKING, cast @@ -32,6 +33,21 @@ _ALPHABET = string.ascii_letters + string.digits +# Matches the shape produced by `random_filename`: +# -<8 alphanumerics>. +# The extension is restricted to a conservative set (alphanumeric plus +# `._-`) and capped in length so that an attacker who controls the +# `/@file/{...:path}` URL segment cannot smuggle through control +# characters, path separators, or a filename matching some other +# process's POSIX shared-memory segment name. +_VALID_VFILE_NAME = re.compile( + r"^[0-9]+-[A-Za-z0-9]{8}\.[A-Za-z0-9._-]{1,32}$" +) + + +def _is_valid_vfile_name(filename: str) -> bool: + return bool(_VALID_VFILE_NAME.fullmatch(filename)) + def random_filename(ext: str) -> str: # adapted from: https://stackoverflow.com/questions/13484726/safe-enough-8-character-short-unique-random-string @@ -300,6 +316,8 @@ def _without_leading_dot(ext: str) -> str: def read_virtual_file(filename: str, byte_length: int) -> bytes: + if not _is_valid_vfile_name(filename): + raise HTTPException(HTTPStatus.NOT_FOUND, detail="File not found") try: return VirtualFileStorageManager().read(filename, byte_length) except KeyError as err: @@ -307,6 +325,15 @@ def read_virtual_file(filename: str, byte_length: int) -> bytes: HTTPStatus.NOT_FOUND, detail="File not found", ) from err + except OSError as err: + # Storage backend hit an I/O error (e.g. shared memory unlinked + # mid-read, disk read failed). Treat as not-found from the + # client's perspective rather than crashing the server worker. + LOGGER.warning("I/O error reading virtual file %s: %s", filename, err) + raise HTTPException( + HTTPStatus.NOT_FOUND, + detail="File not readable", + ) from err def read_virtual_file_chunked( @@ -316,7 +343,18 @@ def read_virtual_file_chunked( Yields chunks of bytes, avoiding holding the entire file in memory as a single bytes object. + + Validation happens here (before returning the generator) rather than + inside `_read_chunks` so that an invalid filename raises *before* + StreamingResponse starts sending headers — otherwise the framework + raises "Caught handled exception, but response already started". """ + if not _is_valid_vfile_name(filename): + raise HTTPException(HTTPStatus.NOT_FOUND, detail="File not found") + return _read_chunks(filename, byte_length) + + +def _read_chunks(filename: str, byte_length: int) -> Iterator[bytes]: try: yield from VirtualFileStorageManager().read_chunked( filename, byte_length @@ -326,3 +364,9 @@ def read_virtual_file_chunked( HTTPStatus.NOT_FOUND, detail="File not found", ) from err + except OSError as err: + LOGGER.warning("I/O error reading virtual file %s: %s", filename, err) + raise HTTPException( + HTTPStatus.NOT_FOUND, + detail="File not readable", + ) from err diff --git a/tests/_runtime/test_storage.py b/tests/_runtime/test_storage.py index 672f805e623..6a3f9b1aa5c 100644 --- a/tests/_runtime/test_storage.py +++ b/tests/_runtime/test_storage.py @@ -1,13 +1,22 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations +import sys + import pytest from marimo._runtime.virtual_file.storage import ( + DiskStorage, + FallbackStorage, InMemoryStorage, SharedMemoryStorage, VirtualFileStorageManager, ) +from marimo._runtime.virtual_file.virtual_file import ( + read_virtual_file, + read_virtual_file_chunked, +) +from marimo._utils.http import HTTPException class TestInMemoryStorageReadChunked: @@ -268,6 +277,20 @@ def test_read_chunked_cross_process(self) -> None: finally: storage1.shutdown() + @pytest.mark.parametrize( + "key", + # Keys the OS shared-memory namespace will reject with OSError or + # ValueError rather than FileNotFoundError. Covers the broadened + # exception catch in has(): probing must return False, not crash. + ["", "/", "//", "with/slash", "x" * 4096], + ) + def test_has_returns_false_on_invalid_keys(self, key: str) -> None: + storage = SharedMemoryStorage() + try: + assert storage.has(key) is False + finally: + storage.shutdown() + def test_read_chunked_data_integrity(self) -> None: """Test that chunked read produces identical data to regular read.""" storage = SharedMemoryStorage() @@ -285,6 +308,494 @@ def test_read_chunked_data_integrity(self) -> None: storage.shutdown() +class TestDiskStorage: + @pytest.fixture + def storage(self, tmp_path) -> DiskStorage: + return DiskStorage(base_dir=tmp_path) + + def test_store_and_read(self, storage: DiskStorage) -> None: + storage.store("k", b"hello world") + assert storage.read("k", 11) == b"hello world" + + def test_read_with_byte_length(self, storage: DiskStorage) -> None: + storage.store("k", b"hello world") + assert storage.read("k", 5) == b"hello" + + def test_read_nonexistent_raises_keyerror( + self, storage: DiskStorage + ) -> None: + with pytest.raises(KeyError, match="Virtual file not found"): + storage.read("missing", 10) + + def test_read_chunked(self, storage: DiskStorage) -> None: + data = b"a" * 100 + storage.store("k", data) + chunks = list(storage.read_chunked("k", 100, chunk_size=30)) + assert b"".join(chunks) == data + assert len(chunks) == 4 + + def test_read_chunked_with_byte_length(self, storage: DiskStorage) -> None: + storage.store("k", b"hello world") + chunks = list(storage.read_chunked("k", 5)) + assert b"".join(chunks) == b"hello" + + def test_read_chunked_nonexistent_raises_keyerror( + self, storage: DiskStorage + ) -> None: + with pytest.raises(KeyError, match="Virtual file not found"): + list(storage.read_chunked("missing", 10)) + + def test_remove(self, storage: DiskStorage) -> None: + storage.store("k", b"hi") + assert storage.has("k") + storage.remove("k") + assert not storage.has("k") + + def test_remove_nonexistent_no_error(self, storage: DiskStorage) -> None: + storage.remove("missing") + + def test_has(self, storage: DiskStorage) -> None: + assert not storage.has("k") + storage.store("k", b"hi") + assert storage.has("k") + + def test_shutdown_only_removes_owned_keys(self, tmp_path) -> None: + # Two instances share a directory; shutdown of one must not nuke the + # other's files (mirrors the cross-process kernel/server scenario). + a = DiskStorage(base_dir=tmp_path) + b = DiskStorage(base_dir=tmp_path) + a.store("ak", b"a-data") + b.store("bk", b"b-data") + a.shutdown() + assert not (tmp_path / "ak").exists() + assert (tmp_path / "bk").exists() + b.shutdown() + + def test_shutdown_with_keys(self, storage: DiskStorage) -> None: + storage.store("k1", b"d1") + storage.store("k2", b"d2") + storage.shutdown(keys=["k1"]) + assert not storage.has("k1") + assert storage.has("k2") + assert not storage.stale # only full shutdown sets stale + + def test_shutdown_full_marks_stale(self, storage: DiskStorage) -> None: + assert not storage.stale + storage.shutdown() + assert storage.stale + + def test_cross_process_read_via_path(self, tmp_path) -> None: + # A second instance pointing at the same dir can read what the first + # wrote — the foundation of cross-process serving. + writer = DiskStorage(base_dir=tmp_path) + writer.store("shared", b"shared bytes") + reader = DiskStorage(base_dir=tmp_path) + assert reader.read("shared", 12) == b"shared bytes" + writer.shutdown() + + @pytest.mark.parametrize( + "key", + [ + "../escape", + "../../etc/passwd", + "foo/bar", + "foo\\bar", + "..", + ".", + "", + "with\x00null", + ], + ) + def test_rejects_path_traversal_keys( + self, storage: DiskStorage, key: str + ) -> None: + # The /@file/{...:path} HTTP route forwards raw URL segments as + # keys. DiskStorage must reject path-traversal keys before + # touching the filesystem so an attacker can't escape base_dir. + with pytest.raises(KeyError, match="Invalid virtual file key"): + storage.store(key, b"data") + with pytest.raises(KeyError, match="Invalid virtual file key"): + storage.read(key, 10) + with pytest.raises(KeyError, match="Invalid virtual file key"): + list(storage.read_chunked(key, 10)) + with pytest.raises(KeyError, match="Invalid virtual file key"): + storage.remove(key) + # has() is a query and must return False (not raise) so that + # FallbackStorage probing doesn't crash on hostile keys. + assert storage.has(key) is False + + def test_traversal_does_not_create_outside_base_dir( + self, tmp_path + ) -> None: + outside = tmp_path / "outside" + outside.mkdir() + base = tmp_path / "base" + storage = DiskStorage(base_dir=base) + with pytest.raises(KeyError): + storage.store("../outside/leak", b"secret") + assert not (outside / "leak").exists() + + +class _RaisingStorage(InMemoryStorage): + """Test double: raises OSError on store, otherwise behaves normally.""" + + def __init__(self, errno: int = 28) -> None: + super().__init__() + self._errno = errno + self.store_calls = 0 + + def store(self, key: str, buffer: bytes) -> None: # noqa: ARG002 + self.store_calls += 1 + raise OSError(self._errno, "No space left on device") + + +class TestFallbackStorage: + def test_requires_at_least_one_backend(self) -> None: + with pytest.raises(ValueError, match="at least one"): + FallbackStorage([]) + + def test_uses_first_backend_when_healthy(self) -> None: + primary = InMemoryStorage() + secondary = InMemoryStorage() + fb = FallbackStorage([primary, secondary]) + fb.store("k", b"data") + assert primary.has("k") + assert not secondary.has("k") + assert fb.read("k", 4) == b"data" + + def test_falls_back_on_oserror(self) -> None: + primary = _RaisingStorage() + secondary = InMemoryStorage() + fb = FallbackStorage([primary, secondary]) + fb.store("k", b"data") + assert primary.store_calls == 1 + assert secondary.has("k") + assert fb.read("k", 4) == b"data" + + def test_reraises_when_all_backends_fail(self) -> None: + a = _RaisingStorage(errno=28) + b = _RaisingStorage(errno=12) + fb = FallbackStorage([a, b]) + with pytest.raises(OSError) as exc: + fb.store("k", b"data") + # Last error is propagated + assert exc.value.errno == 12 + + def test_raises_oserror_when_all_backends_stale(self, tmp_path) -> None: + # If every backend is stale, no store is attempted and last_err + # stays None — must raise an explicit OSError, not an + # AssertionError (or TypeError under `python -O`). + a = DiskStorage(base_dir=tmp_path / "a") + a.shutdown() + b = DiskStorage(base_dir=tmp_path / "b") + b.shutdown() + fb = FallbackStorage([a, b]) + with pytest.raises(OSError, match="stale"): + fb.store("k", b"data") + + def test_routing_directs_reads_and_removes(self) -> None: + primary = _RaisingStorage() + secondary = InMemoryStorage() + fb = FallbackStorage([primary, secondary]) + fb.store("k", b"hello") + assert fb.has("k") + assert b"".join(fb.read_chunked("k", 5, chunk_size=2)) == b"hello" + fb.remove("k") + assert not fb.has("k") + assert not secondary.has("k") + + def test_probe_path_for_unrouted_keys(self, tmp_path) -> None: + # Mirrors the cross-process reader scenario: write via one instance, + # read via a fresh FallbackStorage that has no routing entry. + writer = DiskStorage(base_dir=tmp_path) + writer.store("shared", b"shared bytes") + reader = FallbackStorage( + [InMemoryStorage(), DiskStorage(base_dir=tmp_path)] + ) + assert reader.has("shared") + assert reader.read("shared", 12) == b"shared bytes" + chunks = list(reader.read_chunked("shared", 12, chunk_size=4)) + assert b"".join(chunks) == b"shared bytes" + writer.shutdown() + + def test_read_missing_raises_keyerror(self) -> None: + fb = FallbackStorage([InMemoryStorage(), InMemoryStorage()]) + with pytest.raises(KeyError, match="Virtual file not found"): + fb.read("missing", 10) + with pytest.raises(KeyError, match="Virtual file not found"): + list(fb.read_chunked("missing", 10)) + + def test_skips_stale_backends(self, tmp_path) -> None: + stale_disk = DiskStorage(base_dir=tmp_path) + stale_disk.shutdown() # marks _stale = True + assert stale_disk.stale + ok = InMemoryStorage() + fb = FallbackStorage([stale_disk, ok]) + fb.store("k", b"hi") + assert ok.has("k") + + def test_shutdown_propagates_to_all_backends(self) -> None: + a = InMemoryStorage() + b = InMemoryStorage() + fb = FallbackStorage([a, b]) + fb.store("k1", b"d1") + # Manually populate b to simulate independent state + b.store("k2", b"d2") + fb.shutdown() + assert not a.has("k1") + assert not b.has("k2") + + def test_shutdown_with_keys_propagates(self) -> None: + a = InMemoryStorage() + b = InMemoryStorage() + fb = FallbackStorage([a, b]) + fb.store("k1", b"d1") + b.store("k2", b"d2") + fb.shutdown(keys=["k1", "k2"]) + assert not a.has("k1") + assert not b.has("k2") + + def test_stale_iff_all_backends_stale(self, tmp_path) -> None: + a = DiskStorage(base_dir=str(tmp_path / "a")) + b = DiskStorage(base_dir=str(tmp_path / "b")) + fb = FallbackStorage([a, b]) + assert not fb.stale + a.shutdown() + assert not fb.stale + b.shutdown() + assert fb.stale + + +class _FailingReadStorage(InMemoryStorage): + """Stores normally but raises on read — simulates a storage segment that + was successfully created but becomes unreadable later (corrupted shared + memory, disk I/O error mid-stream, etc.). + """ + + def __init__(self, exc: BaseException) -> None: + super().__init__() + self._exc = exc + + def read(self, key: str, byte_length: int) -> bytes: # noqa: ARG002 + raise self._exc + + def read_chunked(self, key, byte_length, chunk_size=...): # type: ignore[override] # noqa: ARG002 + raise self._exc + + def has(self, key: str) -> bool: # noqa: ARG002 + return True + + +class _FailingShutdownStorage(InMemoryStorage): + """Raises on shutdown. Used to verify that one bad backend doesn't + block cleanup of the others. + """ + + def shutdown(self, keys=None) -> None: # noqa: ARG002 + raise OSError(5, "I/O error during shutdown") + + +class TestKernelResilience: + """Failure-mode tests: confirm that storage errors don't crash the + kernel or the HTTP server beyond the affected request/cell. + """ + + def test_read_oserror_returns_404_not_500(self) -> None: + """An OSError raised during read should surface as a 404 + HTTPException (treated as 'not found' from the client's POV) rather + than propagating to the server worker as a 500. Otherwise a single + flaky read crashes the streaming response. + """ + manager = VirtualFileStorageManager() + original = manager.storage + try: + manager.storage = _FailingReadStorage( + OSError(5, "I/O error reading shared memory") + ) + with pytest.raises(HTTPException) as exc: + read_virtual_file("12345-AbCdEfGh.bin", 10) + assert exc.value.status_code == 404 + finally: + manager.storage = original + + def test_read_chunked_oserror_returns_404_not_500(self) -> None: + manager = VirtualFileStorageManager() + original = manager.storage + try: + manager.storage = _FailingReadStorage( + OSError(5, "I/O error reading shared memory") + ) + with pytest.raises(HTTPException) as exc: + list(read_virtual_file_chunked("12345-AbCdEfGh.bin", 10)) + assert exc.value.status_code == 404 + finally: + manager.storage = original + + def test_fallback_shutdown_tolerates_per_backend_failures(self) -> None: + """If one backend's shutdown raises, the others must still get + their shutdown called. Otherwise a transient cleanup error during + session end leaks shared memory and disk files. + """ + bad = _FailingShutdownStorage() + good = InMemoryStorage() + good.store("k", b"data") + fb = FallbackStorage([bad, good]) + # Should not raise; the bad backend's failure should be logged and + # the good backend should still be shut down. + fb.shutdown() + assert not good.has("k"), ( + "good backend was not shut down because bad backend raised" + ) + + def test_fallback_shutdown_with_keys_tolerates_failures(self) -> None: + bad = _FailingShutdownStorage() + good = InMemoryStorage() + good.store("k", b"data") + fb = FallbackStorage([bad, good]) + fb.shutdown(keys=["k"]) + assert not good.has("k") + + def test_read_endpoint_blocks_arbitrary_shared_memory_access( + self, + ) -> None: + """Demonstrates the cross-process SHM read attack: an authenticated + client requests /@file/-; without + filename validation, the server's cross-process fallback opens + that segment by name and serves its contents. With validation, + non-marimo-shaped filenames are rejected at the boundary. + """ + from multiprocessing import shared_memory + + # Simulate another process's shared memory segment. + name = "not_a_marimo_virtual_file_name" + shm = shared_memory.SharedMemory(name=name, create=True, size=64) + shm.buf[:6] = b"secret" + try: + manager = VirtualFileStorageManager() + original = manager.storage + manager.storage = None # force cross-process probe path + try: + with pytest.raises(HTTPException) as exc: + read_virtual_file(name, 6) + assert exc.value.status_code == 404 + with pytest.raises(HTTPException) as exc: + list(read_virtual_file_chunked(name, 6)) + assert exc.value.status_code == 404 + finally: + manager.storage = original + finally: + shm.close() + shm.unlink() + + @pytest.mark.parametrize( + "filename", + [ + "/etc/passwd", + "../../etc/passwd", + "with space.png", + "no-extension", + "..", + "", + "12345-abcdefgh.\npng", # control char + ], + ) + def test_read_endpoint_rejects_unmarimo_filenames( + self, filename: str + ) -> None: + with pytest.raises(HTTPException) as exc: + read_virtual_file(filename, 10) + assert exc.value.status_code == 404 + with pytest.raises(HTTPException) as exc: + list(read_virtual_file_chunked(filename, 10)) + assert exc.value.status_code == 404 + + def test_read_endpoint_accepts_marimo_filenames(self) -> None: + """Sanity: legitimate random_filename outputs pass validation + (then KeyError-out as 'not found' since nothing is stored). + Confirms validation isn't over-restrictive. + """ + manager = VirtualFileStorageManager() + original = manager.storage + try: + manager.storage = InMemoryStorage() + with pytest.raises(HTTPException) as exc: + read_virtual_file("12345-AbCdEf12.png", 10) + assert exc.value.status_code == 404 + assert exc.value.detail == "File not found" + finally: + manager.storage = original + + def test_shared_memory_chunked_huge_byte_length_terminates(self) -> None: + """SharedMemoryStorage.read_chunked must bound iterations by + the actual segment size, not the URL-supplied byte_length. + Otherwise an attacker requesting /@file/- causes the + loop to yield (huge / chunk_size) chunks — a DoS that can keep + a worker busy for arbitrarily long. + + Note: shm segments are page-allocated, so a 2-byte store actually + backs onto ~one page; we just need to confirm the response stays + bounded near the page size, not anywhere near byte_length. + """ + storage = SharedMemoryStorage() + try: + storage.store("marimo_dos_test", b"hi") + chunks = list( + storage.read_chunked( + "marimo_dos_test", + byte_length=10**10, + chunk_size=1024, + ) + ) + total = sum(len(c) for c in chunks) + # A 2-byte store should never expand into megabytes of + # response just because the URL asked for a huge length. + assert total < 10**6, ( + f"DoS regression: returned {total} bytes for a 2-byte file" + ) + assert len(chunks) > 0 + assert chunks[0].startswith(b"hi") + finally: + storage.shutdown() + + def test_disk_storage_rejects_symlinked_base_dir(self, tmp_path) -> None: + """DiskStorage must refuse to use a base_dir that exists as + a symlink — defends against the classic /tmp pre-creation attack + where a co-tenant pre-creates `/tmp/marimo-vfs` as a symlink to + an attacker-controlled directory. + """ + target = tmp_path / "real" + target.mkdir() + link = tmp_path / "link" + link.symlink_to(target) + with pytest.raises(OSError, match="symlink"): + DiskStorage(base_dir=link) + + @pytest.mark.skipif( + sys.platform == "win32", + reason="POSIX inode semantics; Windows refuses to unlink open files", + ) + def test_disk_storage_read_chunked_survives_mid_stream_unlink( + self, tmp_path + ) -> None: + """DiskStorage.read_chunked is a generator. On POSIX, if the file + is removed mid-stream, the open fd survives (inode lives until + the fd is closed) and remaining chunks still come from the + file's contents. Captures this behaviour so any change is + intentional. + """ + storage = DiskStorage(base_dir=tmp_path) + storage.store("k", b"x" * 1000) + gen = storage.read_chunked("k", 1000, chunk_size=100) + first = next(gen) + assert len(first) == 100 + # Delete the file out from under the generator. + (tmp_path / "k").unlink() + # Remaining chunks should still come from the open file handle + # (POSIX semantics: the inode lives until the fd is closed). + rest = b"".join(gen) + assert len(first) + len(rest) == 1000 + + class TestVirtualFileStorageManager: def test_singleton(self) -> None: manager1 = VirtualFileStorageManager() diff --git a/tests/_runtime/test_virtual_file.py b/tests/_runtime/test_virtual_file.py index f46a44613af..a5c8bbed28e 100644 --- a/tests/_runtime/test_virtual_file.py +++ b/tests/_runtime/test_virtual_file.py @@ -1,8 +1,6 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -import uuid - from marimo._runtime.commands import DeleteCellCommand from marimo._runtime.context import get_context from marimo._runtime.runtime import Kernel @@ -15,6 +13,7 @@ VirtualFile, VirtualFileLifecycleItem, VirtualFileRegistry, + random_filename, read_virtual_file, ) from tests.conftest import ExecReqProvider, MockedKernel @@ -346,8 +345,10 @@ def test_virtual_file_registry_shared_shared_memory_storage() -> None: """ manager = VirtualFileStorageManager() original_storage = manager.storage - key1 = f"{uuid.uuid4().hex[:8]}.txt" - key2 = f"{uuid.uuid4().hex[:8]}.txt" + # Use the real filename shape (random_filename output) so the + # validation in read_virtual_file accepts these keys. + key1 = random_filename("txt") + key2 = random_filename("txt") context = type("Context", (), {"virtual_files_supported": True})() registry1 = None diff --git a/tests/_server/api/endpoints/test_assets.py b/tests/_server/api/endpoints/test_assets.py index 4e48789eefa..ef947a0e688 100644 --- a/tests/_server/api/endpoints/test_assets.py +++ b/tests/_server/api/endpoints/test_assets.py @@ -317,9 +317,11 @@ def test_vfile_large_streaming(client: TestClient) -> None: manager.storage = storage try: - # ~2 MB file, similar to a large anywidget ESM bundle + # ~2 MB file, similar to a large anywidget ESM bundle. + # Filename matches the random_filename pattern enforced by + # read_virtual_file_chunked (digits-8 alphanumeric.ext). data = b"x" * (2 * 1024 * 1024) - filename = "test-large.js" + filename = "12345-AbCdEf12.js" storage.store(filename, data) byte_length = len(data) @@ -358,7 +360,7 @@ def test_vfile_download_query_param_sets_content_disposition( try: data = b'[{"a": 1}]' - filename = "data.json" + filename = "12345-xYzWaBcD.json" storage.store(filename, data) byte_length = len(data) @@ -379,7 +381,7 @@ def test_vfile_download_query_param_sets_content_disposition( assert response.content == data cd = response.headers.get("content-disposition", "") assert cd.startswith("attachment") - assert "data.json" in cd + assert filename in cd # Custom download filename via ?filename=... response = client.get(