Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions backend/tests/test_lead_agent_model_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,33 @@ def _raise_get_app_config():
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])


def test_create_summarization_middleware_preserves_frontend_update_key_contract(monkeypatch):
"""LangGraph update keys use the middleware class name plus hook name.

The frontend treats any ``*.SummarizationMiddleware.before_model`` update as
the summarization state reset signal, so the lead agent's runtime middleware
must keep that suffix stable even when using a DeerFlow-specific subclass.

Temporary regression guard for issue#2965. Remove this test once the
frontend and backend have a stronger explicit contract than matching
middleware-generated update keys.
"""

app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)])
app_config.summarization = SummarizationConfig(enabled=True)
app_config.memory = MemoryConfig(enabled=False)

fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
monkeypatch.setattr(lead_agent_module, "create_chat_model", lambda **kwargs: fake_model)

middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)

assert middleware is not None
update_key = f"{type(middleware).__name__}.before_model"
assert update_key.endswith("SummarizationMiddleware.before_model")


def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
Expand Down
64 changes: 52 additions & 12 deletions frontend/src/core/threads/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { getAPIClient } from "../api";
import { fetch } from "../api/fetcher";
import { getBackendBaseURL } from "../config";
import { useI18n } from "../i18n/hooks";
import { isHiddenFromUIMessage } from "../messages/utils";
import type { FileInMessage } from "../messages/utils";
import type { LocalSettings } from "../settings";
import { useUpdateSubtask } from "../tasks/context";
Expand Down Expand Up @@ -65,17 +66,28 @@ function messageIdentity(message: Message): string | undefined {

function dedupeMessagesByIdentity(messages: Message[]): Message[] {
const lastIndexByIdentity = new Map<string, number>();
const lastVisibleIndexByIdentity = new Map<string, number>();

messages.forEach((message, index) => {
const identity = messageIdentity(message);
if (identity) {
lastIndexByIdentity.set(identity, index);
if (!isHiddenFromUIMessage(message)) {
lastVisibleIndexByIdentity.set(identity, index);
}
}
});

return messages.filter((message, index) => {
const identity = messageIdentity(message);
return !identity || lastIndexByIdentity.get(identity) === index;
if (!identity) {
return true;
}
const visibleIndex = lastVisibleIndexByIdentity.get(identity);
if (visibleIndex !== undefined) {
return visibleIndex === index;
}
return lastIndexByIdentity.get(identity) === index;
});
}

Expand All @@ -98,7 +110,10 @@ export function mergeMessages(
optimisticMessages: Message[],
): Message[] {
const threadMessageIds = new Set(
threadMessages.map(messageIdentity).filter(isNonEmptyString),
threadMessages
.filter((message) => !isHiddenFromUIMessage(message))
.map(messageIdentity)
.filter(isNonEmptyString),
);

// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
Expand Down Expand Up @@ -149,6 +164,30 @@ export function getVisibleOptimisticMessages(
return optimisticMessages;
}

export function getSummarizationMiddlewareMessages(
data: unknown,
): Message[] | undefined {
if (typeof data !== "object" || data === null) {
return undefined;
}

for (const [key, update] of Object.entries(data)) {
if (!key.endsWith("SummarizationMiddleware.before_model")) {
continue;
}
if (typeof update !== "object" || update === null) {
continue;
}

const messages = Reflect.get(update, "messages");
if (Array.isArray(messages)) {
return [...messages] as Message[];
}
}

return undefined;
}

function getStreamErrorMessage(error: unknown): string {
if (typeof error === "string" && error.trim()) {
return error;
Expand Down Expand Up @@ -258,24 +297,25 @@ export function useThreadStream({
}
},
onUpdateEvent(data) {
if (data["SummarizationMiddleware.before_model"]) {
const _messages = [
...(data["SummarizationMiddleware.before_model"].messages ?? []),
];

if (_messages.length < 2) {
return;
}
const _messages = getSummarizationMiddlewareMessages(data);
if (_messages && _messages.length >= 2) {
for (const m of _messages) {
if (m.name === "summary" && m.type === "human") {
summarizedRef.current?.add(m.id ?? "");
}
}
const _lastKeepMessage = _messages[2];
const firstRetainedVisibleIdentity = _messages
.filter((message) => message.type !== "remove")
.filter((message) => !isHiddenFromUIMessage(message))
.map(messageIdentity)
.find(isNonEmptyString);
const _currentMessages = [...messagesRef.current];
const _movedMessages: Message[] = [];
for (const m of _currentMessages) {
if (m.id !== undefined && m.id === _lastKeepMessage?.id) {
if (
firstRetainedVisibleIdentity &&
messageIdentity(m) === firstRetainedVisibleIdentity
) {
break;
}
if (!summarizedRef.current?.has(m.id ?? "")) {
Expand Down
47 changes: 47 additions & 0 deletions frontend/tests/unit/core/threads/message-merge.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";

import {
getSummarizationMiddlewareMessages,
getVisibleOptimisticMessages,
mergeMessages,
} from "@/core/threads/hooks";
Expand Down Expand Up @@ -66,6 +67,52 @@ test("mergeMessages deduplicates tool messages by tool_call_id", () => {
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
});

test("mergeMessages keeps a visible history message when a hidden live message reuses its id", () => {
const historyHuman = {
id: "human-1",
type: "human",
content: "visible user prompt",
} as Message;
const hiddenReminder = {
id: "human-1",
type: "human",
content: "<system-reminder>hidden</system-reminder>",
additional_kwargs: { hide_from_ui: true },
} as Message;
const liveAi = {
id: "ai-1",
type: "ai",
content: "live answer",
} as Message;

expect(mergeMessages([historyHuman], [hiddenReminder, liveAi], [])).toEqual([
historyHuman,
liveAi,
]);
});

test("getSummarizationMiddlewareMessages matches DeerFlow summarization update keys", () => {
const removeAll = {
id: "__remove_all__",
type: "remove",
content: "",
} as Message;
const summary = {
id: "summary-1",
type: "human",
name: "summary",
content: "summary",
} as Message;

expect(
getSummarizationMiddlewareMessages({
"DeerFlowSummarizationMiddleware.before_model": {
messages: [removeAll, summary],
},
}),
).toEqual([removeAll, summary]);
});

test("getVisibleOptimisticMessages hides optimistic user input after server human arrives", () => {
const optimisticHuman = {
id: "opt-human-1",
Expand Down
Loading