feat: refine token usage display modes (#2329)

* feat: refine token usage display modes

* docs: clarify token usage accounting semantics

* fix: avoid duplicate subtask debug keys

* style: format token usage tests

* chore: address token attribution review feedback

* Update test_token_usage_middleware.py

* Update test_token_usage_middleware.py

* chore: simplify token attribution fallback

* fix token usage metadata follow-up handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
YuJitang
2026-05-04 09:56:16 +08:00
committed by GitHub
parent 82e7936d36
commit d02f762ab0
20 changed files with 2346 additions and 222 deletions
@@ -1,31 +1,270 @@
"""Middleware for logging LLM token usage.""" """Middleware for logging token usage and annotating step attribution."""
from __future__ import annotations
import logging import logging
from typing import override from collections import defaultdict
from typing import Any, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
def _string_arg(value: Any) -> str | None:
if isinstance(value, str):
normalized = value.strip()
return normalized or None
return None
def _normalize_todos(value: Any) -> list[Todo]:
if not isinstance(value, list):
return []
normalized: list[Todo] = []
for item in value:
if not isinstance(item, dict):
continue
todo: Todo = {}
content = _string_arg(item.get("content"))
status = item.get("status")
if content is not None:
todo["content"] = content
if status in {"pending", "in_progress", "completed"}:
todo["status"] = status
normalized.append(todo)
return normalized
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
status = current.get("status")
previous_content = previous.get("content") if previous else None
current_content = current.get("content")
if previous is None:
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
if previous_content != current_content:
return "todo_update"
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
# This is the single source of truth for precise write_todos token
# attribution. The frontend intentionally falls back to a generic
# "Update to-do list" label when this metadata is missing or malformed.
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
matched_previous_indices: set[int] = set()
for index, todo in enumerate(previous_todos):
content = todo.get("content")
if isinstance(content, str) and content:
previous_by_content[content].append((index, todo))
actions: list[dict[str, Any]] = []
for index, todo in enumerate(next_todos):
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
previous_match: Todo | None = None
content_matches = previous_by_content.get(content)
if content_matches:
while content_matches and content_matches[0][0] in matched_previous_indices:
content_matches.pop(0)
if content_matches:
previous_index, previous_match = content_matches.pop(0)
matched_previous_indices.add(previous_index)
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
previous_match = previous_todos[index]
matched_previous_indices.add(index)
if previous_match is not None:
previous_content = previous_match.get("content")
previous_status = previous_match.get("status")
if previous_content == content and previous_status == todo.get("status"):
continue
actions.append(
{
"kind": _todo_action_kind(previous_match, todo),
"content": content,
}
)
for index, todo in enumerate(previous_todos):
if index in matched_previous_indices:
continue
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
actions.append(
{
"kind": "todo_remove",
"content": content,
}
)
return actions
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
name = _string_arg(tool_call.get("name")) or "unknown"
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
tool_call_id = _string_arg(tool_call.get("id"))
if name == "write_todos":
next_todos = _normalize_todos(args.get("todos"))
actions = _build_todo_actions(todos, next_todos)
if not actions:
return [
{
"kind": "tool",
"tool_name": name,
"tool_call_id": tool_call_id,
}
]
return [
{
**action,
"tool_call_id": tool_call_id,
}
for action in actions
]
if name == "task":
return [
{
"kind": "subagent",
"description": _string_arg(args.get("description")),
"subagent_type": _string_arg(args.get("subagent_type")),
"tool_call_id": tool_call_id,
}
]
if name in {"web_search", "image_search"}:
query = _string_arg(args.get("query"))
return [
{
"kind": "search",
"tool_name": name,
"query": query,
"tool_call_id": tool_call_id,
}
]
if name == "present_files":
return [
{
"kind": "present_files",
"tool_call_id": tool_call_id,
}
]
if name == "ask_clarification":
return [
{
"kind": "clarification",
"tool_call_id": tool_call_id,
}
]
return [
{
"kind": "tool",
"tool_name": name,
"description": _string_arg(args.get("description")),
"tool_call_id": tool_call_id,
}
]
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
if actions:
first_kind = actions[0].get("kind")
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
return "todo_update"
if len(actions) == 1 and first_kind == "subagent":
return "subagent_dispatch"
return "tool_batch"
if message.content:
return "final_answer"
return "thinking"
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
current_todos = list(todos)
for raw_tool_call in tool_calls:
if not isinstance(raw_tool_call, dict):
continue
described_actions = _describe_tool_call(raw_tool_call, current_todos)
actions.extend(described_actions)
if raw_tool_call.get("name") == "write_todos":
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
current_todos = _normalize_todos(args.get("todos"))
tool_call_ids: list[str] = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
tool_call_id = _string_arg(tool_call.get("id"))
if tool_call_id is not None:
tool_call_ids.append(tool_call_id)
return {
# Schema changes should remain additive where possible so older
# frontends can ignore unknown fields and fall back safely.
"version": 1,
"kind": _infer_step_kind(message, actions),
"shared_attribution": len(actions) > 1,
"tool_call_ids": tool_call_ids,
"actions": actions,
}
class TokenUsageMiddleware(AgentMiddleware): class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata.""" """Logs token usage from model responses and annotates the AI step."""
@override def _apply(self, state: AgentState) -> dict | None:
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
messages = state.get("messages", []) messages = state.get("messages", [])
if not messages: if not messages:
return None return None
last = messages[-1] last = messages[-1]
if not isinstance(last, AIMessage):
return None
usage = getattr(last, "usage_metadata", None) usage = getattr(last, "usage_metadata", None)
if usage: if usage:
logger.info( logger.info(
@@ -34,4 +273,22 @@ class TokenUsageMiddleware(AgentMiddleware):
usage.get("output_tokens", "?"), usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"), usage.get("total_tokens", "?"),
) )
return None
todos = state.get("todos") or []
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
+98 -19
View File
@@ -264,25 +264,35 @@ class DeerFlowClient:
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls] return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod @staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent": def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
"""Build a ``messages-tuple`` AI text event, attaching usage when present.""" """Copy message additional_kwargs when present."""
additional_kwargs = getattr(msg, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs:
return dict(additional_kwargs)
return None
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id} data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage: if usage:
data["usage_metadata"] = usage data["usage_metadata"] = usage
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data) return StreamEvent(type="messages-tuple", data=data)
@staticmethod @staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent": def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event.""" """Build a ``messages-tuple`` AI tool-calls event."""
return StreamEvent( data: dict[str, Any] = {
type="messages-tuple", "type": "ai",
data={ "content": "",
"type": "ai", "id": msg_id,
"content": "", "tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
"id": msg_id, }
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls), if additional_kwargs:
}, data["additional_kwargs"] = additional_kwargs
) return StreamEvent(type="messages-tuple", data=data)
@staticmethod @staticmethod
def _tool_message_event(msg: ToolMessage) -> "StreamEvent": def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
@@ -307,19 +317,30 @@ class DeerFlowClient:
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls) d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None): if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata d["usage_metadata"] = msg.usage_metadata
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d return d
if isinstance(msg, ToolMessage): if isinstance(msg, ToolMessage):
return { d = {
"type": "tool", "type": "tool",
"content": DeerFlowClient._extract_text(msg.content), "content": DeerFlowClient._extract_text(msg.content),
"name": getattr(msg, "name", None), "name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None), "tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None), "id": getattr(msg, "id", None),
} }
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, HumanMessage): if isinstance(msg, HumanMessage):
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)} d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, SystemMessage): if isinstance(msg, SystemMessage):
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)} d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)} return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
@staticmethod @staticmethod
@@ -542,6 +563,7 @@ class DeerFlowClient:
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str} - type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}} - type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]} - type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str} - type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}} - type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
""" """
@@ -564,6 +586,7 @@ class DeerFlowClient:
# in both the final ``messages`` chunk and the values snapshot — # in both the final ``messages`` chunk and the values snapshot —
# count it only on whichever arrives first. # count it only on whichever arrives first.
counted_usage_ids: set[str] = set() counted_usage_ids: set[str] = set()
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: Any) -> dict | None: def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
@@ -593,6 +616,20 @@ class DeerFlowClient:
"total_tokens": total_tokens, "total_tokens": total_tokens,
} }
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
if not additional_kwargs:
return None
if not msg_id:
return additional_kwargs
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
if not delta:
return None
sent.update(delta)
return delta
for item in self._agent.stream( for item in self._agent.stream(
state, state,
config=config, config=config,
@@ -620,17 +657,31 @@ class DeerFlowClient:
if isinstance(msg_chunk, AIMessage): if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content) text = self._extract_text(msg_chunk.content)
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata) counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
sent_additional_kwargs = False
if text: if text:
if msg_id: if msg_id:
streamed_ids.add(msg_id) streamed_ids.add(msg_id)
yield self._ai_text_event(msg_id, text, counted_usage) additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
if msg_chunk.tool_calls: if msg_chunk.tool_calls:
if msg_id: if msg_id:
streamed_ids.add(msg_id) streamed_ids.add(msg_id)
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls) additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg_chunk.tool_calls,
additional_kwargs_delta,
)
elif isinstance(msg_chunk, ToolMessage): elif isinstance(msg_chunk, ToolMessage):
if msg_id: if msg_id:
@@ -653,17 +704,45 @@ class DeerFlowClient:
if msg_id and msg_id in streamed_ids: if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None)) _account_usage(msg_id, getattr(msg, "usage_metadata", None))
additional_kwargs = self._serialize_additional_kwargs(msg)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
if additional_kwargs_delta:
# Metadata-only follow-up: ``messages-tuple`` has no
# dedicated attribution event, so clients should
# merge this empty-content AI event by message id
# and ignore it for text rendering.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
continue continue
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
counted_usage = _account_usage(msg_id, msg.usage_metadata) counted_usage = _account_usage(msg_id, msg.usage_metadata)
additional_kwargs = self._serialize_additional_kwargs(msg)
sent_additional_kwargs = False
if msg.tool_calls: if msg.tool_calls:
yield self._ai_tool_calls_event(msg_id, msg.tool_calls) additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg.tool_calls,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
text = self._extract_text(msg.content) text = self._extract_text(msg.content)
if text: if text:
yield self._ai_text_event(msg_id, text, counted_usage) additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
elif msg_id:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
if not additional_kwargs_delta:
continue
# See the metadata-only follow-up convention above.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
elif isinstance(msg, ToolMessage): elif isinstance(msg, ToolMessage):
yield self._tool_message_event(msg) yield self._tool_message_event(msg)
+79
View File
@@ -437,6 +437,85 @@ class TestStream:
call_kwargs = agent.stream.call_args.kwargs call_kwargs = agent.stream.call_args.kwargs
assert "messages" in call_kwargs["stream_mode"] assert "messages" in call_kwargs["stream_mode"]
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
assert any(event.data.get("content") == "Hello!" for event in ai_events)
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
attribution = {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"reasoning_content": "Thinking first.",
"token_usage_attribution": attribution,
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
(
"messages",
(
AIMessageChunk(
content="Hello!",
id="ai-1",
additional_kwargs={"reasoning_content": "Thinking first."},
),
{},
),
),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
assert metadata_events[1].data["content"] == ""
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
def test_chat_accumulates_streamed_deltas(self, client): def test_chat_accumulates_streamed_deltas(self, client):
"""chat() concatenates per-id deltas from messages mode.""" """chat() concatenates per-id deltas from messages mode."""
agent = MagicMock() agent = MagicMock()
@@ -0,0 +1,53 @@
"""Tests for DeerFlowClient message serialization helpers."""
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.client import DeerFlowClient
def test_serialize_ai_message_preserves_additional_kwargs():
message = AIMessage(
content="done",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized["type"] == "ai"
assert serialized["usage_metadata"] == {
"input_tokens": 12,
"output_tokens": 3,
"total_tokens": 15,
}
assert serialized["additional_kwargs"] == {
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
}
def test_serialize_human_message_preserves_additional_kwargs():
message = HumanMessage(
content="hello",
additional_kwargs={"files": [{"name": "diagram.png"}]},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized == {
"type": "human",
"content": "hello",
"id": None,
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
}
+149 -24
View File
@@ -1,32 +1,157 @@
from unittest.mock import MagicMock, patch """Tests for TokenUsageMiddleware attribution annotations."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
TokenUsageMiddleware,
)
def test_after_model_logs_usage_metadata_counts(): def _make_runtime():
middleware = TokenUsageMiddleware() runtime = MagicMock()
state = { runtime.context = {"thread_id": "test-thread"}
"messages": [ return runtime
AIMessage(
content="done",
usage_metadata={ class TestTokenUsageMiddleware:
"input_tokens": 10, def test_annotates_todo_updates_with_structured_actions(self):
"output_tokens": 5, middleware = TokenUsageMiddleware()
"total_tokens": 15, message = AIMessage(
}, content="",
) tool_calls=[
{
"id": "write_todos:1",
"name": "write_todos",
"args": {
"todos": [
{"content": "Inspect streaming path", "status": "completed"},
{"content": "Design token attribution schema", "status": "in_progress"},
]
},
}
],
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
)
state = {
"messages": [message],
"todos": [
{"content": "Inspect streaming path", "status": "in_progress"},
{"content": "Design token attribution schema", "status": "pending"},
],
}
result = middleware.after_model(state, _make_runtime())
assert result is not None
updated_message = result["messages"][0]
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["tool_call_ids"] == ["write_todos:1"]
assert attribution["actions"] == [
{
"kind": "todo_complete",
"content": "Inspect streaming path",
"tool_call_id": "write_todos:1",
},
{
"kind": "todo_start",
"content": "Design token attribution schema",
"tool_call_id": "write_todos:1",
},
] ]
}
with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock: def test_annotates_subagent_and_search_steps(self):
result = middleware.after_model(state=state, runtime=MagicMock()) middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "task:1",
"name": "task",
"args": {
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
},
},
{
"id": "web_search:1",
"name": "web_search",
"args": {"query": "LangGraph useStream messages tuple"},
},
],
)
assert result is None result = middleware.after_model({"messages": [message]}, _make_runtime())
info_mock.assert_called_once_with(
"LLM token usage: input=%s output=%s total=%s", assert result is not None
10, attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
5, assert attribution["kind"] == "tool_batch"
15, assert attribution["shared_attribution"] is True
) assert attribution["actions"] == [
{
"kind": "subagent",
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
"tool_call_id": "task:1",
},
{
"kind": "search",
"tool_name": "web_search",
"query": "LangGraph useStream messages tuple",
"tool_call_id": "web_search:1",
},
]
def test_marks_final_answer_when_no_tools(self):
middleware = TokenUsageMiddleware()
message = AIMessage(content="Here is the final answer.")
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "final_answer"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == []
def test_annotates_removed_todos(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:remove",
"name": "write_todos",
"args": {
"todos": [],
},
}
],
)
result = middleware.after_model(
{
"messages": [message],
"todos": [
{"content": "Archive obsolete plan", "status": "pending"},
],
},
_make_runtime(),
)
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "todo_update"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == [
{
"kind": "todo_remove",
"content": "Archive obsolete plan",
"tool_call_id": "write_todos:remove",
}
]
@@ -25,7 +25,7 @@ import { useAgent } from "@/core/agents";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { useModels } from "@/core/models/hooks"; import { useModels } from "@/core/models/hooks";
import { useNotification } from "@/core/notification/hooks"; import { useNotification } from "@/core/notification/hooks";
import { useThreadSettings } from "@/core/settings"; import { useLocalSettings, useThreadSettings } from "@/core/settings";
import { useThreadStream } from "@/core/threads/hooks"; import { useThreadStream } from "@/core/threads/hooks";
import { textOfMessage } from "@/core/threads/utils"; import { textOfMessage } from "@/core/threads/utils";
import { env } from "@/env"; import { env } from "@/env";
@@ -45,6 +45,7 @@ export default function AgentChatPage() {
const { threadId, setThreadId, isNewThread, setIsNewThread } = const { threadId, setThreadId, isNewThread, setIsNewThread } =
useThreadChat(); useThreadChat();
const [settings, setSettings] = useThreadSettings(threadId); const [settings, setSettings] = useThreadSettings(threadId);
const [localSettings, setLocalSettings] = useLocalSettings();
const { tokenUsageEnabled } = useModels(); const { tokenUsageEnabled } = useModels();
const { showNotification } = useNotification(); const { showNotification } = useNotification();
@@ -100,6 +101,9 @@ export default function AgentChatPage() {
? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM + ? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM +
MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM
: undefined; : undefined;
const tokenUsageInlineMode = tokenUsageEnabled
? localSettings.tokenUsage.inlineMode
: "off";
return ( return (
<ThreadContext.Provider value={{ thread }}> <ThreadContext.Provider value={{ thread }}>
@@ -139,6 +143,10 @@ export default function AgentChatPage() {
<TokenUsageIndicator <TokenUsageIndicator
enabled={tokenUsageEnabled} enabled={tokenUsageEnabled}
messages={thread.messages} messages={thread.messages}
preferences={localSettings.tokenUsage}
onPreferencesChange={(preferences) =>
setLocalSettings("tokenUsage", preferences)
}
/> />
<ExportTrigger threadId={threadId} /> <ExportTrigger threadId={threadId} />
<ArtifactTrigger /> <ArtifactTrigger />
@@ -152,10 +160,10 @@ export default function AgentChatPage() {
threadId={threadId} threadId={threadId}
thread={thread} thread={thread}
paddingBottom={messageListPaddingBottom} paddingBottom={messageListPaddingBottom}
tokenUsageEnabled={tokenUsageEnabled}
hasMoreHistory={hasMoreHistory} hasMoreHistory={hasMoreHistory}
loadMoreHistory={loadMoreHistory} loadMoreHistory={loadMoreHistory}
isHistoryLoading={isHistoryLoading} isHistoryLoading={isHistoryLoading}
tokenUsageInlineMode={tokenUsageInlineMode}
/> />
</div> </div>
@@ -24,7 +24,7 @@ import { Welcome } from "@/components/workspace/welcome";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { useModels } from "@/core/models/hooks"; import { useModels } from "@/core/models/hooks";
import { useNotification } from "@/core/notification/hooks"; import { useNotification } from "@/core/notification/hooks";
import { useThreadSettings } from "@/core/settings"; import { useLocalSettings, useThreadSettings } from "@/core/settings";
import { useThreadStream } from "@/core/threads/hooks"; import { useThreadStream } from "@/core/threads/hooks";
import { textOfMessage } from "@/core/threads/utils"; import { textOfMessage } from "@/core/threads/utils";
import { env } from "@/env"; import { env } from "@/env";
@@ -36,6 +36,7 @@ export default function ChatPage() {
const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } = const { threadId, setThreadId, isNewThread, setIsNewThread, isMock } =
useThreadChat(); useThreadChat();
const [settings, setSettings] = useThreadSettings(threadId); const [settings, setSettings] = useThreadSettings(threadId);
const [localSettings, setLocalSettings] = useLocalSettings();
const { tokenUsageEnabled } = useModels(); const { tokenUsageEnabled } = useModels();
const mountedRef = useRef(false); const mountedRef = useRef(false);
useSpecificChatMode(); useSpecificChatMode();
@@ -99,6 +100,9 @@ export default function ChatPage() {
? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM + ? MESSAGE_LIST_DEFAULT_PADDING_BOTTOM +
MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM MESSAGE_LIST_FOLLOWUPS_EXTRA_PADDING_BOTTOM
: undefined; : undefined;
const tokenUsageInlineMode = tokenUsageEnabled
? localSettings.tokenUsage.inlineMode
: "off";
return ( return (
<ThreadContext.Provider value={{ thread, isMock }}> <ThreadContext.Provider value={{ thread, isMock }}>
@@ -119,6 +123,10 @@ export default function ChatPage() {
<TokenUsageIndicator <TokenUsageIndicator
enabled={tokenUsageEnabled} enabled={tokenUsageEnabled}
messages={thread.messages} messages={thread.messages}
preferences={localSettings.tokenUsage}
onPreferencesChange={(preferences) =>
setLocalSettings("tokenUsage", preferences)
}
/> />
<ExportTrigger threadId={threadId} /> <ExportTrigger threadId={threadId} />
<ArtifactTrigger /> <ArtifactTrigger />
@@ -131,10 +139,10 @@ export default function ChatPage() {
threadId={threadId} threadId={threadId}
thread={thread} thread={thread}
paddingBottom={messageListPaddingBottom} paddingBottom={messageListPaddingBottom}
tokenUsageEnabled={tokenUsageEnabled}
hasMoreHistory={hasMoreHistory} hasMoreHistory={hasMoreHistory}
loadMoreHistory={loadMoreHistory} loadMoreHistory={loadMoreHistory}
isHistoryLoading={isHistoryLoading} isHistoryLoading={isHistoryLoading}
tokenUsageInlineMode={tokenUsageInlineMode}
/> />
</div> </div>
<div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4"> <div className="absolute right-0 bottom-0 left-0 z-30 flex justify-center px-4">
@@ -2,6 +2,7 @@ import type { Message } from "@langchain/langgraph-sdk";
import { import {
BookOpenTextIcon, BookOpenTextIcon,
ChevronUp, ChevronUp,
CoinsIcon,
FolderOpenIcon, FolderOpenIcon,
GlobeIcon, GlobeIcon,
LightbulbIcon, LightbulbIcon,
@@ -24,6 +25,8 @@ import {
import { CodeBlock } from "@/components/ai-elements/code-block"; import { CodeBlock } from "@/components/ai-elements/code-block";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { formatTokenCount } from "@/core/messages/usage";
import type { TokenDebugStep } from "@/core/messages/usage-model";
import { import {
extractReasoningContentFromMessage, extractReasoningContentFromMessage,
findToolCallResult, findToolCallResult,
@@ -43,10 +46,14 @@ export function MessageGroup({
className, className,
messages, messages,
isLoading = false, isLoading = false,
tokenDebugSteps = [],
showTokenDebugSummaries = false,
}: { }: {
className?: string; className?: string;
messages: Message[]; messages: Message[];
isLoading?: boolean; isLoading?: boolean;
tokenDebugSteps?: TokenDebugStep[];
showTokenDebugSummaries?: boolean;
}) { }) {
const { t } = useI18n(); const { t } = useI18n();
const [showAbove, setShowAbove] = useState( const [showAbove, setShowAbove] = useState(
@@ -56,6 +63,28 @@ export function MessageGroup({
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true", env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true",
); );
const steps = useMemo(() => convertToSteps(messages), [messages]); const steps = useMemo(() => convertToSteps(messages), [messages]);
const debugStepByMessageId = useMemo(
() =>
new Map(
tokenDebugSteps.map(
(step) => [step.messageId || step.id, step] as const,
),
),
[tokenDebugSteps],
);
const toolCallCountByMessageId = useMemo(() => {
const counts = new Map<string, number>();
for (const step of steps) {
if (step.type !== "toolCall" || !step.messageId) {
continue;
}
counts.set(step.messageId, (counts.get(step.messageId) ?? 0) + 1);
}
return counts;
}, [steps]);
const lastToolCallStep = useMemo(() => { const lastToolCallStep = useMemo(() => {
const filteredSteps = steps.filter((step) => step.type === "toolCall"); const filteredSteps = steps.filter((step) => step.type === "toolCall");
return filteredSteps[filteredSteps.length - 1]; return filteredSteps[filteredSteps.length - 1];
@@ -77,6 +106,125 @@ export function MessageGroup({
} }
}, [lastToolCallStep, steps]); }, [lastToolCallStep, steps]);
const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading); const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading);
const firstEligibleDebugSummaryStepIndexByMessageId = useMemo(() => {
const firstIndices = new Map<string, number>();
if (!showTokenDebugSummaries) {
return firstIndices;
}
for (const [index, step] of steps.entries()) {
const messageId = step.messageId;
if (!messageId || firstIndices.has(messageId)) {
continue;
}
const debugStep = debugStepByMessageId.get(messageId);
if (!debugStep) {
continue;
}
const toolCallCount = toolCallCountByMessageId.get(messageId) ?? 0;
if (!debugStep.sharedAttribution && toolCallCount > 0) {
continue;
}
if (
!debugStep.sharedAttribution &&
toolCallCount === 0 &&
debugStep.label === t.common.thinking &&
debugStep.secondaryLabels.length === 0
) {
continue;
}
firstIndices.set(messageId, index);
}
return firstIndices;
}, [
debugStepByMessageId,
showTokenDebugSummaries,
steps,
t.common.thinking,
toolCallCountByMessageId,
]);
const renderDebugSummary = (
messageId: string | undefined,
stepIndex: number,
) => {
if (!showTokenDebugSummaries || !messageId) {
return null;
}
const debugStep = debugStepByMessageId.get(messageId);
if (!debugStep) {
return null;
}
if (
firstEligibleDebugSummaryStepIndexByMessageId.get(messageId) !== stepIndex
) {
return null;
}
return (
<ChainOfThoughtStep
key={`token-debug-${messageId}`}
icon={CoinsIcon}
label={
<DebugStepLabel
label={debugStep.label}
token={formatDebugToken(debugStep, t)}
/>
}
description={
debugStep.sharedAttribution
? t.tokenUsage.sharedAttribution
: undefined
}
>
{debugStep.secondaryLabels.length > 0 && (
<ChainOfThoughtSearchResults>
{debugStep.secondaryLabels.map((label, index) => (
<ChainOfThoughtSearchResult
key={`${debugStep.id}-${index}-${label}`}
>
{label}
</ChainOfThoughtSearchResult>
))}
</ChainOfThoughtSearchResults>
)}
</ChainOfThoughtStep>
);
};
const renderToolCall = (
step: CoTToolCallStep,
options?: { isLast?: boolean },
) => {
const debugStep =
showTokenDebugSummaries && step.messageId
? debugStepByMessageId.get(step.messageId)
: undefined;
return (
<ToolCall
key={step.id}
{...step}
isLast={options?.isLast}
isLoading={isLoading}
tokenDebugStep={
debugStep && !debugStep.sharedAttribution ? debugStep : undefined
}
/>
);
};
const lastReasoningDebugStep =
showTokenDebugSummaries && lastReasoningStep?.messageId
? debugStepByMessageId.get(lastReasoningStep.messageId)
: undefined;
return ( return (
<ChainOfThought <ChainOfThought
className={cn("w-full gap-2 rounded-lg border p-0.5", className)} className={cn("w-full gap-2 rounded-lg border p-0.5", className)}
@@ -111,36 +259,46 @@ export function MessageGroup({
{lastToolCallStep && ( {lastToolCallStep && (
<ChainOfThoughtContent className="px-4 pb-2"> <ChainOfThoughtContent className="px-4 pb-2">
{showAbove && {showAbove &&
aboveLastToolCallSteps.map((step) => aboveLastToolCallSteps.flatMap((step) => {
step.type === "reasoning" ? ( const stepIndex = steps.indexOf(step);
<ChainOfThoughtStep if (step.type === "reasoning") {
key={step.id} return [
label={ renderDebugSummary(step.messageId, stepIndex),
<MarkdownContent <ChainOfThoughtStep
content={step.reasoning ?? ""} key={step.id}
isLoading={isLoading} label={
rehypePlugins={rehypePlugins} <MarkdownContent
/> content={step.reasoning ?? ""}
} isLoading={isLoading}
></ChainOfThoughtStep> rehypePlugins={rehypePlugins}
) : ( />
<ToolCall key={step.id} {...step} isLoading={isLoading} /> }
), ></ChainOfThoughtStep>,
)} ];
}
return [
renderDebugSummary(step.messageId, stepIndex),
renderToolCall(step),
];
})}
{renderDebugSummary(
lastToolCallStep.messageId,
steps.indexOf(lastToolCallStep),
)}
{lastToolCallStep && ( {lastToolCallStep && (
<FlipDisplay uniqueKey={lastToolCallStep.id ?? ""}> <FlipDisplay uniqueKey={lastToolCallStep.id ?? ""}>
<ToolCall {renderToolCall(lastToolCallStep, { isLast: true })}
key={lastToolCallStep.id}
{...lastToolCallStep}
isLast={true}
isLoading={isLoading}
/>
</FlipDisplay> </FlipDisplay>
)} )}
</ChainOfThoughtContent> </ChainOfThoughtContent>
)} )}
{lastReasoningStep && ( {lastReasoningStep && (
<> <>
{renderDebugSummary(
lastReasoningStep.messageId,
steps.indexOf(lastReasoningStep),
)}
<Button <Button
key={lastReasoningStep.id} key={lastReasoningStep.id}
className="w-full items-start justify-start text-left" className="w-full items-start justify-start text-left"
@@ -150,7 +308,22 @@ export function MessageGroup({
<div className="flex w-full items-center justify-between"> <div className="flex w-full items-center justify-between">
<ChainOfThoughtStep <ChainOfThoughtStep
className="font-normal" className="font-normal"
label={t.common.thinking} label={
<DebugStepLabel
label={t.common.thinking}
token={shouldInlineThinkingToken({
debugStep: lastReasoningDebugStep,
toolCallCount: lastReasoningStep.messageId
? (toolCallCountByMessageId.get(
lastReasoningStep.messageId,
) ?? 0)
: 0,
enabled: showTokenDebugSummaries,
thinkingLabel: t.common.thinking,
t,
})}
/>
}
icon={LightbulbIcon} icon={LightbulbIcon}
></ChainOfThoughtStep> ></ChainOfThoughtStep>
<div> <div>
@@ -183,6 +356,60 @@ export function MessageGroup({
); );
} }
function formatDebugToken(
debugStep: TokenDebugStep,
t: ReturnType<typeof useI18n>["t"],
) {
return debugStep.usage
? `${formatTokenCount(debugStep.usage.totalTokens)} ${t.tokenUsage.label}`
: t.tokenUsage.unavailableShort;
}
function shouldInlineThinkingToken({
debugStep,
toolCallCount,
enabled,
thinkingLabel,
t,
}: {
debugStep?: TokenDebugStep;
toolCallCount: number;
enabled: boolean;
thinkingLabel: string;
t: ReturnType<typeof useI18n>["t"];
}) {
if (
!enabled ||
!debugStep ||
debugStep.sharedAttribution ||
toolCallCount > 0 ||
debugStep.label !== thinkingLabel
) {
return null;
}
return formatDebugToken(debugStep, t);
}
function DebugStepLabel({
label,
token,
}: {
label: React.ReactNode;
token?: string | null;
}) {
return (
<div className="flex items-center justify-between gap-3">
<div className="min-w-0 flex-1">{label}</div>
{token ? (
<div className="text-muted-foreground shrink-0 font-mono text-[11px]">
{token}
</div>
) : null}
</div>
);
}
function ToolCall({ function ToolCall({
id, id,
messageId, messageId,
@@ -191,6 +418,7 @@ function ToolCall({
result, result,
isLast = false, isLast = false,
isLoading = false, isLoading = false,
tokenDebugStep,
}: { }: {
id?: string; id?: string;
messageId?: string; messageId?: string;
@@ -199,10 +427,20 @@ function ToolCall({
result?: string | Record<string, unknown>; result?: string | Record<string, unknown>;
isLast?: boolean; isLast?: boolean;
isLoading?: boolean; isLoading?: boolean;
tokenDebugStep?: TokenDebugStep;
}) { }) {
const { t } = useI18n(); const { t } = useI18n();
const { setOpen, autoOpen, autoSelect, selectedArtifact, select } = const { setOpen, autoOpen, autoSelect, selectedArtifact, select } =
useArtifacts(); useArtifacts();
const tokenLabel = tokenDebugStep
? formatDebugToken(tokenDebugStep, t)
: null;
const resolveLabel = (fallback: React.ReactNode) =>
tokenDebugStep ? (
<DebugStepLabel label={tokenDebugStep.label} token={tokenLabel} />
) : (
fallback
);
if (name === "web_search") { if (name === "web_search") {
let label: React.ReactNode = t.toolCalls.searchForRelatedInfo; let label: React.ReactNode = t.toolCalls.searchForRelatedInfo;
@@ -210,7 +448,11 @@ function ToolCall({
label = t.toolCalls.searchOnWebFor(args.query); label = t.toolCalls.searchOnWebFor(args.query);
} }
return ( return (
<ChainOfThoughtStep key={id} label={label} icon={SearchIcon}> <ChainOfThoughtStep
key={id}
label={resolveLabel(label)}
icon={SearchIcon}
>
{Array.isArray(result) && ( {Array.isArray(result) && (
<ChainOfThoughtSearchResults> <ChainOfThoughtSearchResults>
{result.map((item) => ( {result.map((item) => (
@@ -240,7 +482,11 @@ function ToolCall({
} }
)?.results; )?.results;
return ( return (
<ChainOfThoughtStep key={id} label={label} icon={SearchIcon}> <ChainOfThoughtStep
key={id}
label={resolveLabel(label)}
icon={SearchIcon}
>
{Array.isArray(results) && ( {Array.isArray(results) && (
<ChainOfThoughtSearchResults> <ChainOfThoughtSearchResults>
{Array.isArray(results) && {Array.isArray(results) &&
@@ -280,7 +526,7 @@ function ToolCall({
return ( return (
<ChainOfThoughtStep <ChainOfThoughtStep
key={id} key={id}
label={t.toolCalls.viewWebPage} label={resolveLabel(t.toolCalls.viewWebPage)}
icon={GlobeIcon} icon={GlobeIcon}
> >
<ChainOfThoughtSearchResult> <ChainOfThoughtSearchResult>
@@ -305,7 +551,11 @@ function ToolCall({
} }
const path: string | undefined = (args as { path: string })?.path; const path: string | undefined = (args as { path: string })?.path;
return ( return (
<ChainOfThoughtStep key={id} label={description} icon={FolderOpenIcon}> <ChainOfThoughtStep
key={id}
label={resolveLabel(description)}
icon={FolderOpenIcon}
>
{path && ( {path && (
<ChainOfThoughtSearchResult className="cursor-pointer"> <ChainOfThoughtSearchResult className="cursor-pointer">
{path} {path}
@@ -321,7 +571,11 @@ function ToolCall({
} }
const { path } = args as { path: string; content: string }; const { path } = args as { path: string; content: string };
return ( return (
<ChainOfThoughtStep key={id} label={description} icon={BookOpenTextIcon}> <ChainOfThoughtStep
key={id}
label={resolveLabel(description)}
icon={BookOpenTextIcon}
>
{path && ( {path && (
<ChainOfThoughtSearchResult className="cursor-pointer"> <ChainOfThoughtSearchResult className="cursor-pointer">
{path} {path}
@@ -353,7 +607,7 @@ function ToolCall({
<ChainOfThoughtStep <ChainOfThoughtStep
key={id} key={id}
className="cursor-pointer" className="cursor-pointer"
label={description} label={resolveLabel(description)}
icon={NotebookPenIcon} icon={NotebookPenIcon}
onClick={() => { onClick={() => {
select( select(
@@ -375,13 +629,19 @@ function ToolCall({
const description: string | undefined = (args as { description: string }) const description: string | undefined = (args as { description: string })
?.description; ?.description;
if (!description) { if (!description) {
return t.toolCalls.executeCommand; return (
<ChainOfThoughtStep
key={id}
label={resolveLabel(t.toolCalls.executeCommand)}
icon={SquareTerminalIcon}
/>
);
} }
const command: string | undefined = (args as { command: string })?.command; const command: string | undefined = (args as { command: string })?.command;
return ( return (
<ChainOfThoughtStep <ChainOfThoughtStep
key={id} key={id}
label={description} label={resolveLabel(description)}
icon={SquareTerminalIcon} icon={SquareTerminalIcon}
> >
{command && ( {command && (
@@ -398,7 +658,7 @@ function ToolCall({
return ( return (
<ChainOfThoughtStep <ChainOfThoughtStep
key={id} key={id}
label={t.toolCalls.needYourHelp} label={resolveLabel(t.toolCalls.needYourHelp)}
icon={MessageCircleQuestionMarkIcon} icon={MessageCircleQuestionMarkIcon}
></ChainOfThoughtStep> ></ChainOfThoughtStep>
); );
@@ -406,7 +666,7 @@ function ToolCall({
return ( return (
<ChainOfThoughtStep <ChainOfThoughtStep
key={id} key={id}
label={t.toolCalls.writeTodos} label={resolveLabel(t.toolCalls.writeTodos)}
icon={ListTodoIcon} icon={ListTodoIcon}
></ChainOfThoughtStep> ></ChainOfThoughtStep>
); );
@@ -416,7 +676,7 @@ function ToolCall({
return ( return (
<ChainOfThoughtStep <ChainOfThoughtStep
key={id} key={id}
label={description ?? t.toolCalls.useTool(name)} label={resolveLabel(description ?? t.toolCalls.useTool(name))}
icon={WrenchIcon} icon={WrenchIcon}
></ChainOfThoughtStep> ></ChainOfThoughtStep>
); );
@@ -50,7 +50,6 @@ import { cn } from "@/lib/utils";
import { CopyButton } from "../copy-button"; import { CopyButton } from "../copy-button";
import { MarkdownContent } from "./markdown-content"; import { MarkdownContent } from "./markdown-content";
import { MessageTokenUsage } from "./message-token-usage";
function FeedbackButtons({ function FeedbackButtons({
threadId, threadId,
@@ -121,20 +120,20 @@ function FeedbackButtons({
export function MessageListItem({ export function MessageListItem({
className, className,
threadId,
message, message,
isLoading, isLoading,
tokenUsageEnabled = false,
feedback, feedback,
runId, runId,
threadId,
showCopyButton = true,
}: { }: {
className?: string; className?: string;
message: Message; message: Message;
isLoading?: boolean; isLoading?: boolean;
threadId: string; threadId: string;
tokenUsageEnabled?: boolean;
feedback?: FeedbackData | null; feedback?: FeedbackData | null;
runId?: string; runId?: string;
showCopyButton?: boolean;
}) { }) {
const isHuman = message.type === "human"; const isHuman = message.type === "human";
return ( return (
@@ -147,16 +146,17 @@ export function MessageListItem({
message={message} message={message}
isLoading={isLoading} isLoading={isLoading}
threadId={threadId} threadId={threadId}
tokenUsageEnabled={tokenUsageEnabled}
/> />
{!isLoading && ( {!isLoading && showCopyButton && (
<MessageToolbar <MessageToolbar
className={cn( className={cn(
isHuman ? "-bottom-9 justify-end" : "-bottom-8", isHuman
"absolute right-0 left-0 z-20", ? "absolute right-0 -bottom-9 left-0 justify-end"
: "absolute right-0 bottom-0 left-0",
"z-20 opacity-0 transition-opacity delay-200 duration-300 group-hover/conversation-message:opacity-100",
)} )}
> >
<div className="flex gap-1"> <div className="pointer-events-auto flex gap-1">
<CopyButton <CopyButton
clipboardData={ clipboardData={
extractContentFromMessage(message) ?? extractContentFromMessage(message) ??
@@ -213,13 +213,11 @@ function MessageContent_({
message, message,
isLoading = false, isLoading = false,
threadId, threadId,
tokenUsageEnabled = false,
}: { }: {
className?: string; className?: string;
message: Message; message: Message;
isLoading?: boolean; isLoading?: boolean;
threadId: string; threadId: string;
tokenUsageEnabled?: boolean;
}) { }) {
const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading); const rehypePlugins = useRehypeSplitWordsIntoSpans(isLoading);
const isHuman = message.type === "human"; const isHuman = message.type === "human";
@@ -297,11 +295,6 @@ function MessageContent_({
<ReasoningTrigger /> <ReasoningTrigger />
<ReasoningContent>{reasoningContent}</ReasoningContent> <ReasoningContent>{reasoningContent}</ReasoningContent>
</Reasoning> </Reasoning>
<MessageTokenUsage
enabled={tokenUsageEnabled}
isLoading={isLoading}
message={message}
/>
</AIElementMessageContent> </AIElementMessageContent>
); );
} }
@@ -339,11 +332,6 @@ function MessageContent_({
className="my-3" className="my-3"
components={components} components={components}
/> />
<MessageTokenUsage
enabled={tokenUsageEnabled}
isLoading={isLoading}
message={message}
/>
</AIElementMessageContent> </AIElementMessageContent>
); );
} }
@@ -1,6 +1,7 @@
import type { Message } from "@langchain/langgraph-sdk";
import type { BaseStream } from "@langchain/langgraph-sdk/react"; import type { BaseStream } from "@langchain/langgraph-sdk/react";
import { ChevronUpIcon, Loader2Icon } from "lucide-react"; import { ChevronUpIcon, Loader2Icon } from "lucide-react";
import { useCallback, useEffect, useRef } from "react"; import { useCallback, useEffect, useMemo, useRef } from "react";
import { import {
Conversation, Conversation,
@@ -8,15 +9,20 @@ import {
} from "@/components/ai-elements/conversation"; } from "@/components/ai-elements/conversation";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import {
buildTokenDebugSteps,
type TokenUsageInlineMode,
} from "@/core/messages/usage-model";
import { import {
extractContentFromMessage, extractContentFromMessage,
extractPresentFilesFromMessage, extractPresentFilesFromMessage,
extractReasoningContentFromMessage,
extractTextFromMessage, extractTextFromMessage,
groupMessages, getAssistantTurnUsageMessages,
getMessageGroups,
hasContent, hasContent,
hasPresentFiles, hasPresentFiles,
hasReasoning, hasReasoning,
hasToolCalls,
} from "@/core/messages/utils"; } from "@/core/messages/utils";
import { useRehypeSplitWordsIntoSpans } from "@/core/rehype"; import { useRehypeSplitWordsIntoSpans } from "@/core/rehype";
import type { Subtask } from "@/core/tasks"; import type { Subtask } from "@/core/tasks";
@@ -25,12 +31,16 @@ import type { AgentThreadState } from "@/core/threads";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { ArtifactFileList } from "../artifacts/artifact-file-list"; import { ArtifactFileList } from "../artifacts/artifact-file-list";
import { CopyButton } from "../copy-button";
import { StreamingIndicator } from "../streaming-indicator"; import { StreamingIndicator } from "../streaming-indicator";
import { MarkdownContent } from "./markdown-content"; import { MarkdownContent } from "./markdown-content";
import { MessageGroup } from "./message-group"; import { MessageGroup } from "./message-group";
import { MessageListItem } from "./message-list-item"; import { MessageListItem } from "./message-list-item";
import { MessageTokenUsageList } from "./message-token-usage"; import {
MessageTokenUsageDebugList,
MessageTokenUsageList,
} from "./message-token-usage";
import { MessageListSkeleton } from "./skeleton"; import { MessageListSkeleton } from "./skeleton";
import { SubtaskCard } from "./subtask-card"; import { SubtaskCard } from "./subtask-card";
@@ -149,7 +159,7 @@ export function MessageList({
threadId, threadId,
thread, thread,
paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM, paddingBottom = MESSAGE_LIST_DEFAULT_PADDING_BOTTOM,
tokenUsageEnabled = false, tokenUsageInlineMode = "off",
hasMoreHistory, hasMoreHistory,
loadMoreHistory, loadMoreHistory,
isHistoryLoading, isHistoryLoading,
@@ -158,7 +168,7 @@ export function MessageList({
threadId: string; threadId: string;
thread: BaseStream<AgentThreadState>; thread: BaseStream<AgentThreadState>;
paddingBottom?: number; paddingBottom?: number;
tokenUsageEnabled?: boolean; tokenUsageInlineMode?: TokenUsageInlineMode;
hasMoreHistory?: boolean; hasMoreHistory?: boolean;
loadMoreHistory?: () => void; loadMoreHistory?: () => void;
isHistoryLoading?: boolean; isHistoryLoading?: boolean;
@@ -167,10 +177,85 @@ export function MessageList({
const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading); const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading);
const updateSubtask = useUpdateSubtask(); const updateSubtask = useUpdateSubtask();
const messages = thread.messages; const messages = thread.messages;
const groupedMessages = getMessageGroups(messages);
const turnUsageMessagesByGroupIndex =
getAssistantTurnUsageMessages(groupedMessages);
const tokenDebugSteps = useMemo(
() => buildTokenDebugSteps(messages, t),
[messages, t],
);
const renderAssistantCopyButton = useCallback((messages: Message[]) => {
const clipboardData = [...messages]
.reverse()
.filter((message) => message.type === "ai")
.map((message) => {
const content = extractContentFromMessage(message);
return content ?? extractReasoningContentFromMessage(message) ?? "";
})
.find((content) => content.length > 0);
if (!clipboardData) {
return null;
}
return (
<div className="mt-2 flex justify-start opacity-0 transition-opacity delay-200 duration-300 group-hover/assistant-turn:opacity-100">
<CopyButton clipboardData={clipboardData} />
</div>
);
}, []);
const renderTokenUsage = useCallback(
({
messages,
turnUsageMessages,
inlineDebug = true,
debugMessageIds,
}: {
messages: Message[];
turnUsageMessages?: Message[] | null;
inlineDebug?: boolean;
debugMessageIds?: string[];
}) => {
if (tokenUsageInlineMode === "per_turn") {
return (
<MessageTokenUsageList
enabled={true}
isLoading={thread.isLoading}
messages={turnUsageMessages ?? []}
/>
);
}
if (tokenUsageInlineMode === "step_debug" && inlineDebug) {
const messageIds = new Set(
debugMessageIds ??
messages
.filter((message) => message.type === "ai")
.map((message) => message.id)
.filter((id): id is string => typeof id === "string"),
);
return (
<MessageTokenUsageDebugList
enabled={true}
isLoading={thread.isLoading}
steps={tokenDebugSteps.filter((step) =>
messageIds.has(step.messageId),
)}
/>
);
}
return null;
},
[thread.isLoading, tokenDebugSteps, tokenUsageInlineMode],
);
if (thread.isThreadLoading && messages.length === 0) { if (thread.isThreadLoading && messages.length === 0) {
return <MessageListSkeleton />; return <MessageListSkeleton />;
} }
return ( return (
<Conversation <Conversation
className={cn("flex size-full flex-col justify-center", className)} className={cn("flex size-full flex-col justify-center", className)}
@@ -181,19 +266,37 @@ export function MessageList({
hasMore={hasMoreHistory} hasMore={hasMoreHistory}
loadMore={loadMoreHistory} loadMore={loadMoreHistory}
/> />
{groupMessages(messages, (group) => { {groupedMessages.map((group, groupIndex) => {
const turnUsageMessages = turnUsageMessagesByGroupIndex[groupIndex];
if (group.type === "human" || group.type === "assistant") { if (group.type === "human" || group.type === "assistant") {
return group.messages.map((msg) => { return (
return ( <div
<MessageListItem key={group.id}
key={`${group.id}/${msg.id}`} className={cn(
threadId={threadId} "w-full",
message={msg} group.type === "assistant" && "group/assistant-turn",
isLoading={thread.isLoading} )}
tokenUsageEnabled={tokenUsageEnabled} >
/> {group.messages.map((msg) => {
); return (
}); <MessageListItem
key={`${group.id}/${msg.id}`}
message={msg}
isLoading={thread.isLoading}
threadId={threadId}
showCopyButton={group.type !== "assistant"}
/>
);
})}
{renderTokenUsage({
messages: group.messages,
turnUsageMessages,
})}
{group.type === "assistant" &&
renderAssistantCopyButton(group.messages)}
</div>
);
} else if (group.type === "assistant:clarification") { } else if (group.type === "assistant:clarification") {
const message = group.messages[0]; const message = group.messages[0];
if (message && hasContent(message)) { if (message && hasContent(message)) {
@@ -204,11 +307,10 @@ export function MessageList({
isLoading={thread.isLoading} isLoading={thread.isLoading}
rehypePlugins={rehypePlugins} rehypePlugins={rehypePlugins}
/> />
<MessageTokenUsageList {renderTokenUsage({
enabled={tokenUsageEnabled} messages: group.messages,
isLoading={thread.isLoading} turnUsageMessages,
messages={group.messages} })}
/>
</div> </div>
); );
} }
@@ -232,11 +334,10 @@ export function MessageList({
/> />
)} )}
<ArtifactFileList files={files} threadId={threadId} /> <ArtifactFileList files={files} threadId={threadId} />
<MessageTokenUsageList {renderTokenUsage({
enabled={tokenUsageEnabled} messages: group.messages,
isLoading={thread.isLoading} turnUsageMessages,
messages={group.messages} })}
/>
</div> </div>
); );
} else if (group.type === "assistant:subagent") { } else if (group.type === "assistant:subagent") {
@@ -289,7 +390,19 @@ export function MessageList({
} }
} }
} }
const results: React.ReactNode[] = []; const results: React.ReactNode[] = [];
const subagentDebugMessageIds: string[] = [];
if (tasks.size > 0) {
results.push(
<div
key="subtask-count"
className="text-muted-foreground pt-2 text-sm font-normal"
>
{t.subtasks.executing(tasks.size)}
</div>,
);
}
for (const message of group.messages.filter( for (const message of group.messages.filter(
(message) => message.type === "ai", (message) => message.type === "ai",
)) { )) {
@@ -299,17 +412,17 @@ export function MessageList({
key={"thinking-group-" + message.id} key={"thinking-group-" + message.id}
messages={[message]} messages={[message]}
isLoading={thread.isLoading} isLoading={thread.isLoading}
tokenDebugSteps={tokenDebugSteps.filter(
(step) => step.messageId === message.id,
)}
showTokenDebugSummaries={
tokenUsageInlineMode === "step_debug"
}
/>, />,
); );
} else if (message.id) {
subagentDebugMessageIds.push(message.id);
} }
results.push(
<div
key="subtask-count"
className="text-muted-foreground font-norma pt-2 text-sm"
>
{t.subtasks.executing(tasks.size)}
</div>,
);
const taskIds = message.tool_calls const taskIds = message.tool_calls
?.filter((toolCall) => toolCall.name === "task") ?.filter((toolCall) => toolCall.name === "task")
.map((toolCall) => toolCall.id); .map((toolCall) => toolCall.id);
@@ -329,30 +442,31 @@ export function MessageList({
className="relative z-1 flex flex-col gap-2" className="relative z-1 flex flex-col gap-2"
> >
{results} {results}
<MessageTokenUsageList {renderTokenUsage({
enabled={tokenUsageEnabled} messages: group.messages,
isLoading={thread.isLoading} turnUsageMessages,
messages={group.messages} debugMessageIds: subagentDebugMessageIds,
/> })}
</div> </div>
); );
} }
const tokenUsageMessages = group.messages.filter(
(message) =>
message.type === "ai" &&
(hasToolCalls(message) ? true : !hasContent(message)),
);
return ( return (
<div key={"group-" + group.id} className="w-full"> <div key={"group-" + group.id} className="w-full">
<MessageGroup <MessageGroup
messages={group.messages} messages={group.messages}
isLoading={thread.isLoading} isLoading={thread.isLoading}
tokenDebugSteps={tokenDebugSteps.filter((step) =>
group.messages.some(
(message) => message.id === step.messageId,
),
)}
showTokenDebugSummaries={tokenUsageInlineMode === "step_debug"}
/> />
<MessageTokenUsageList {renderTokenUsage({
enabled={tokenUsageEnabled} messages: group.messages,
isLoading={thread.isLoading} turnUsageMessages,
messages={tokenUsageMessages} inlineDebug: false,
/> })}
</div> </div>
); );
})} })}
@@ -1,29 +1,27 @@
import type { Message } from "@langchain/langgraph-sdk"; import type { Message } from "@langchain/langgraph-sdk";
import { CoinsIcon } from "lucide-react"; import { CoinsIcon } from "lucide-react";
import { Badge } from "@/components/ui/badge";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { formatTokenCount, getUsageMetadata } from "@/core/messages/usage"; import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
import type { TokenDebugStep } from "@/core/messages/usage-model";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
export function MessageTokenUsage({ function TokenUsageSummary({
className, className,
enabled = false, inputTokens,
isLoading = false, outputTokens,
message, totalTokens,
unavailable = false,
}: { }: {
className?: string; className?: string;
enabled?: boolean; inputTokens?: number;
isLoading?: boolean; outputTokens?: number;
message: Message; totalTokens?: number;
unavailable?: boolean;
}) { }) {
const { t } = useI18n(); const { t } = useI18n();
if (!enabled || isLoading || message.type !== "ai") {
return null;
}
const usage = getUsageMetadata(message);
return ( return (
<div <div
className={cn( className={cn(
@@ -35,16 +33,16 @@ export function MessageTokenUsage({
<CoinsIcon className="size-3" /> <CoinsIcon className="size-3" />
{t.tokenUsage.label} {t.tokenUsage.label}
</span> </span>
{usage ? ( {!unavailable ? (
<> <>
<span> <span>
{t.tokenUsage.input}: {formatTokenCount(usage.inputTokens)} {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span> </span>
<span> <span>
{t.tokenUsage.output}: {formatTokenCount(usage.outputTokens)} {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span> </span>
<span className="font-medium"> <span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(usage.totalTokens)} {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span> </span>
</> </>
) : ( ) : (
@@ -75,17 +73,93 @@ export function MessageTokenUsageList({
return null; return null;
} }
const usage = accumulateUsage(aiMessages);
return ( return (
<> <TokenUsageSummary
{aiMessages.map((message, index) => ( className={className}
<MessageTokenUsage inputTokens={usage?.inputTokens}
className={className} outputTokens={usage?.outputTokens}
enabled={enabled} totalTokens={usage?.totalTokens}
isLoading={isLoading} unavailable={!usage}
key={message.id ?? index} />
message={message} );
/> }
))}
</> export function MessageTokenUsageDebugList({
className,
enabled = false,
isLoading = false,
steps,
}: {
className?: string;
enabled?: boolean;
isLoading?: boolean;
steps: TokenDebugStep[];
}) {
const { t } = useI18n();
if (!enabled || isLoading) {
return null;
}
if (steps.length === 0) {
return null;
}
return (
<div className={cn("border-border/60 mt-1 border-t pt-2", className)}>
<div className="space-y-2">
{steps.map((step) => (
<div
key={step.id}
className="bg-muted/30 border-border/50 flex items-start justify-between gap-3 rounded-md border px-3 py-2"
>
<div className="min-w-0 flex-1 space-y-1">
<div className="text-foreground flex items-center gap-2 text-xs font-medium">
<CoinsIcon className="text-muted-foreground size-3" />
<span className="truncate">{step.label}</span>
</div>
{step.secondaryLabels.length > 0 && (
<div className="flex flex-wrap gap-1.5">
{step.secondaryLabels.map((label, index) => (
<Badge
key={`${step.id}-${index}-${label}`}
className="px-1.5 py-0 text-[10px] font-normal"
variant="secondary"
>
{label}
</Badge>
))}
</div>
)}
{step.sharedAttribution && (
<div className="text-muted-foreground text-[11px]">
{t.tokenUsage.sharedAttribution}
</div>
)}
<div className="text-muted-foreground text-[11px]">
{step.usage ? (
<>
{t.tokenUsage.input}:{" "}
{formatTokenCount(step.usage.inputTokens)}
{" · "}
{t.tokenUsage.output}:{" "}
{formatTokenCount(step.usage.outputTokens)}
</>
) : (
t.tokenUsage.unavailableShort
)}
</div>
</div>
<Badge className="shrink-0 font-mono" variant="outline">
{step.usage
? `${formatTokenCount(step.usage.totalTokens)} ${t.tokenUsage.label}`
: t.tokenUsage.unavailableShort}
</Badge>
</div>
))}
</div>
</div>
); );
} }
@@ -1,60 +1,81 @@
"use client"; "use client";
import type { Message } from "@langchain/langgraph-sdk"; import type { Message } from "@langchain/langgraph-sdk";
import { CoinsIcon } from "lucide-react"; import { ChevronDownIcon, CoinsIcon } from "lucide-react";
import { useMemo } from "react"; import { useMemo } from "react";
import { Button } from "@/components/ui/button";
import { import {
Tooltip, DropdownMenu,
TooltipContent, DropdownMenuContent,
TooltipTrigger, DropdownMenuLabel,
} from "@/components/ui/tooltip"; DropdownMenuRadioGroup,
DropdownMenuRadioItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { accumulateUsage, formatTokenCount } from "@/core/messages/usage"; import { accumulateUsage, formatTokenCount } from "@/core/messages/usage";
import {
getTokenUsageViewPreset,
tokenUsagePreferencesFromPreset,
type TokenUsagePreferences,
type TokenUsageViewPreset,
} from "@/core/messages/usage-model";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
interface TokenUsageIndicatorProps { interface TokenUsageIndicatorProps {
messages: Message[]; messages: Message[];
enabled?: boolean; enabled?: boolean;
preferences: TokenUsagePreferences;
onPreferencesChange: (preferences: TokenUsagePreferences) => void;
className?: string; className?: string;
} }
export function TokenUsageIndicator({ export function TokenUsageIndicator({
messages, messages,
enabled = false, enabled = false,
preferences,
onPreferencesChange,
className, className,
}: TokenUsageIndicatorProps) { }: TokenUsageIndicatorProps) {
const { t } = useI18n(); const { t } = useI18n();
const usage = useMemo(() => accumulateUsage(messages), [messages]); const usage = useMemo(() => accumulateUsage(messages), [messages]);
const preset = getTokenUsageViewPreset(preferences);
if (!enabled) { if (!enabled) {
return null; return null;
} }
return ( return (
<Tooltip delayDuration={200}> <DropdownMenu>
<TooltipTrigger asChild> <DropdownMenuTrigger asChild>
<button <Button
type="button" type="button"
variant="ghost"
className={cn( className={cn(
"text-muted-foreground bg-background/70 flex cursor-default items-center gap-1.5 rounded-full border px-2 py-1 text-xs", "text-muted-foreground bg-background/70 hover:bg-background/90 flex h-auto items-center gap-1.5 rounded-full border px-2 py-1 text-xs font-normal",
!usage && "opacity-60",
className, className,
)} )}
> >
<CoinsIcon size={14} /> <CoinsIcon size={14} />
<span>{t.tokenUsage.label}</span> <span>{t.tokenUsage.label}</span>
<span className="font-mono"> <span className="font-mono">
{usage ? formatTokenCount(usage.totalTokens) : "-"} {preferences.headerTotal
? usage
? formatTokenCount(usage.totalTokens)
: "-"
: t.tokenUsage.presets[presetKeyToTranslationKey(preset)]}
</span> </span>
</button> <ChevronDownIcon className="size-3" />
</TooltipTrigger> </Button>
<TooltipContent side="bottom" align="end"> </DropdownMenuTrigger>
<div className="space-y-1 text-xs"> <DropdownMenuContent side="bottom" align="end" className="w-80">
<div className="font-medium">{t.tokenUsage.title}</div> <DropdownMenuLabel>{t.tokenUsage.title}</DropdownMenuLabel>
<div className="px-2 py-1 text-xs">
{usage ? ( {usage ? (
<> <div className="space-y-1">
<div className="flex justify-between gap-4"> <div className="flex justify-between gap-4">
<span>{t.tokenUsage.input}</span> <span>{t.tokenUsage.input}</span>
<span className="font-mono"> <span className="font-mono">
@@ -75,14 +96,53 @@ export function TokenUsageIndicator({
</span> </span>
</div> </div>
</div> </div>
</> </div>
) : ( ) : (
<div className="text-muted-foreground max-w-56"> <div className="text-muted-foreground">
{t.tokenUsage.unavailable} {t.tokenUsage.unavailable}
</div> </div>
)} )}
</div> </div>
</TooltipContent> <DropdownMenuSeparator />
</Tooltip> <DropdownMenuLabel>{t.tokenUsage.view}</DropdownMenuLabel>
<DropdownMenuRadioGroup
value={preset}
onValueChange={(value) =>
onPreferencesChange(
tokenUsagePreferencesFromPreset(value as TokenUsageViewPreset),
)
}
>
{(
["off", "summary", "per_turn", "debug"] as TokenUsageViewPreset[]
).map((value) => {
const translationKey = presetKeyToTranslationKey(value);
return (
<DropdownMenuRadioItem key={value} value={value}>
<div className="grid gap-0.5">
<span>{t.tokenUsage.presets[translationKey]}</span>
<span className="text-muted-foreground text-xs">
{t.tokenUsage.presetDescriptions[translationKey]}
</span>
</div>
</DropdownMenuRadioItem>
);
})}
</DropdownMenuRadioGroup>
<DropdownMenuSeparator />
<div className="text-muted-foreground px-2 py-2 text-xs leading-relaxed">
{t.tokenUsage.note}
</div>
</DropdownMenuContent>
</DropdownMenu>
); );
} }
function presetKeyToTranslationKey(preset: TokenUsageViewPreset) {
switch (preset) {
case "per_turn":
return "perTurn" as const;
default:
return preset;
}
}
+23
View File
@@ -306,9 +306,32 @@ export const enUS: Translations = {
input: "Input", input: "Input",
output: "Output", output: "Output",
total: "Total", total: "Total",
view: "Display",
unavailable: unavailable:
"No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.", "No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.",
unavailableShort: "No usage returned", unavailableShort: "No usage returned",
note: "Shown from provider-returned usage_metadata. Totals are best-effort conversation totals and may differ from provider billing pages.",
presets: {
off: "Off",
summary: "Summary",
perTurn: "Per turn",
debug: "Debug",
},
presetDescriptions: {
off: "Hide token usage in the header and conversation.",
summary: "Show only the current conversation total in the header.",
perTurn:
"Show the header total and one token summary per assistant turn.",
debug: "Show the header total and step-level token debugging details.",
},
finalAnswer: "Final answer",
stepTotal: "Step total",
sharedAttribution: "Shared across multiple actions in this step",
subagent: (description: string) => `Subagent: ${description}`,
startTodo: (content: string) => `Start To-do: ${content}`,
completeTodo: (content: string) => `Complete To-do: ${content}`,
updateTodo: (content: string) => `Update To-do: ${content}`,
removeTodo: (content: string) => `Remove To-do: ${content}`,
}, },
// Shortcuts // Shortcuts
+22
View File
@@ -236,8 +236,30 @@ export interface Translations {
input: string; input: string;
output: string; output: string;
total: string; total: string;
view: string;
unavailable: string; unavailable: string;
unavailableShort: string; unavailableShort: string;
note: string;
presets: {
off: string;
summary: string;
perTurn: string;
debug: string;
};
presetDescriptions: {
off: string;
summary: string;
perTurn: string;
debug: string;
};
finalAnswer: string;
stepTotal: string;
sharedAttribution: string;
subagent: (description: string) => string;
startTodo: (content: string) => string;
completeTodo: (content: string) => string;
updateTodo: (content: string) => string;
removeTodo: (content: string) => string;
}; };
// Shortcuts // Shortcuts
+22
View File
@@ -292,9 +292,31 @@ export const zhCN: Translations = {
input: "输入", input: "输入",
output: "输出", output: "输出",
total: "总计", total: "总计",
view: "显示方式",
unavailable: unavailable:
"暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。", "暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。",
unavailableShort: "未返回用量", unavailableShort: "未返回用量",
note: "基于供应商返回的 usage_metadata 展示。当前总量是 best-effort 的会话参考值,可能与平台账单页不完全一致。",
presets: {
off: "关闭",
summary: "总览",
perTurn: "每轮",
debug: "调试",
},
presetDescriptions: {
off: "隐藏顶部和会话内的 token 展示。",
summary: "只在顶部显示当前对话累计 token。",
perTurn: "显示顶部累计,并为每轮 assistant 回复显示一条汇总 token。",
debug: "显示顶部累计,并展示按步骤归类的 token 调试信息。",
},
finalAnswer: "最终回复",
stepTotal: "步骤总计",
sharedAttribution: "该 token 由此步骤中的多个动作共同消耗",
subagent: (description: string) => `子任务:${description}`,
startTodo: (content: string) => `开始 To-do${content}`,
completeTodo: (content: string) => `完成 To-do${content}`,
updateTodo: (content: string) => `更新 To-do${content}`,
removeTodo: (content: string) => `移除 To-do${content}`,
}, },
// Shortcuts // Shortcuts
+440
View File
@@ -0,0 +1,440 @@
import type { Message } from "@langchain/langgraph-sdk";
import type { Translations } from "@/core/i18n/locales/types";
import { getUsageMetadata, type TokenUsage } from "./usage";
import { hasContent } from "./utils";
export type TokenUsageInlineMode = "off" | "per_turn" | "step_debug";
export interface TokenUsagePreferences {
headerTotal: boolean;
inlineMode: TokenUsageInlineMode;
}
export type TokenUsageViewPreset = "off" | "summary" | "per_turn" | "debug";
export interface TokenDebugStep {
id: string;
messageId: string;
label: string;
secondaryLabels: string[];
usage: TokenUsage | null;
sharedAttribution: boolean;
}
type TokenUsageAttributionAction =
| {
kind: "todo_start" | "todo_complete" | "todo_update" | "todo_remove";
content?: string;
tool_call_id?: string;
}
| {
kind: "subagent";
description?: string | null;
subagent_type?: string | null;
tool_call_id?: string;
}
| {
kind: "search";
query?: string | null;
tool_name?: string | null;
tool_call_id?: string;
}
| {
kind: "present_files" | "clarification";
tool_call_id?: string;
}
| {
kind: "tool";
tool_name?: string | null;
description?: string | null;
tool_call_id?: string;
};
interface TokenUsageAttribution {
version?: number;
kind?:
| "thinking"
| "final_answer"
| "tool_batch"
| "todo_update"
| "subagent_dispatch";
shared_attribution?: boolean;
tool_call_ids?: string[];
actions?: TokenUsageAttributionAction[];
}
// Precise write_todos labels come from the backend attribution payload.
// The frontend fallback intentionally stays generic so we do not duplicate
// backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py
//::_build_todo_actions and risk the two diffing algorithms drifting apart.
export function getTokenUsageViewPreset(
preferences: TokenUsagePreferences,
): TokenUsageViewPreset {
if (!preferences.headerTotal && preferences.inlineMode === "off") {
return "off";
}
if (preferences.headerTotal && preferences.inlineMode === "off") {
return "summary";
}
if (preferences.inlineMode === "step_debug") {
return "debug";
}
return "per_turn";
}
export function tokenUsagePreferencesFromPreset(
preset: TokenUsageViewPreset,
): TokenUsagePreferences {
switch (preset) {
case "off":
return { headerTotal: false, inlineMode: "off" };
case "summary":
return { headerTotal: true, inlineMode: "off" };
case "debug":
return { headerTotal: true, inlineMode: "step_debug" };
case "per_turn":
default:
return { headerTotal: true, inlineMode: "per_turn" };
}
}
export function buildTokenDebugSteps(
messages: Message[],
t: Translations,
): TokenDebugStep[] {
const steps: TokenDebugStep[] = [];
for (const [index, message] of messages.entries()) {
if (message.type !== "ai") {
continue;
}
const usage = getUsageMetadata(message);
const attribution = getTokenUsageAttribution(message);
const actionLabels: string[] = [];
if (attribution) {
actionLabels.push(...buildActionLabelsFromAttribution(attribution, t));
if (actionLabels.length === 0) {
if (attribution.kind === "final_answer") {
actionLabels.push(t.tokenUsage.finalAnswer);
} else if (attribution.kind === "thinking") {
actionLabels.push(t.common.thinking);
}
}
if (actionLabels.length > 0) {
const sharedAttribution =
attribution.shared_attribution ?? actionLabels.length > 1;
steps.push({
id: message.id ?? `token-step-${index}`,
messageId: message.id ?? `token-step-${index}`,
label:
sharedAttribution && actionLabels.length > 1
? t.tokenUsage.stepTotal
: actionLabels[0]!,
secondaryLabels:
sharedAttribution && actionLabels.length > 1 ? actionLabels : [],
usage,
sharedAttribution,
});
continue;
}
}
for (const toolCall of message.tool_calls ?? []) {
const toolArgs = (toolCall.args ?? {}) as Record<string, unknown>;
if (toolCall.name === "write_todos") {
actionLabels.push(t.toolCalls.writeTodos);
continue;
}
actionLabels.push(
describeToolCall(
{
name: toolCall.name,
args: toolArgs,
},
t,
),
);
}
if (actionLabels.length === 0) {
if (hasContent(message)) {
actionLabels.push(t.tokenUsage.finalAnswer);
} else {
actionLabels.push(t.common.thinking);
}
}
steps.push({
id: message.id ?? `token-step-${index}`,
messageId: message.id ?? `token-step-${index}`,
label:
actionLabels.length === 1 ? actionLabels[0]! : t.tokenUsage.stepTotal,
secondaryLabels: actionLabels.length > 1 ? actionLabels : [],
usage,
sharedAttribution: actionLabels.length > 1,
});
}
return steps;
}
function getTokenUsageAttribution(
message: Message,
): TokenUsageAttribution | null {
if (message.type !== "ai") {
return null;
}
const additionalKwargs = message.additional_kwargs;
if (!additionalKwargs || typeof additionalKwargs !== "object") {
return null;
}
const attribution = (additionalKwargs as Record<string, unknown>)
.token_usage_attribution;
const normalized = normalizeTokenUsageAttribution(attribution);
if (!normalized) {
return null;
}
return normalized;
}
function buildActionLabelsFromAttribution(
attribution: TokenUsageAttribution,
t: Translations,
): string[] {
return (attribution.actions ?? [])
.map((action) => describeAttributionAction(action, t))
.filter((label): label is string => !!label);
}
function describeAttributionAction(
action: TokenUsageAttributionAction,
t: Translations,
): string | null {
switch (action.kind) {
case "todo_start":
return action.content
? t.tokenUsage.startTodo(action.content)
: t.toolCalls.writeTodos;
case "todo_complete":
return action.content
? t.tokenUsage.completeTodo(action.content)
: t.toolCalls.writeTodos;
case "todo_update":
return action.content
? t.tokenUsage.updateTodo(action.content)
: t.toolCalls.writeTodos;
case "todo_remove":
return action.content
? t.tokenUsage.removeTodo(action.content)
: t.toolCalls.writeTodos;
case "subagent":
return t.tokenUsage.subagent(action.description ?? t.subtasks.subtask);
case "search":
if (action.query) {
return t.toolCalls.searchFor(action.query);
}
return t.toolCalls.useTool(action.tool_name ?? "search");
case "present_files":
return t.toolCalls.presentFiles;
case "clarification":
return t.toolCalls.needYourHelp;
case "tool":
return describeToolCall(
{
name: action.tool_name ?? "tool",
args: action.description ? { description: action.description } : {},
},
t,
);
default:
return null;
}
}
function describeToolCall(
toolCall: {
name: string;
args: Record<string, unknown>;
},
t: Translations,
): string {
if (toolCall.name === "task") {
const description =
typeof toolCall.args.description === "string"
? toolCall.args.description
: t.subtasks.subtask;
return t.tokenUsage.subagent(description);
}
if (
(toolCall.name === "web_search" || toolCall.name === "image_search") &&
typeof toolCall.args.query === "string"
) {
return t.toolCalls.searchFor(toolCall.args.query);
}
if (toolCall.name === "web_fetch") {
return t.toolCalls.viewWebPage;
}
if (toolCall.name === "present_files") {
return t.toolCalls.presentFiles;
}
if (toolCall.name === "ask_clarification") {
return t.toolCalls.needYourHelp;
}
if (typeof toolCall.args.description === "string") {
return toolCall.args.description;
}
return t.toolCalls.useTool(toolCall.name);
}
function normalizeTokenUsageAttribution(
value: unknown,
): TokenUsageAttribution | null {
const record = asRecord(value);
if (!record) {
return null;
}
const rawActions = record.actions;
if (rawActions !== undefined && !Array.isArray(rawActions)) {
return null;
}
return {
// Versioning is additive for now: the frontend should ignore unknown
// fields and fall back when required fields become incompatible.
version: typeof record.version === "number" ? record.version : undefined,
kind: isTokenUsageAttributionKind(record.kind) ? record.kind : undefined,
shared_attribution:
typeof record.shared_attribution === "boolean"
? record.shared_attribution
: undefined,
tool_call_ids: Array.isArray(record.tool_call_ids)
? record.tool_call_ids.filter(
(toolCallId): toolCallId is string =>
typeof toolCallId === "string" && toolCallId.trim().length > 0,
)
: undefined,
actions: Array.isArray(rawActions)
? rawActions
.map((action) => normalizeTokenUsageAttributionAction(action))
.filter(
(action): action is TokenUsageAttributionAction => action !== null,
)
: undefined,
};
}
function normalizeTokenUsageAttributionAction(
value: unknown,
): TokenUsageAttributionAction | null {
const record = asRecord(value);
if (!record) {
return null;
}
const kind = record.kind;
if (
kind !== "todo_start" &&
kind !== "todo_complete" &&
kind !== "todo_update" &&
kind !== "todo_remove" &&
kind !== "subagent" &&
kind !== "search" &&
kind !== "present_files" &&
kind !== "clarification" &&
kind !== "tool"
) {
return null;
}
const content = readString(record.content);
const toolCallId = readString(record.tool_call_id);
switch (kind) {
case "todo_start":
case "todo_complete":
case "todo_update":
case "todo_remove":
return {
kind,
content,
tool_call_id: toolCallId,
};
case "subagent":
return {
kind,
description: readString(record.description),
subagent_type: readString(record.subagent_type),
tool_call_id: toolCallId,
};
case "search":
return {
kind,
query: readString(record.query),
tool_name: readString(record.tool_name),
tool_call_id: toolCallId,
};
case "present_files":
case "clarification":
return {
kind,
tool_call_id: toolCallId,
};
case "tool":
return {
kind,
tool_name: readString(record.tool_name),
description: readString(record.description),
tool_call_id: toolCallId,
};
default:
return null;
}
}
function asRecord(value: unknown): Record<string, unknown> | null {
if (!value || typeof value !== "object" || Array.isArray(value)) {
return null;
}
return value as Record<string, unknown>;
}
function readString(value: unknown): string | undefined {
if (typeof value !== "string") {
return undefined;
}
const normalized = value.trim();
return normalized.length > 0 ? normalized : undefined;
}
function isTokenUsageAttributionKind(
value: unknown,
): value is NonNullable<TokenUsageAttribution["kind"]> {
return (
value === "thinking" ||
value === "final_answer" ||
value === "tool_batch" ||
value === "todo_update" ||
value === "subagent_dispatch"
);
}
+44 -6
View File
@@ -18,7 +18,7 @@ interface AssistantClarificationGroup extends GenericMessageGroup<"assistant:cla
interface AssistantSubagentGroup extends GenericMessageGroup<"assistant:subagent"> {} interface AssistantSubagentGroup extends GenericMessageGroup<"assistant:subagent"> {}
type MessageGroup = export type MessageGroup =
| HumanMessageGroup | HumanMessageGroup
| AssistantProcessingGroup | AssistantProcessingGroup
| AssistantMessageGroup | AssistantMessageGroup
@@ -26,10 +26,7 @@ type MessageGroup =
| AssistantClarificationGroup | AssistantClarificationGroup
| AssistantSubagentGroup; | AssistantSubagentGroup;
export function groupMessages<T>( export function getMessageGroups(messages: Message[]): MessageGroup[] {
messages: Message[],
mapper: (group: MessageGroup) => T,
): T[] {
if (messages.length === 0) { if (messages.length === 0) {
return []; return [];
} }
@@ -124,11 +121,52 @@ export function groupMessages<T>(
} }
} }
return groups return groups;
}
export function groupMessages<T>(
messages: Message[],
mapper: (group: MessageGroup) => T,
): T[] {
return getMessageGroups(messages)
.map(mapper) .map(mapper)
.filter((result) => result !== undefined && result !== null) as T[]; .filter((result) => result !== undefined && result !== null) as T[];
} }
export function getAssistantTurnUsageMessages(groups: MessageGroup[]) {
const usageMessagesByGroupIndex: Array<Message[] | null> = Array.from(
{ length: groups.length },
() => null,
);
let turnStartIndex: number | null = null;
for (const [index, group] of groups.entries()) {
if (group.type === "human") {
turnStartIndex = null;
continue;
}
turnStartIndex ??= index;
const nextGroup = groups[index + 1];
const isTurnEnd = !nextGroup || nextGroup.type === "human";
if (!isTurnEnd) {
continue;
}
usageMessagesByGroupIndex[index] = groups
.slice(turnStartIndex, index + 1)
.flatMap((currentGroup) => currentGroup.messages)
.filter((message) => message.type === "ai");
turnStartIndex = null;
}
return usageMessagesByGroupIndex;
}
export function extractTextFromMessage(message: Message) { export function extractTextFromMessage(message: Message) {
if (typeof message.content === "string") { if (typeof message.content === "string") {
return ( return (
+13
View File
@@ -1,9 +1,14 @@
import type { TokenUsageInlineMode } from "../messages/usage-model";
import type { AgentThreadContext } from "../threads"; import type { AgentThreadContext } from "../threads";
export const DEFAULT_LOCAL_SETTINGS: LocalSettings = { export const DEFAULT_LOCAL_SETTINGS: LocalSettings = {
notification: { notification: {
enabled: true, enabled: true,
}, },
tokenUsage: {
headerTotal: true,
inlineMode: "per_turn",
},
context: { context: {
model_name: undefined, model_name: undefined,
mode: undefined, mode: undefined,
@@ -22,6 +27,10 @@ export interface LocalSettings {
notification: { notification: {
enabled: boolean; enabled: boolean;
}; };
tokenUsage: {
headerTotal: boolean;
inlineMode: TokenUsageInlineMode;
};
context: Omit< context: Omit<
AgentThreadContext, AgentThreadContext,
| "thread_id" | "thread_id"
@@ -44,6 +53,10 @@ function mergeLocalSettings(settings?: Partial<LocalSettings>): LocalSettings {
...DEFAULT_LOCAL_SETTINGS.context, ...DEFAULT_LOCAL_SETTINGS.context,
...settings?.context, ...settings?.context,
}, },
tokenUsage: {
...DEFAULT_LOCAL_SETTINGS.tokenUsage,
...settings?.tokenUsage,
},
notification: { notification: {
...DEFAULT_LOCAL_SETTINGS.notification, ...DEFAULT_LOCAL_SETTINGS.notification,
...settings?.notification, ...settings?.notification,
@@ -0,0 +1,396 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import { enUS } from "@/core/i18n";
import {
buildTokenDebugSteps,
getTokenUsageViewPreset,
tokenUsagePreferencesFromPreset,
} from "@/core/messages/usage-model";
test("maps token usage presets to persisted preferences", () => {
expect(tokenUsagePreferencesFromPreset("off")).toEqual({
headerTotal: false,
inlineMode: "off",
});
expect(tokenUsagePreferencesFromPreset("summary")).toEqual({
headerTotal: true,
inlineMode: "off",
});
expect(tokenUsagePreferencesFromPreset("per_turn")).toEqual({
headerTotal: true,
inlineMode: "per_turn",
});
expect(tokenUsagePreferencesFromPreset("debug")).toEqual({
headerTotal: true,
inlineMode: "step_debug",
});
});
test("derives the active preset from persisted preferences", () => {
expect(
getTokenUsageViewPreset({
headerTotal: false,
inlineMode: "off",
}),
).toBe("off");
expect(
getTokenUsageViewPreset({
headerTotal: true,
inlineMode: "off",
}),
).toBe("summary");
expect(
getTokenUsageViewPreset({
headerTotal: true,
inlineMode: "per_turn",
}),
).toBe("per_turn");
expect(
getTokenUsageViewPreset({
headerTotal: true,
inlineMode: "step_debug",
}),
).toBe("debug");
});
test("uses generic todo labels when backend attribution is absent", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [{ content: "Draft the plan", status: "in_progress" }],
},
},
],
usage_metadata: {
input_tokens: 100,
output_tokens: 20,
total_tokens: 120,
},
},
{
id: "tool-1",
type: "tool",
name: "write_todos",
tool_call_id: "write_todos:1",
content: "ok",
},
{
id: "ai-2",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:2",
name: "write_todos",
args: {
todos: [{ content: "Draft the plan", status: "completed" }],
},
},
],
usage_metadata: { input_tokens: 50, output_tokens: 10, total_tokens: 60 },
},
{
id: "ai-3",
type: "ai",
content: "Here is the result",
usage_metadata: { input_tokens: 40, output_tokens: 15, total_tokens: 55 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Update to-do list",
sharedAttribution: false,
}),
expect.objectContaining({
messageId: "ai-2",
label: "Update to-do list",
sharedAttribution: false,
}),
expect.objectContaining({
messageId: "ai-3",
label: "Final answer",
sharedAttribution: false,
}),
]);
});
test("marks multi-action AI steps as shared attribution", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "web_search:1",
name: "web_search",
args: { query: "LangGraph stream mode" },
},
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [
{
content: "Inspect stream mode handling",
status: "in_progress",
},
],
},
},
],
usage_metadata: {
input_tokens: 120,
output_tokens: 30,
total_tokens: 150,
},
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Step total",
sharedAttribution: true,
secondaryLabels: [
'Search for "LangGraph stream mode"',
"Update to-do list",
],
}),
]);
});
test("prefers backend attribution metadata when available", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [
{
content: "Fallback label should not win",
status: "in_progress",
},
],
},
},
],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "todo_update",
shared_attribution: false,
actions: [{ kind: "todo_start", content: "Use backend attribution" }],
},
},
usage_metadata: { input_tokens: 25, output_tokens: 5, total_tokens: 30 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Start To-do: Use backend attribution",
sharedAttribution: false,
}),
]);
});
test("falls back safely when attribution payload is malformed", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "web_search:1",
name: "web_search",
args: { query: "LangGraph stream mode" },
},
],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "tool_batch",
actions: { broken: true },
},
},
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: 'Search for "LangGraph stream mode"',
sharedAttribution: false,
}),
]);
});
test("ignores attribution actions that are not objects", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "tool_batch",
shared_attribution: true,
actions: [
null,
"bad-action",
{ kind: "search", query: "valid search", ignored: "extra-field" },
],
},
},
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: 'Search for "valid search"',
}),
]);
});
test("ignores malformed attribution fields and falls back to message content", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "Real final answer",
tool_calls: [],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: null,
shared_attribution: null,
tool_call_ids: [null, "tool-1", 123],
actions: [{ query: "missing kind" }],
},
},
usage_metadata: { input_tokens: 9, output_tokens: 3, total_tokens: 12 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Final answer",
sharedAttribution: false,
}),
]);
});
test("ignores unknown top-level attribution fields", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "tool_batch",
shared_attribution: false,
unknown_field: "ignored",
actions: [{ kind: "subagent", description: "Inspect the fix" }],
},
},
usage_metadata: { input_tokens: 12, output_tokens: 4, total_tokens: 16 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Subagent: Inspect the fix",
sharedAttribution: false,
}),
]);
});
test("falls back to generic todo labels when backend attribution has no actions", () => {
const messages = [
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:1",
name: "write_todos",
args: {
todos: [{ content: "Clean up stale tasks", status: "in_progress" }],
},
},
],
usage_metadata: {
input_tokens: 100,
output_tokens: 20,
total_tokens: 120,
},
},
{
id: "ai-2",
type: "ai",
content: "",
tool_calls: [
{
id: "write_todos:2",
name: "write_todos",
args: {
todos: [],
},
},
],
additional_kwargs: {
token_usage_attribution: {
version: 1,
kind: "todo_update",
shared_attribution: false,
actions: [],
},
},
usage_metadata: { input_tokens: 30, output_tokens: 8, total_tokens: 38 },
},
] as Message[];
expect(buildTokenDebugSteps(messages, enUS)).toEqual([
expect.objectContaining({
messageId: "ai-1",
label: "Update to-do list",
}),
expect.objectContaining({
messageId: "ai-2",
label: "Update to-do list",
sharedAttribution: false,
}),
]);
});
@@ -0,0 +1,65 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import {
getAssistantTurnUsageMessages,
getMessageGroups,
} from "@/core/messages/utils";
test("aggregates token usage messages once per assistant turn", () => {
const messages = [
{
id: "human-1",
type: "human",
content: "Plan a trip",
},
{
id: "ai-1",
type: "ai",
content: "",
tool_calls: [{ id: "tool-1", name: "web_search", args: {} }],
usage_metadata: { input_tokens: 10, output_tokens: 5, total_tokens: 15 },
},
{
id: "tool-1-result",
type: "tool",
name: "web_search",
tool_call_id: "tool-1",
content: "[]",
},
{
id: "ai-2",
type: "ai",
content: "Here is the itinerary",
usage_metadata: { input_tokens: 2, output_tokens: 8, total_tokens: 10 },
},
{
id: "human-2",
type: "human",
content: "Make it shorter",
},
{
id: "ai-3",
type: "ai",
content: "Short version",
usage_metadata: { input_tokens: 1, output_tokens: 1, total_tokens: 2 },
},
] as Message[];
const groups = getMessageGroups(messages);
const usageMessagesByGroupIndex = getAssistantTurnUsageMessages(groups);
expect(groups.map((group) => group.type)).toEqual([
"human",
"assistant:processing",
"assistant",
"human",
"assistant",
]);
expect(
usageMessagesByGroupIndex.map(
(groupMessages) => groupMessages?.map((message) => message.id) ?? null,
),
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
});