d02f762ab0
* feat: refine token usage display modes * docs: clarify token usage accounting semantics * fix: avoid duplicate subtask debug keys * style: format token usage tests * chore: address token attribution review feedback * Update test_token_usage_middleware.py * Update test_token_usage_middleware.py * chore: simplify token attribution fallback * fix token usage metadata follow-up handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
"""Tests for TokenUsageMiddleware attribution annotations."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
from langchain_core.messages import AIMessage
|
|
|
|
from deerflow.agents.middlewares.token_usage_middleware import (
|
|
TOKEN_USAGE_ATTRIBUTION_KEY,
|
|
TokenUsageMiddleware,
|
|
)
|
|
|
|
|
|
def _make_runtime():
|
|
runtime = MagicMock()
|
|
runtime.context = {"thread_id": "test-thread"}
|
|
return runtime
|
|
|
|
|
|
class TestTokenUsageMiddleware:
|
|
def test_annotates_todo_updates_with_structured_actions(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "write_todos:1",
|
|
"name": "write_todos",
|
|
"args": {
|
|
"todos": [
|
|
{"content": "Inspect streaming path", "status": "completed"},
|
|
{"content": "Design token attribution schema", "status": "in_progress"},
|
|
]
|
|
},
|
|
}
|
|
],
|
|
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
|
|
)
|
|
|
|
state = {
|
|
"messages": [message],
|
|
"todos": [
|
|
{"content": "Inspect streaming path", "status": "in_progress"},
|
|
{"content": "Design token attribution schema", "status": "pending"},
|
|
],
|
|
}
|
|
|
|
result = middleware.after_model(state, _make_runtime())
|
|
|
|
assert result is not None
|
|
updated_message = result["messages"][0]
|
|
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "tool_batch"
|
|
assert attribution["shared_attribution"] is True
|
|
assert attribution["tool_call_ids"] == ["write_todos:1"]
|
|
assert attribution["actions"] == [
|
|
{
|
|
"kind": "todo_complete",
|
|
"content": "Inspect streaming path",
|
|
"tool_call_id": "write_todos:1",
|
|
},
|
|
{
|
|
"kind": "todo_start",
|
|
"content": "Design token attribution schema",
|
|
"tool_call_id": "write_todos:1",
|
|
},
|
|
]
|
|
|
|
def test_annotates_subagent_and_search_steps(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "task:1",
|
|
"name": "task",
|
|
"args": {
|
|
"description": "spec-coder patch message grouping",
|
|
"subagent_type": "general-purpose",
|
|
},
|
|
},
|
|
{
|
|
"id": "web_search:1",
|
|
"name": "web_search",
|
|
"args": {"query": "LangGraph useStream messages tuple"},
|
|
},
|
|
],
|
|
)
|
|
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "tool_batch"
|
|
assert attribution["shared_attribution"] is True
|
|
assert attribution["actions"] == [
|
|
{
|
|
"kind": "subagent",
|
|
"description": "spec-coder patch message grouping",
|
|
"subagent_type": "general-purpose",
|
|
"tool_call_id": "task:1",
|
|
},
|
|
{
|
|
"kind": "search",
|
|
"tool_name": "web_search",
|
|
"query": "LangGraph useStream messages tuple",
|
|
"tool_call_id": "web_search:1",
|
|
},
|
|
]
|
|
|
|
def test_marks_final_answer_when_no_tools(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(content="Here is the final answer.")
|
|
|
|
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
|
|
|
assert result is not None
|
|
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "final_answer"
|
|
assert attribution["shared_attribution"] is False
|
|
assert attribution["actions"] == []
|
|
|
|
def test_annotates_removed_todos(self):
|
|
middleware = TokenUsageMiddleware()
|
|
message = AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{
|
|
"id": "write_todos:remove",
|
|
"name": "write_todos",
|
|
"args": {
|
|
"todos": [],
|
|
},
|
|
}
|
|
],
|
|
)
|
|
|
|
result = middleware.after_model(
|
|
{
|
|
"messages": [message],
|
|
"todos": [
|
|
{"content": "Archive obsolete plan", "status": "pending"},
|
|
],
|
|
},
|
|
_make_runtime(),
|
|
)
|
|
|
|
assert result is not None
|
|
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
|
assert attribution["kind"] == "todo_update"
|
|
assert attribution["shared_attribution"] is False
|
|
assert attribution["actions"] == [
|
|
{
|
|
"kind": "todo_remove",
|
|
"content": "Archive obsolete plan",
|
|
"tool_call_id": "write_todos:remove",
|
|
}
|
|
]
|