Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,3 @@ Fixes #
Backend: cd backend && make lint && make test
Frontend: cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test
Frontend E2E (if you touched frontend/): cd frontend && make test-e2e -->

52 changes: 38 additions & 14 deletions backend/packages/harness/deerflow/mcp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

import logging
from collections import Counter
from collections.abc import Iterable
from typing import Any

from langchain_core.tools import BaseTool, StructuredTool
Expand All @@ -19,6 +21,18 @@
logger = logging.getLogger(__name__)


def _get_prefixed_tool_server(tool_name: str, server_names: Iterable[str]) -> str | None:
for server_name in server_names:
if tool_name.startswith(f"{server_name}_"):
return server_name
return None
Comment thread
LittleChenLiya marked this conversation as resolved.


def _strip_mcp_tool_prefix(tool_name: str, server_name: str) -> str:
prefix = f"{server_name}_"
return tool_name[len(prefix) :] if tool_name.startswith(prefix) else tool_name


def _extract_thread_id(runtime: Runtime | None) -> str:
"""Extract thread_id from the injected tool runtime or LangGraph config."""
if runtime is not None:
Expand Down Expand Up @@ -108,6 +122,7 @@ def _make_session_pool_tool(
server_name: str,
connection: dict[str, Any],
tool_interceptors: list[Any] | None = None,
exposed_name: str | None = None,
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.

Expand All @@ -118,11 +133,7 @@ def _make_session_pool_tool(
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
applied on every call before invoking the pooled session.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
original_name = _strip_mcp_tool_prefix(tool.name, server_name)

pool = get_session_pool()

Expand Down Expand Up @@ -161,7 +172,7 @@ async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
return _convert_call_tool_result(call_tool_result)

return StructuredTool(
name=tool.name,
name=exposed_name or tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
Expand Down Expand Up @@ -248,7 +259,7 @@ async def get_mcp_tools() -> list[BaseTool]:
)

# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
# temporary sessions; the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")

Expand All @@ -257,17 +268,30 @@ async def get_mcp_tools() -> list[BaseTool]:
# internally which cannot be closed from a different async task, so
# pooling them causes RuntimeError on cleanup (see #3203).
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
server_names_by_prefix_length = tuple(sorted(servers_config, key=len, reverse=True))
tool_servers = [_get_prefixed_tool_server(tool.name, server_names_by_prefix_length) for tool in tools]
exposed_name_candidates: list[str] = []
for tool, tool_server in zip(tools, tool_servers, strict=True):
if tool_server is None:
exposed_name_candidates.append(tool.name)
else:
exposed_name_candidates.append(_strip_mcp_tool_prefix(tool.name, tool_server))
exposed_name_counts = Counter(exposed_name_candidates)

for tool, tool_server, exposed_name_candidate in zip(tools, tool_servers, exposed_name_candidates, strict=True):
if tool_server is not None:
transport = servers_config[tool_server].get("transport", "stdio")
if transport == "stdio":
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
exposed_name = exposed_name_candidate if exposed_name_counts[exposed_name_candidate] == 1 else tool.name
wrapped_tools.append(
_make_session_pool_tool(
tool,
tool_server,
servers_config[tool_server],
tool_interceptors,
exposed_name=exposed_name,
)
)
else:
wrapped_tools.append(tool)
else:
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/test_mcp_session_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ class Args(BaseModel):
assert len(http_tools) == 1
assert http_tools[0].coroutine is http_tool.coroutine

# Verify the stdio tool WAS wrapped with the pool.
stdio_tools = [t for t in tools if t.name == "playwright_navigate"]
# Verify the stdio tool WAS wrapped with the pool and exposed without the server prefix.
stdio_tools = [t for t in tools if t.name == "navigate"]
assert len(stdio_tools) == 1
assert stdio_tools[0].coroutine is not stdio_tool.coroutine
113 changes: 112 additions & 1 deletion backend/tests/test_mcp_sync_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from deerflow.tools.sync import make_sync_tool_wrapper


class MockCallToolResult:
def __init__(self):
self.content = []
self.isError = False
self.structuredContent = None


class MockArgs(BaseModel):
x: int = Field(..., description="test param")

Expand All @@ -34,7 +41,7 @@ async def mock_coro(x: int):
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance) as mock_client,
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
Expand All @@ -51,6 +58,110 @@ async def mock_coro(x: int):
# Verify it works (sync call)
result = patched_tool.func(x=42)
assert result == "result: 42"
assert mock_client.call_args.kwargs["tool_name_prefix"] is True


def test_mcp_tools_restore_unique_original_names_after_session_wrapping():
"""Unique MCP tool names are exposed without the adapter's server prefix."""

async def mock_coro(x: int):
return f"result: {x}"

mock_tool = StructuredTool(
name="test-server_test_tool",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
)

mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
patch("deerflow.mcp.tools.get_session_pool", return_value=MagicMock()),
):
tools = asyncio.run(get_mcp_tools())

assert [tool.name for tool in tools] == ["test_tool"]


def test_mcp_session_wrapper_calls_original_tool_name():
"""Restored tool names still call MCP sessions with the original tool name."""

async def mock_coro(x: int):
return f"result: {x}"

mock_tool = StructuredTool(
name="test-server_test_tool",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
)

mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(return_value=MockCallToolResult())
mock_pool = MagicMock()
mock_pool.get_session = AsyncMock(return_value=mock_session)

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
patch("deerflow.mcp.tools.get_session_pool", return_value=mock_pool),
):
tools = asyncio.run(get_mcp_tools())
result = asyncio.run(tools[0].coroutine(x=42))

assert tools[0].name == "test_tool"
assert result == ([], None)
mock_session.call_tool.assert_awaited_once_with("test_tool", {"x": 42})


def test_mcp_tools_keep_server_prefix_for_duplicate_original_names():
"""Duplicate original MCP tool names keep prefixes to avoid collisions."""

async def mock_coro(x: int):
return f"result: {x}"

tools_from_client = [
StructuredTool(
name="server-a_search",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
),
StructuredTool(
name="server-b_search",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
),
]

mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=tools_from_client)

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"server-a": {}, "server-b": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
patch("deerflow.mcp.tools.get_session_pool", return_value=MagicMock()),
):
tools = asyncio.run(get_mcp_tools())

assert [tool.name for tool in tools] == ["server-a_search", "server-b_search"]


def test_mcp_tool_sync_wrapper_in_running_loop():
Expand Down
Loading