Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backend/app/channels/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.config.paths import make_safe_user_id
from deerflow.runtime.user_context import get_effective_user_id

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -599,12 +600,20 @@ def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str,
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id

# ``user_id`` drives user-scoped filesystem buckets that only accept
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
# under ``channel_user_id`` for platform-facing lookups.
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id

run_context = _merge_dicts(
DEFAULT_RUN_CONTEXT,
self._default_session.get("context"),
channel_layer.get("context"),
user_layer.get("context"),
{"thread_id": thread_id},
run_context_identity,
)

# Custom agents are implemented as lead_agent + agent_name context.
Expand Down
3 changes: 2 additions & 1 deletion backend/app/gateway/internal_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal"


def _load_internal_auth_token() -> str:
Expand All @@ -34,4 +35,4 @@ def is_valid_internal_auth_token(token: str | None) -> bool:

def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
15 changes: 14 additions & 1 deletion backend/app/gateway/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from langchain_core.messages.utils import convert_to_messages

from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config
from deerflow.runtime import (
Expand Down Expand Up @@ -140,7 +141,14 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
and ``config['context']`` so they are visible to legacy configurable readers and
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
see issue #2677)."""
see issue #2677).

``user_id`` is intentionally propagated into ``config['context']`` in addition to
the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in
``body.context`` keep it on ``ToolRuntime.context``. It is merged with
``setdefault`` so a server-authenticated id stamped by
:func:`inject_authenticated_user_context` always wins over the client-supplied one.
"""
if not context:
return
configurable = config.setdefault("configurable", {})
Expand All @@ -151,6 +159,8 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
configurable.setdefault(key, context[key])
if isinstance(runtime_context, dict):
runtime_context.setdefault(key, context[key])
if "user_id" in context and isinstance(runtime_context, dict):
runtime_context.setdefault("user_id", context["user_id"])
Comment on lines +162 to +163
Comment thread
zhongli-sz marked this conversation as resolved.


def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
Expand All @@ -166,6 +176,9 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request)
if user_id is None:
return

if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
return

runtime_context = config.setdefault("context", {})
if isinstance(runtime_context, dict):
runtime_context["user_id"] = str(user_id)
Expand Down
20 changes: 20 additions & 0 deletions backend/packages/harness/deerflow/config/paths.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import os
import re
import shutil
Expand All @@ -10,6 +11,8 @@

_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
_SAFE_USER_ID_DIGEST_HEX_LEN = 16


def _default_local_base_dir() -> Path:
Expand All @@ -31,6 +34,23 @@ def _validate_user_id(user_id: str) -> str:
return user_id


def make_safe_user_id(raw: str) -> str:
"""Normalize an external identity into the user-id charset (``[A-Za-z0-9_-]``).

IM channel ids (Feishu/Slack/Telegram) may contain characters that
:func:`_validate_user_id` rejects. Already-safe ids pass through unchanged;
lossy ones get a short digest suffix so two distinct inputs never share a
storage bucket.
"""
if not raw:
raise ValueError("user_id must be a non-empty string.")
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
if sanitized == raw:
return raw
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
return f"{sanitized}-{digest}"


def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style.

Expand Down
11 changes: 10 additions & 1 deletion backend/packages/harness/deerflow/mcp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
from collections.abc import Mapping
from typing import Any

from langchain_core.tools import BaseTool, StructuredTool
Expand Down Expand Up @@ -137,7 +138,15 @@ async def call_with_persistent_session(
from langchain_mcp_adapters.interceptors import MCPToolCallRequest

async def base_handler(request: MCPToolCallRequest) -> Any:
return await session.call_tool(request.name, request.args)
# Preserve interceptor-injected headers for stdio MCP calls by
# forwarding them through MCP call meta.
call_kwargs: dict[str, Any] = {}
if request.headers:
if isinstance(request.headers, Mapping):
call_kwargs["meta"] = {"headers": dict(request.headers)}
else:
logger.warning("Ignoring MCP interceptor headers with unsupported type: %s", type(request.headers).__name__)
return await session.call_tool(request.name, request.args, **call_kwargs)
Comment on lines +141 to +149

handler = base_handler
for interceptor in reversed(tool_interceptors):
Expand Down
45 changes: 45 additions & 0 deletions backend/tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,51 @@ async def capture(msg):
_run(go())


class TestResolveRunParamsUserId:
"""Regression for PR #3294: channel identity must reach ``run_context``
while staying safe for user-scoped filesystem buckets.
"""

def _manager(self):
from app.channels.manager import ChannelManager

bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
return ChannelManager(bus=bus, store=store)

def test_safe_user_id_is_passed_through(self):
manager = self._manager()
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")

_, _, run_context = manager._resolve_run_params(msg, "thread-1")

assert run_context["user_id"] == "123456"
assert run_context["channel_user_id"] == "123456"

def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
from deerflow.config.paths import make_safe_user_id

manager = self._manager()
raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")

_, _, run_context = manager._resolve_run_params(msg, "thread-1")

assert run_context["user_id"] == make_safe_user_id(raw)
assert run_context["user_id"] != raw
assert run_context["channel_user_id"] == raw

@pytest.mark.parametrize("raw_user_id", ["", None])
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
manager = self._manager()
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")

_, _, run_context = manager._resolve_run_params(msg, "thread-1")

assert "user_id" not in run_context
assert "channel_user_id" not in run_context


# ---------------------------------------------------------------------------
# ChannelService tests
# ---------------------------------------------------------------------------
Expand Down
43 changes: 43 additions & 0 deletions backend/tests/test_gateway_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,49 @@ def test_inject_authenticated_user_context_overrides_client_user_id():
assert config["context"]["user_id"] == "auth-user-42"


def test_merge_run_context_overrides_propagates_user_id():
"""Regression for PR #3294: ``user_id`` from ``body.context`` must land in
``config['context']`` so non-web callers (e.g. IM channels) keep their identity
on ``ToolRuntime.context``.
"""
from app.gateway.services import build_run_config, merge_run_context_overrides

config = build_run_config("thread-1", None, None)
merge_run_context_overrides(config, {"user_id": "channel-user-7"})

assert config["context"]["user_id"] == "channel-user-7"


def test_merge_run_context_overrides_does_not_clobber_existing_user_id():
"""``merge_run_context_overrides`` must not override an already-stamped
authenticated ``context.user_id`` with the client-supplied value.
"""
from app.gateway.services import build_run_config, merge_run_context_overrides

config = build_run_config("thread-1", {"context": {"user_id": "auth-user-42"}}, None)
merge_run_context_overrides(config, {"user_id": "spoofed-client"})

assert config["context"]["user_id"] == "auth-user-42"


def test_inject_authenticated_user_context_skips_internal_role():
"""Regression for PR #3294: internal system-role callers must not overwrite an
already-present ``context.user_id`` (e.g. a channel-supplied identity), so the
real end user keeps owning the per-user storage bucket.
"""
from types import SimpleNamespace

from app.gateway.services import build_run_config, inject_authenticated_user_context

config = build_run_config("thread-1", None, None)
config["context"] = {"user_id": "channel-user-7"}
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="internal-bot", system_role="internal")))

inject_authenticated_user_context(config, request)

assert config["context"]["user_id"] == "channel-user-7"


# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
Expand Down
130 changes: 130 additions & 0 deletions backend/tests/test_mcp_session_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,136 @@ class Args(BaseModel):
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})


@pytest.mark.asyncio
async def test_session_pool_tool_forwards_interceptor_headers():
"""Regression for PR #3294: when an interceptor sets ``request.headers``, the
pooled stdio call must forward them via ``meta={"headers": ...}`` so downstream
MCP servers can read auth/context headers.
"""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field

from deerflow.mcp.tools import _make_session_pool_tool

class Args(BaseModel):
x: int = Field(..., description="x")

original_tool = StructuredTool(
name="srv_act",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)

mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)

async def header_interceptor(request, handler):
return await handler(request.override(headers={"X-User-Id": "u-42"}))

with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(
original_tool,
"srv",
{"transport": "stdio", "command": "x", "args": []},
tool_interceptors=[header_interceptor],
)
await wrapped.coroutine(runtime=None, x=1)

mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}, meta={"headers": {"X-User-Id": "u-42"}})


@pytest.mark.asyncio
async def test_session_pool_tool_no_headers_omits_meta():
"""When no interceptor sets headers, the pooled call must not pass a ``meta``
kwarg (falls back to the plain two-argument ``call_tool``).
"""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field

from deerflow.mcp.tools import _make_session_pool_tool

class Args(BaseModel):
x: int = Field(..., description="x")

original_tool = StructuredTool(
name="srv_act",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)

mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)

async def passthrough_interceptor(request, handler):
return await handler(request)

with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(
original_tool,
"srv",
{"transport": "stdio", "command": "x", "args": []},
tool_interceptors=[passthrough_interceptor],
)
await wrapped.coroutine(runtime=None, x=1)

mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})


@pytest.mark.asyncio
async def test_session_pool_tool_ignores_unsupported_header_type(caplog):
"""Defensive path: non-mapping truthy headers should be ignored safely."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field

from deerflow.mcp.tools import _make_session_pool_tool

class Args(BaseModel):
x: int = Field(..., description="x")

class TruthyHeaders:
def __bool__(self) -> bool:
return True

original_tool = StructuredTool(
name="srv_act",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)

mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)

async def invalid_header_interceptor(request, handler):
return await handler(request.override(headers=TruthyHeaders()))

with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(
original_tool,
"srv",
{"transport": "stdio", "command": "x", "args": []},
tool_interceptors=[invalid_header_interceptor],
)
await wrapped.coroutine(runtime=None, x=1)

mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
assert "unsupported type" in caplog.text


@pytest.mark.asyncio
async def test_session_pool_tool_extracts_thread_id():
"""Thread ID is extracted from runtime.config when not in context."""
Expand Down
Loading
Loading