Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import json
import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable
from typing import override

Expand Down Expand Up @@ -109,10 +110,10 @@ def _build_patched_messages(self, messages: list) -> list | None:
This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged.
"""
tool_messages_by_id: dict[str, ToolMessage] = {}
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
for msg in messages:
if isinstance(msg, ToolMessage):
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
tool_messages_by_id[msg.tool_call_id].append(msg)

tool_call_ids: set[str] = set()
for msg in messages:
Expand All @@ -124,7 +125,7 @@ def _build_patched_messages(self, messages: list) -> list | None:
tool_call_ids.add(tc_id)

patched: list = []
consumed_tool_msg_ids: set[str] = set()
consumed_tool_msg_objects: set[int] = set()
patch_count = 0
for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
Expand All @@ -136,13 +137,17 @@ def _build_patched_messages(self, messages: list) -> list | None:

for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if not tc_id or tc_id in consumed_tool_msg_ids:
if not tc_id:
continue

existing_tool_msg = tool_messages_by_id.get(tc_id)
tool_msg_queue = tool_messages_by_id.get(tc_id)
while tool_msg_queue and id(tool_msg_queue[0]) in consumed_tool_msg_objects:
tool_msg_queue.popleft()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The deque already guarantees FIFO consumption once a ToolMessage is removed
with popleft(), so the extra object-id tracking is unnecessary here. I removed
consumed_tool_msg_objects and the cleanup loop, and kept the matching logic based on queue
occurrence order only.

existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
consumed_tool_msg_objects.add(id(existing_tool_msg))
else:
patched.append(
ToolMessage(
Expand All @@ -152,7 +157,6 @@ def _build_patched_messages(self, messages: list) -> list | None:
status="error",
)
)
consumed_tool_msg_ids.add(tc_id)
patch_count += 1

if patched == messages:
Expand Down
28 changes: 28 additions & 0 deletions backend/tests/test_dangling_tool_call_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,34 @@ def test_valid_adjacent_tool_results_are_unchanged(self):

assert mw._build_patched_messages(msgs) is None

def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]

assert mw._build_patched_messages(msgs) is None

def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
Expand Down
Loading