Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions backend/packages/harness/deerflow/mcp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from __future__ import annotations

import logging
from collections import Counter
from collections.abc import Iterable
from typing import Any

from langchain_core.tools import BaseTool, StructuredTool
Expand All @@ -19,6 +21,18 @@
logger = logging.getLogger(__name__)


def _get_prefixed_tool_server(tool_name: str, server_names: Iterable[str]) -> str | None:
for server_name in sorted(server_names, key=len, reverse=True):
if tool_name.startswith(f"{server_name}_"):
return server_name
return None
Comment thread
LittleChenLiya marked this conversation as resolved.


def _strip_mcp_tool_prefix(tool_name: str, server_name: str) -> str:
prefix = f"{server_name}_"
return tool_name[len(prefix) :] if tool_name.startswith(prefix) else tool_name


def _extract_thread_id(runtime: Runtime | None) -> str:
"""Extract thread_id from the injected tool runtime or LangGraph config."""
if runtime is not None:
Expand Down Expand Up @@ -108,6 +122,7 @@ def _make_session_pool_tool(
server_name: str,
connection: dict[str, Any],
tool_interceptors: list[Any] | None = None,
exposed_name: str | None = None,
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.

Expand All @@ -118,11 +133,7 @@ def _make_session_pool_tool(
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
applied on every call before invoking the pooled session.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
original_name = _strip_mcp_tool_prefix(tool.name, server_name)

pool = get_session_pool()

Expand Down Expand Up @@ -161,7 +172,7 @@ async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
return _convert_call_tool_result(call_tool_result)

return StructuredTool(
name=tool.name,
name=exposed_name or tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
Expand Down Expand Up @@ -246,21 +257,28 @@ async def get_mcp_tools() -> list[BaseTool]:
)

# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
# temporary sessions; the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")

# Wrap each tool with persistent-session logic.
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
tool_servers = [_get_prefixed_tool_server(tool.name, servers_config) for tool in tools]
exposed_name_candidates = [_strip_mcp_tool_prefix(tool.name, tool_server) if tool_server is not None else tool.name for tool, tool_server in zip(tools, tool_servers, strict=True)]
Comment thread
LittleChenLiya marked this conversation as resolved.
Outdated
exposed_name_counts = Counter(exposed_name_candidates)

for tool, tool_server, exposed_name_candidate in zip(tools, tool_servers, exposed_name_candidates, strict=True):
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
exposed_name = exposed_name_candidate if exposed_name_counts[exposed_name_candidate] == 1 else tool.name
wrapped_tools.append(
_make_session_pool_tool(
tool,
tool_server,
servers_config[tool_server],
tool_interceptors,
exposed_name=exposed_name,
)
)
else:
wrapped_tools.append(tool)

Expand Down
112 changes: 111 additions & 1 deletion backend/tests/test_mcp_sync_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from deerflow.tools.sync import make_sync_tool_wrapper


class MockCallToolResult:
content = []
isError = False
structuredContent = None
Comment thread
LittleChenLiya marked this conversation as resolved.
Outdated


class MockArgs(BaseModel):
x: int = Field(..., description="test param")

Expand All @@ -34,7 +40,7 @@ async def mock_coro(x: int):
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance) as mock_client,
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
Expand All @@ -51,6 +57,110 @@ async def mock_coro(x: int):
# Verify it works (sync call)
result = patched_tool.func(x=42)
assert result == "result: 42"
assert mock_client.call_args.kwargs["tool_name_prefix"] is True


def test_mcp_tools_restore_unique_original_names_after_session_wrapping():
"""Unique MCP tool names are exposed without the adapter's server prefix."""

async def mock_coro(x: int):
return f"result: {x}"

mock_tool = StructuredTool(
name="test-server_test_tool",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
)

mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
patch("deerflow.mcp.tools.get_session_pool", return_value=MagicMock()),
):
tools = asyncio.run(get_mcp_tools())

assert [tool.name for tool in tools] == ["test_tool"]


def test_mcp_session_wrapper_calls_original_tool_name():
"""Restored tool names still call MCP sessions with the original tool name."""

async def mock_coro(x: int):
return f"result: {x}"

mock_tool = StructuredTool(
name="test-server_test_tool",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
)

mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])
mock_session = MagicMock()
mock_session.call_tool = AsyncMock(return_value=MockCallToolResult())
mock_pool = MagicMock()
mock_pool.get_session = AsyncMock(return_value=mock_session)

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
patch("deerflow.mcp.tools.get_session_pool", return_value=mock_pool),
):
tools = asyncio.run(get_mcp_tools())
result = asyncio.run(tools[0].coroutine(x=42))

assert tools[0].name == "test_tool"
assert result == ([], None)
mock_session.call_tool.assert_awaited_once_with("test_tool", {"x": 42})


def test_mcp_tools_keep_server_prefix_for_duplicate_original_names():
"""Duplicate original MCP tool names keep prefixes to avoid collisions."""

async def mock_coro(x: int):
return f"result: {x}"

tools_from_client = [
StructuredTool(
name="server-a_search",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
),
StructuredTool(
name="server-b_search",
description="test description",
args_schema=MockArgs,
func=None,
coroutine=mock_coro,
),
]

mock_client_instance = MagicMock()
mock_client_instance.get_tools = AsyncMock(return_value=tools_from_client)

with (
patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance),
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"),
patch("deerflow.mcp.tools.build_servers_config", return_value={"server-a": {}, "server-b": {}}),
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}),
patch("deerflow.mcp.tools.get_session_pool", return_value=MagicMock()),
):
tools = asyncio.run(get_mcp_tools())

assert [tool.name for tool in tools] == ["server-a_search", "server-b_search"]


def test_mcp_tool_sync_wrapper_in_running_loop():
Expand Down