"""Tests for NodeConversation, Message, ConversationStore, and FileConversationStore.""" from __future__ import annotations import json from typing import Any import pytest from framework.graph.conversation import ( Message, NodeConversation, extract_tool_call_history, ) from framework.storage.conversation_store import FileConversationStore # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- class MockConversationStore: """In-memory dict-based store for testing.""" def __init__(self) -> None: self._parts: dict[int, dict] = {} self._meta: dict | None = None self._cursor: dict | None = None async def write_part(self, seq: int, data: dict[str, Any]) -> None: self._parts[seq] = data async def read_parts(self) -> list[dict[str, Any]]: return [self._parts[k] for k in sorted(self._parts)] async def write_meta(self, data: dict[str, Any]) -> None: self._meta = data async def read_meta(self) -> dict[str, Any] | None: return self._meta async def write_cursor(self, data: dict[str, Any]) -> None: self._cursor = data async def read_cursor(self) -> dict[str, Any] | None: return self._cursor async def delete_parts_before(self, seq: int, run_id: str | None = None) -> None: self._parts = {k: v for k, v in self._parts.items() if k >= seq} async def close(self) -> None: pass async def destroy(self) -> None: pass SAMPLE_TOOL_CALLS = [ { "id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": '{"city":"SF"}'}, } ] # =================================================================== # Message serialization # =================================================================== class TestMessage: def test_user_and_assistant_to_llm_dict(self): """User and assistant (no tools) produce simple role+content dicts.""" assert Message(seq=0, role="user", content="hi").to_llm_dict() == { "role": "user", "content": "hi", } assert Message(seq=0, role="assistant", content="hello").to_llm_dict() == { "role": "assistant", "content": "hello", } def test_assistant_to_llm_dict_with_tools(self): m = Message(seq=0, role="assistant", content="", tool_calls=SAMPLE_TOOL_CALLS) d = m.to_llm_dict() assert d["role"] == "assistant" assert d["tool_calls"] == SAMPLE_TOOL_CALLS def test_tool_to_llm_dict(self): m = Message(seq=0, role="tool", content="sunny", tool_use_id="call_1") d = m.to_llm_dict() assert d == {"role": "tool", "tool_call_id": "call_1", "content": "sunny"} def test_tool_error_to_llm_dict(self): m = Message(seq=0, role="tool", content="not found", tool_use_id="call_1", is_error=True) d = m.to_llm_dict() assert d["content"] == "ERROR: not found" assert d["tool_call_id"] == "call_1" def test_storage_roundtrip(self): m = Message(seq=5, role="assistant", content="ok", tool_calls=SAMPLE_TOOL_CALLS) restored = Message.from_storage_dict(m.to_storage_dict()) assert restored.seq == m.seq assert restored.role == m.role assert restored.content == m.content assert restored.tool_calls == m.tool_calls def test_storage_dict_edge_cases(self): """is_error is preserved; None/False fields are omitted.""" m = Message(seq=1, role="tool", content="fail", tool_use_id="c1", is_error=True) d = m.to_storage_dict() assert d["is_error"] is True assert Message.from_storage_dict(d).is_error is True d2 = Message(seq=0, role="user", content="hi").to_storage_dict() assert "tool_use_id" not in d2 assert "tool_calls" not in d2 assert "is_error" not in d2 # =================================================================== # NodeConversation (in-memory) # =================================================================== class TestNodeConversation: @pytest.mark.asyncio async def test_multi_turn_build_and_export(self): conv = NodeConversation(system_prompt="You are helpful.") await conv.add_user_message("hello") await conv.add_assistant_message("hi there") await conv.add_user_message("weather?") await conv.add_assistant_message("", tool_calls=SAMPLE_TOOL_CALLS) await conv.add_tool_result("call_1", "sunny") await conv.add_assistant_message("It's sunny!") assert conv.turn_count == 2 assert conv.message_count == 6 llm = conv.to_llm_messages() assert len(llm) == 6 assert llm[0]["role"] == "user" assert llm[3]["tool_calls"] == SAMPLE_TOOL_CALLS summary = conv.export_summary() assert "turns: 2" in summary assert "messages: 6" in summary @pytest.mark.asyncio async def test_system_prompt_excluded_from_messages(self): conv = NodeConversation(system_prompt="secret") await conv.add_user_message("hi") llm = conv.to_llm_messages() assert len(llm) == 1 assert all("secret" not in str(m) for m in llm) @pytest.mark.asyncio async def test_turn_and_seq_counting(self): """turn_count tracks user messages; next_seq increments on every add.""" conv = NodeConversation() assert conv.turn_count == 0 assert conv.next_seq == 0 await conv.add_user_message("a") assert conv.turn_count == 1 assert conv.next_seq == 1 await conv.add_assistant_message("b") assert conv.turn_count == 1 assert conv.next_seq == 2 @pytest.mark.asyncio async def test_token_estimation(self): conv = NodeConversation() await conv.add_user_message("a" * 400) # chars // 3 (4/3 safety margin over chars/4 base) assert conv.estimate_tokens() == 400 // 3 @pytest.mark.asyncio async def test_update_token_count_overrides_estimate(self): """When actual API token count is provided, estimate_tokens uses it.""" conv = NodeConversation() await conv.add_user_message("a" * 400) assert conv.estimate_tokens() == 400 // 3 # char-based fallback with safety margin conv.update_token_count(500) assert conv.estimate_tokens() == 500 # actual API value @pytest.mark.asyncio async def test_compact_resets_token_count(self): """After compaction, actual token count is cleared (recalibrates on next LLM call).""" conv = NodeConversation() await conv.add_user_message("a" * 400) conv.update_token_count(500) assert conv.estimate_tokens() == 500 await conv.compact("summary", keep_recent=0) # Falls back to char-based heuristic with 4/3 safety margin (chars // 3) assert conv.estimate_tokens() == len("summary") // 3 @pytest.mark.asyncio async def test_clear_resets_token_count(self): """clear() also resets the actual token count.""" conv = NodeConversation() await conv.add_user_message("hello") conv.update_token_count(1000) assert conv.estimate_tokens() == 1000 await conv.clear() assert conv.estimate_tokens() == 0 @pytest.mark.asyncio async def test_usage_ratio(self): """usage_ratio returns estimate / max_context_tokens.""" conv = NodeConversation(max_context_tokens=1000) await conv.add_user_message("a" * 400) # 400 // 3 = 133 tokens (with safety margin), so 133/1000 assert conv.usage_ratio() == pytest.approx(400 // 3 / 1000) conv.update_token_count(800) assert conv.usage_ratio() == pytest.approx(0.8) # 800/1000 @pytest.mark.asyncio async def test_usage_ratio_zero_budget(self): """usage_ratio returns 0 when max_context_tokens is 0 (unlimited).""" conv = NodeConversation(max_context_tokens=0) await conv.add_user_message("a" * 400) assert conv.usage_ratio() == 0.0 @pytest.mark.asyncio async def test_needs_compaction_with_actual_tokens(self): """needs_compaction uses actual API token count when available.""" conv = NodeConversation(max_context_tokens=1000, compaction_threshold=0.8) await conv.add_user_message("a" * 100) # chars/4 = 25, well under 800 assert conv.needs_compaction() is False # Simulate API reporting much higher actual token usage conv.update_token_count(850) assert conv.needs_compaction() is True @pytest.mark.asyncio async def test_needs_compaction(self): conv = NodeConversation(max_context_tokens=100, compaction_threshold=0.8) await conv.add_user_message("x" * 320) assert conv.needs_compaction() is True @pytest.mark.asyncio async def test_compact_replaces_with_summary(self): """keep_recent=0 replaces all messages; empty conversation is a no-op.""" conv = NodeConversation() await conv.compact("summary") assert conv.turn_count == 0 conv2 = NodeConversation() await conv2.add_user_message("one") await conv2.add_assistant_message("two") seq_before = conv2.next_seq await conv2.compact("summary of conversation", keep_recent=0) assert conv2.turn_count == 1 assert conv2.message_count == 1 assert conv2.messages[0].content == "summary of conversation" assert conv2.messages[0].role == "user" assert conv2.messages[0].seq == seq_before assert conv2.next_seq == seq_before + 1 @pytest.mark.asyncio async def test_compact_keep_recent_default(self): """Default keep_recent=2 keeps last 2 messages.""" conv = NodeConversation() await conv.add_user_message("m1") await conv.add_assistant_message("m2") await conv.add_user_message("m3") await conv.add_assistant_message("m4") await conv.add_user_message("m5") await conv.add_assistant_message("m6") await conv.compact("summary of early conversation") assert conv.message_count == 3 assert conv.messages[0].content == "summary of early conversation" assert conv.messages[0].role == "user" assert conv.messages[1].content == "m5" assert conv.messages[2].content == "m6" @pytest.mark.asyncio async def test_compact_keep_recent_clamped(self): """keep_recent larger than len-1 gets clamped.""" conv = NodeConversation() await conv.add_user_message("a") await conv.add_assistant_message("b") await conv.compact("summary", keep_recent=5) assert conv.message_count == 2 assert conv.messages[0].content == "summary" assert conv.messages[1].content == "b" @pytest.mark.asyncio async def test_compact_preserves_output_keys(self): """PRESERVED VALUES block appears in summary when output_keys match.""" conv = NodeConversation(output_keys=["score", "status"]) await conv.add_user_message("process this") await conv.add_assistant_message("score: 87") await conv.add_assistant_message("status = complete") await conv.add_user_message("next question") await conv.compact("conversation summary", keep_recent=1) summary_content = conv.messages[0].content assert "PRESERVED VALUES" in summary_content assert "score: 87" in summary_content assert "status: complete" in summary_content assert "CONVERSATION SUMMARY:" in summary_content assert "conversation summary" in summary_content @pytest.mark.asyncio async def test_compact_seq_arithmetic_with_keep_recent(self): """Summary seq = recent[0].seq - 1 when keeping recent messages.""" conv = NodeConversation() await conv.add_user_message("m1") # seq=0 await conv.add_assistant_message("m2") # seq=1 await conv.add_user_message("m3") # seq=2 await conv.add_assistant_message("m4") # seq=3 await conv.compact("summary", keep_recent=2) assert conv.messages[0].seq == 1 # summary assert conv.messages[1].seq == 2 # m3 assert conv.messages[2].seq == 3 # m4 assert conv.next_seq == 4 @pytest.mark.asyncio async def test_clear(self): """Clear removes messages, keeps system prompt, preserves next_seq.""" conv = NodeConversation(system_prompt="keep me") await conv.add_user_message("a") await conv.add_user_message("b") seq_before = conv.next_seq await conv.clear() assert conv.turn_count == 0 assert conv.system_prompt == "keep me" assert conv.next_seq == seq_before @pytest.mark.asyncio async def test_export_summary(self): conv = NodeConversation(system_prompt="Be helpful") await conv.add_user_message("q1") await conv.add_assistant_message("a1") s = conv.export_summary() assert "[STATS]" in s assert "turns: 1" in s assert "messages: 2" in s assert "[CONFIG]" in s assert "Be helpful" in s assert "[RECENT_MESSAGES]" in s assert "[user]" in s assert "[assistant]" in s @pytest.mark.asyncio async def test_export_summary_output_keys(self): """output_keys appear in CONFIG when set, absent when None.""" conv = NodeConversation( system_prompt="test", output_keys=["confirmed_meetings", "lead_score"], ) await conv.add_user_message("hi") assert "output_keys: confirmed_meetings, lead_score" in conv.export_summary() conv2 = NodeConversation(system_prompt="test") await conv2.add_user_message("hi") assert "output_keys" not in conv2.export_summary() # =================================================================== # Output-key extraction # =================================================================== class TestExtractProtectedValues: @pytest.mark.asyncio async def test_extract_colon_format(self): conv = NodeConversation(output_keys=["score"]) await conv.add_assistant_message("The score: 87") assert conv._extract_protected_values(conv.messages) == {"score": "87"} @pytest.mark.asyncio async def test_extract_json_format(self): conv = NodeConversation(output_keys=["meetings"]) await conv.add_assistant_message('{"meetings": ["standup", "retro"]}') assert conv._extract_protected_values(conv.messages) == {"meetings": '["standup", "retro"]'} @pytest.mark.asyncio async def test_extract_equals_format(self): conv = NodeConversation(output_keys=["status"]) await conv.add_assistant_message("status = done") assert conv._extract_protected_values(conv.messages) == {"status": "done"} @pytest.mark.asyncio async def test_extract_most_recent_wins(self): conv = NodeConversation(output_keys=["score"]) await conv.add_assistant_message("score: 50") await conv.add_assistant_message("score: 99") assert conv._extract_protected_values(conv.messages) == {"score": "99"} @pytest.mark.asyncio async def test_extract_embedded_json(self): conv = NodeConversation(output_keys=["lead_score"]) await conv.add_assistant_message( 'Based on my analysis, here are the results: {"lead_score": 87, "status": "hot"}' ) assert conv._extract_protected_values(conv.messages) == {"lead_score": "87"} @pytest.mark.asyncio async def test_extract_no_match_cases(self): """No extraction: user messages, no output_keys, key not found.""" conv = NodeConversation(output_keys=["score"]) await conv.add_user_message("score: 42") assert conv._extract_protected_values(conv.messages) == {} conv2 = NodeConversation(output_keys=None) await conv2.add_assistant_message("score: 42") assert conv2._extract_protected_values(conv2.messages) == {} conv3 = NodeConversation(output_keys=["missing_key"]) await conv3.add_assistant_message("nothing relevant here") assert conv3._extract_protected_values(conv3.messages) == {} # =================================================================== # Persistence (MockConversationStore) # =================================================================== class TestPersistence: @pytest.mark.asyncio async def test_write_through_each_add(self): store = MockConversationStore() conv = NodeConversation(store=store) await conv.add_user_message("a") await conv.add_assistant_message("b") parts = await store.read_parts() assert len(parts) == 2 assert parts[0]["content"] == "a" assert parts[1]["content"] == "b" @pytest.mark.asyncio async def test_meta_and_cursor_persistence(self): """Meta is lazily written on first add; cursor updated on each add.""" store = MockConversationStore() conv = NodeConversation(system_prompt="sys", store=store) assert store._meta is None await conv.add_user_message("trigger") assert store._meta is not None assert store._meta["system_prompt"] == "sys" assert store._cursor == {"next_seq": 1} await conv.add_user_message("b") assert store._cursor == {"next_seq": 2} @pytest.mark.asyncio async def test_restore_from_store(self): """Restore reconstructs conversation; empty store returns None.""" store = MockConversationStore() assert await NodeConversation.restore(store) is None conv = NodeConversation(system_prompt="hello", max_context_tokens=500, store=store) await conv.add_user_message("u1") await conv.add_assistant_message("a1") restored = await NodeConversation.restore(store) assert restored is not None assert restored.system_prompt == "hello" assert restored.turn_count == 1 assert restored.message_count == 2 assert restored.next_seq == 2 assert restored.messages[0].content == "u1" @pytest.mark.asyncio async def test_restore_filters_by_run_id_for_crash_recovery(self): """Restore with a non-legacy run_id only loads parts from that run. This ensures intentional restarts (new run_id) start fresh while crash recovery (same run_id) resumes correctly. Legacy parts (no run_id) and other runs' parts are excluded. """ store = MockConversationStore() await store.write_meta({"system_prompt": "hello"}) await store.write_part(0, {"seq": 0, "role": "user", "content": "legacy"}) await store.write_part(1, {"seq": 1, "role": "user", "content": "run-a", "run_id": "run-a"}) await store.write_part( 2, {"seq": 2, "role": "assistant", "content": "run-b", "run_id": "run-b"}, ) await store.write_cursor({"next_seq": 3}) restored = await NodeConversation.restore(store, run_id="run-a") assert restored is not None assert [m.content for m in restored.messages] == ["run-a"] assert restored.next_seq == 3 @pytest.mark.asyncio async def test_clear_deletes_all_parts(self): store = MockConversationStore() conv_a = NodeConversation(system_prompt="hello", store=store, run_id="run-a") conv_b = NodeConversation(system_prompt="hello", store=store, run_id="run-b") await conv_a.add_user_message("a1") await conv_b.add_user_message("b1") await conv_a.clear() restored = await NodeConversation.restore(store) assert restored is not None assert [m.content for m in restored.messages] == [] @pytest.mark.asyncio async def test_restore_preserves_tool_messages(self): store = MockConversationStore() conv = NodeConversation(store=store) await conv.add_assistant_message("", tool_calls=SAMPLE_TOOL_CALLS) await conv.add_tool_result("call_1", "result", is_error=True) restored = await NodeConversation.restore(store) assert restored is not None msgs = restored.messages assert msgs[0].tool_calls == SAMPLE_TOOL_CALLS assert msgs[1].tool_use_id == "call_1" assert msgs[1].is_error is True @pytest.mark.asyncio async def test_compact_deletes_old_parts(self): store = MockConversationStore() conv = NodeConversation(store=store) await conv.add_user_message("a") await conv.add_user_message("b") assert len(store._parts) == 2 await conv.compact("summary", keep_recent=0) assert len(store._parts) == 1 remaining = list(store._parts.values()) assert remaining[0]["content"] == "summary" @pytest.mark.asyncio async def test_compact_then_restore(self): """Compact with keep_recent persists correctly and restores.""" store = MockConversationStore() conv = NodeConversation(system_prompt="sp", store=store) await conv.add_user_message("m1") await conv.add_assistant_message("m2") await conv.add_user_message("m3") await conv.add_assistant_message("m4") await conv.compact("early summary", keep_recent=2) restored = await NodeConversation.restore(store) assert restored is not None assert restored.message_count == 3 assert restored.messages[0].content == "early summary" assert restored.messages[1].content == "m3" assert restored.messages[2].content == "m4" @pytest.mark.asyncio async def test_clear_deletes_store_parts(self): store = MockConversationStore() conv = NodeConversation(store=store) await conv.add_user_message("a") await conv.add_user_message("b") await conv.clear() assert len(store._parts) == 0 # =================================================================== # FileConversationStore # =================================================================== class TestFileConversationStore: @pytest.mark.asyncio async def test_meta_and_cursor_crud(self, tmp_path): """Write/read meta and cursor; empty reads return None.""" store = FileConversationStore(tmp_path / "conv") assert await store.read_meta() is None await store.write_meta({"system_prompt": "hi"}) assert await store.read_meta() == {"system_prompt": "hi"} await store.write_cursor({"next_seq": 5}) assert await store.read_cursor() == {"next_seq": 5} @pytest.mark.asyncio async def test_write_and_read_parts_in_order(self, tmp_path): store = FileConversationStore(tmp_path / "conv") await store.write_part(2, {"seq": 2, "content": "second"}) await store.write_part(0, {"seq": 0, "content": "first"}) await store.write_part(1, {"seq": 1, "content": "middle"}) parts = await store.read_parts() assert [p["seq"] for p in parts] == [0, 1, 2] @pytest.mark.asyncio async def test_delete_parts_before(self, tmp_path): store = FileConversationStore(tmp_path / "conv") for i in range(5): await store.write_part(i, {"seq": i}) await store.delete_parts_before(3) parts = await store.read_parts() assert [p["seq"] for p in parts] == [3, 4] @pytest.mark.asyncio async def test_idempotent_write_part(self, tmp_path): store = FileConversationStore(tmp_path / "conv") await store.write_part(0, {"seq": 0, "v": 1}) await store.write_part(0, {"seq": 0, "v": 2}) parts = await store.read_parts() assert len(parts) == 1 assert parts[0]["v"] == 2 @pytest.mark.asyncio async def test_integration_with_node_conversation(self, tmp_path): """Full round-trip: create -> add messages -> restore from file store.""" store = FileConversationStore(tmp_path / "conv") conv = NodeConversation(system_prompt="test", store=store) await conv.add_user_message("u1") await conv.add_assistant_message("a1", tool_calls=SAMPLE_TOOL_CALLS) await conv.add_tool_result("call_1", "r1", is_error=True) restored = await NodeConversation.restore(store) assert restored is not None assert restored.system_prompt == "test" assert restored.turn_count == 1 assert restored.message_count == 3 assert restored.next_seq == 3 msgs = restored.messages assert msgs[0].content == "u1" assert msgs[1].tool_calls == SAMPLE_TOOL_CALLS assert msgs[2].is_error is True llm = restored.to_llm_messages() assert llm[2]["content"] == "ERROR: r1" @pytest.mark.asyncio async def test_corrupt_part_skipped_on_read(self, tmp_path): """A corrupt JSON part file is skipped, not fatal to restore.""" store = FileConversationStore(tmp_path / "conv") await store.write_part(0, {"seq": 0, "content": "ok"}) await store.write_part(1, {"seq": 1, "content": "good"}) # Simulate crash mid-write: corrupt part 0 corrupt_path = tmp_path / "conv" / "parts" / "0000000000.json" corrupt_path.write_text("{truncated", encoding="utf-8") parts = await store.read_parts() assert len(parts) == 1 assert parts[0]["seq"] == 1 @pytest.mark.asyncio async def test_directory_structure(self, tmp_path): """Verify meta.json, cursor.json, and parts/*.json files exist after writes.""" store = FileConversationStore(tmp_path / "conv") await store.write_meta({"system_prompt": "hi"}) await store.write_cursor({"next_seq": 2}) await store.write_part(0, {"seq": 0, "content": "first"}) await store.write_part(1, {"seq": 1, "content": "second"}) base = tmp_path / "conv" assert (base / "meta.json").exists() assert (base / "cursor.json").exists() assert (base / "parts" / "0000000000.json").exists() assert (base / "parts" / "0000000001.json").exists() # =================================================================== # Integration tests — real FileConversationStore, no mocks # =================================================================== class TestConversationIntegration: """End-to-end tests using real FileConversationStore on disk. Every test creates a fresh directory, writes real JSON files, and restores from a *new* store instance (simulating process restart). """ @pytest.mark.asyncio async def test_multi_turn_agent_conversation(self, tmp_path): """Simulate a realistic agent conversation with multiple turns, tool calls, and tool results — then restore from disk.""" base = tmp_path / "agent_conv" store = FileConversationStore(base) conv = NodeConversation( system_prompt="You are a helpful travel agent.", max_context_tokens=16000, store=store, ) # Turn 1: user asks, assistant responds with tool call await conv.add_user_message("Find me flights from NYC to London next Friday.") await conv.add_assistant_message( "Let me search for flights.", tool_calls=[ { "id": "call_flight_1", "type": "function", "function": { "name": "search_flights", "arguments": '{"origin":"JFK","destination":"LHR","date":"2025-06-13"}', }, } ], ) await conv.add_tool_result( "call_flight_1", '{"flights":[{"airline":"BA","price":450,"departure":"08:00"},{"airline":"AA","price":520,"departure":"14:30"}]}', ) # Turn 2: assistant presents results, user picks one await conv.add_assistant_message( "I found 2 flights:\n" "1. British Airways at $450, departing 08:00\n" "2. American Airlines at $520, departing 14:30\n" "Which one would you like?" ) await conv.add_user_message("Book the British Airways one.") await conv.add_assistant_message( "Booking the BA flight now.", tool_calls=[ { "id": "call_book_1", "type": "function", "function": { "name": "book_flight", "arguments": '{"flight_id":"BA-JFK-LHR-0800","passenger":"user"}', }, } ], ) await conv.add_tool_result( "call_book_1", '{"confirmation":"BA-12345","status":"confirmed"}', ) await conv.add_assistant_message("Your flight is booked! Confirmation: BA-12345.") # Verify in-memory state assert conv.turn_count == 2 assert conv.message_count == 8 assert conv.next_seq == 8 # --- Simulate process restart: new store, same path --- store2 = FileConversationStore(base) restored = await NodeConversation.restore(store2) assert restored is not None assert restored.system_prompt == "You are a helpful travel agent." assert restored.turn_count == 2 assert restored.message_count == 8 assert restored.next_seq == 8 # Verify message content integrity msgs = restored.messages assert msgs[0].role == "user" assert "NYC to London" in msgs[0].content assert msgs[1].role == "assistant" assert msgs[1].tool_calls[0]["id"] == "call_flight_1" assert msgs[2].role == "tool" assert msgs[2].tool_use_id == "call_flight_1" assert "BA" in msgs[2].content assert msgs[7].content == "Your flight is booked! Confirmation: BA-12345." # Verify LLM-format output llm_msgs = restored.to_llm_messages() assert llm_msgs[0] == {"role": "user", "content": msgs[0].content} assert llm_msgs[2]["role"] == "tool" assert llm_msgs[2]["tool_call_id"] == "call_flight_1" @pytest.mark.asyncio async def test_compaction_and_restore_preserves_continuity(self, tmp_path): """Build up a long conversation, compact it, continue adding messages, then restore — verifying seq continuity and content.""" base = tmp_path / "compact_conv" store = FileConversationStore(base) conv = NodeConversation( system_prompt="research assistant", store=store, ) # Build 10 messages (5 turns) for i in range(5): await conv.add_user_message(f"question {i}") await conv.add_assistant_message(f"answer {i}") assert conv.message_count == 10 assert conv.next_seq == 10 # Compact: keep last 2 messages (question 4, answer 4) await conv.compact("Summary of questions 0-3 and their answers.", keep_recent=2) assert conv.message_count == 3 # summary + 2 recent assert conv.messages[0].content == "Summary of questions 0-3 and their answers." assert conv.messages[1].content == "question 4" assert conv.messages[2].content == "answer 4" # Continue the conversation post-compaction await conv.add_user_message("question 5") await conv.add_assistant_message("answer 5") assert conv.next_seq == 12 # Verify disk: old part files (seq 0-7) should be deleted parts_dir = base / "parts" part_files = sorted(parts_dir.glob("*.json")) part_seqs = [int(f.stem) for f in part_files] # Should have: summary (seq 7), question 4 (seq 8), answer 4 (seq 9), # question 5 (seq 10), answer 5 (seq 11) assert all(s >= 7 for s in part_seqs), f"Stale parts found: {part_seqs}" # Restore from fresh store store2 = FileConversationStore(base) restored = await NodeConversation.restore(store2) assert restored is not None assert restored.next_seq == 12 assert restored.message_count == 5 assert "Summary of questions 0-3" in restored.messages[0].content assert restored.messages[-1].content == "answer 5" # Verify seq monotonicity across all restored messages seqs = [m.seq for m in restored.messages] assert seqs == sorted(seqs), f"Seqs not monotonic: {seqs}" @pytest.mark.asyncio async def test_output_key_preservation_through_compact_and_restore(self, tmp_path): """Output keys in compacted messages survive disk persistence.""" base = tmp_path / "output_key_conv" store = FileConversationStore(base) conv = NodeConversation( system_prompt="classifier", output_keys=["classification", "confidence"], store=store, ) await conv.add_user_message("Classify this email: 'You won a prize!'") await conv.add_assistant_message('{"classification": "spam", "confidence": "0.97"}') await conv.add_user_message("What about: 'Meeting at 3pm'") await conv.add_assistant_message('{"classification": "ham", "confidence": "0.99"}') await conv.add_user_message("And: 'Buy cheap meds now'") await conv.add_assistant_message('{"classification": "spam", "confidence": "0.95"}') # Compact keeping only the last 2 messages await conv.compact("Classified 3 emails.", keep_recent=2) # The summary should contain preserved output keys from discarded messages summary_content = conv.messages[0].content assert "PRESERVED VALUES" in summary_content # Most recent values from discarded messages (msgs 0-3) are "ham"/"0.99" assert "ham" in summary_content or "spam" in summary_content # Restore and verify the preserved values survived store2 = FileConversationStore(base) restored = await NodeConversation.restore(store2) assert restored is not None assert "PRESERVED VALUES" in restored.messages[0].content @pytest.mark.asyncio async def test_tool_error_roundtrip(self, tmp_path): """Tool errors persist and restore with ERROR: prefix in LLM output.""" base = tmp_path / "error_conv" store = FileConversationStore(base) conv = NodeConversation(store=store) await conv.add_user_message("Calculate 1/0") await conv.add_assistant_message( "Let me calculate that.", tool_calls=[ { "id": "call_calc", "type": "function", "function": {"name": "calculator", "arguments": '{"expr":"1/0"}'}, } ], ) await conv.add_tool_result( "call_calc", "ZeroDivisionError: division by zero", is_error=True ) await conv.add_assistant_message("The calculation failed: division by zero is undefined.") # Restore store2 = FileConversationStore(base) restored = await NodeConversation.restore(store2) assert restored is not None tool_msg = restored.messages[2] assert tool_msg.role == "tool" assert tool_msg.is_error is True assert tool_msg.tool_use_id == "call_calc" llm_dict = tool_msg.to_llm_dict() assert llm_dict["content"].startswith("ERROR: ") assert "ZeroDivisionError" in llm_dict["content"] assert llm_dict["tool_call_id"] == "call_calc" @pytest.mark.asyncio async def test_concurrent_conversations_isolated(self, tmp_path): """Two conversations in separate directories don't interfere.""" store_a = FileConversationStore(tmp_path / "conv_a") store_b = FileConversationStore(tmp_path / "conv_b") conv_a = NodeConversation(system_prompt="Agent A", store=store_a) conv_b = NodeConversation(system_prompt="Agent B", store=store_b) await conv_a.add_user_message("Hello from A") await conv_b.add_user_message("Hello from B") await conv_a.add_assistant_message("Response A") await conv_b.add_assistant_message("Response B") await conv_b.add_user_message("Follow-up B") # Restore independently restored_a = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_a")) restored_b = await NodeConversation.restore(FileConversationStore(tmp_path / "conv_b")) assert restored_a.system_prompt == "Agent A" assert restored_b.system_prompt == "Agent B" assert restored_a.message_count == 2 assert restored_b.message_count == 3 assert restored_a.messages[0].content == "Hello from A" assert restored_b.messages[2].content == "Follow-up B" @pytest.mark.asyncio async def test_destroy_removes_all_files(self, tmp_path): """destroy() wipes the entire conversation directory.""" base = tmp_path / "doomed_conv" store = FileConversationStore(base) conv = NodeConversation(system_prompt="temp", store=store) await conv.add_user_message("ephemeral") await conv.add_assistant_message("gone soon") assert base.exists() assert (base / "meta.json").exists() assert (base / "parts").exists() await store.destroy() assert not base.exists() @pytest.mark.asyncio async def test_restore_empty_store_returns_none(self, tmp_path): """Restoring from a path that was never written to returns None.""" store = FileConversationStore(tmp_path / "empty") result = await NodeConversation.restore(store) assert result is None @pytest.mark.asyncio async def test_clear_then_continue_then_restore(self, tmp_path): """clear() removes messages but preserves seq counter for new messages.""" base = tmp_path / "clear_conv" store = FileConversationStore(base) conv = NodeConversation(system_prompt="s", store=store) await conv.add_user_message("old msg 0") await conv.add_assistant_message("old msg 1") assert conv.next_seq == 2 await conv.clear() assert conv.message_count == 0 assert conv.next_seq == 2 # seq counter preserved # Continue with new messages — seqs should start at 2 await conv.add_user_message("new msg") await conv.add_assistant_message("new response") assert conv.next_seq == 4 assert conv.messages[0].seq == 2 assert conv.messages[1].seq == 3 # Restore store2 = FileConversationStore(base) restored = await NodeConversation.restore(store2) assert restored is not None assert restored.message_count == 2 assert restored.next_seq == 4 assert restored.messages[0].content == "new msg" assert restored.messages[0].seq == 2 # --------------------------------------------------------------------------- # Helpers for aggressive compaction tests # --------------------------------------------------------------------------- def _make_tool_call(call_id: str, name: str, args: dict) -> dict: return { "id": call_id, "type": "function", "function": {"name": name, "arguments": json.dumps(args)}, } async def _build_tool_heavy_conversation( store: MockConversationStore | None = None, ) -> NodeConversation: """Build a conversation with many tool call pairs. Layout: user msg, then 5x (assistant with append_data tool_call + tool result), then 1x (assistant with set_output tool_call + tool result), then user msg + assistant msg. """ conv = NodeConversation(store=store) await conv.add_user_message("Process the data") # seq 0 for i in range(5): args = {"filename": "output.html", "content": "x" * 500} tc = [_make_tool_call(f"call_{i}", "append_data", args)] conv._messages.append( Message( seq=conv._next_seq, role="assistant", content=f"Appending part {i}", tool_calls=tc, ) ) if store: await store.write_part(conv._next_seq, conv._messages[-1].to_storage_dict()) conv._next_seq += 1 conv._messages.append( Message( seq=conv._next_seq, role="tool", content='{"success": true}', tool_use_id=f"call_{i}", ) ) if store: await store.write_part(conv._next_seq, conv._messages[-1].to_storage_dict()) conv._next_seq += 1 # set_output call — must be protected so_tc = [_make_tool_call("call_so", "set_output", {"key": "result", "value": "done"})] conv._messages.append( Message(seq=conv._next_seq, role="assistant", content="Setting output", tool_calls=so_tc) ) if store: await store.write_part(conv._next_seq, conv._messages[-1].to_storage_dict()) conv._next_seq += 1 conv._messages.append( Message( seq=conv._next_seq, role="tool", content="Output 'result' set successfully.", tool_use_id="call_so", ) ) if store: await store.write_part(conv._next_seq, conv._messages[-1].to_storage_dict()) conv._next_seq += 1 # Recent messages await conv.add_user_message("Continue") await conv.add_assistant_message("Working on it") return conv # --------------------------------------------------------------------------- # Tests: aggressive structural compaction # --------------------------------------------------------------------------- class TestAggressiveStructuralCompaction: @pytest.mark.asyncio async def test_aggressive_collapses_tool_pairs(self, tmp_path): """Aggressive mode should collapse non-essential tool pairs into a summary.""" conv = await _build_tool_heavy_conversation() spill = str(tmp_path) await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, aggressive=True, ) # The 5 append_data pairs (10 msgs) + 1 user msg should be collapsed. # Remaining: ref_msg + set_output pair (2 msgs) + 2 recent = 5 assert conv.message_count == 5 assert conv.messages[0].role == "user" # ref message assert "TOOLS ALREADY CALLED" in conv.messages[0].content assert "append_data (5x)" in conv.messages[0].content # set_output pair should be preserved assert conv.messages[1].role == "assistant" assert conv.messages[1].tool_calls is not None assert conv.messages[1].tool_calls[0]["function"]["name"] == "set_output" assert conv.messages[2].role == "tool" # Recent messages intact assert conv.messages[3].content == "Continue" assert conv.messages[4].content == "Working on it" @pytest.mark.asyncio async def test_aggressive_preserves_set_output(self, tmp_path): """set_output tool calls are always protected in aggressive mode.""" conv = await _build_tool_heavy_conversation() spill = str(tmp_path) await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, aggressive=True, ) # Find all tool calls in remaining messages tool_names = [] for msg in conv.messages: if msg.tool_calls: for tc in msg.tool_calls: tool_names.append(tc["function"]["name"]) assert "set_output" in tool_names # append_data should NOT be in remaining messages (collapsed) assert "append_data" not in tool_names @pytest.mark.asyncio async def test_aggressive_preserves_errors(self, tmp_path): """Error tool results are always protected in aggressive mode.""" conv = NodeConversation() await conv.add_user_message("Start") # Regular tool call tc1 = [_make_tool_call("call_ok", "web_search", {"query": "test"})] conv._messages.append( Message(seq=conv._next_seq, role="assistant", content="", tool_calls=tc1) ) conv._next_seq += 1 conv._messages.append( Message(seq=conv._next_seq, role="tool", content="results", tool_use_id="call_ok") ) conv._next_seq += 1 # Error tool call tc2 = [_make_tool_call("call_err", "web_scrape", {"url": "http://broken.com"})] conv._messages.append( Message(seq=conv._next_seq, role="assistant", content="", tool_calls=tc2) ) conv._next_seq += 1 conv._messages.append( Message( seq=conv._next_seq, role="tool", content="Connection timeout", tool_use_id="call_err", is_error=True, ) ) conv._next_seq += 1 await conv.add_user_message("Next") await conv.add_assistant_message("OK") spill = str(tmp_path) await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, aggressive=True, ) # Error pair should be preserved error_msgs = [m for m in conv.messages if m.role == "tool" and m.is_error] assert len(error_msgs) == 1 assert error_msgs[0].content == "Connection timeout" @pytest.mark.asyncio async def test_standard_mode_keeps_all_tool_pairs(self, tmp_path): """Non-aggressive mode should keep all tool pairs (existing behavior).""" conv = await _build_tool_heavy_conversation() spill = str(tmp_path) await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, aggressive=False, ) # All 6 tool pairs (12 msgs) should be kept as structural. # Removed: 1 user msg (freeform). Remaining: ref + 12 structural + 2 recent = 15 assert conv.message_count == 15 @pytest.mark.asyncio async def test_two_pass_sequence(self, tmp_path): """Standard pass then aggressive pass produces valid result.""" conv = await _build_tool_heavy_conversation() spill = str(tmp_path) # Pass 1: standard await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, ) after_standard = conv.message_count assert after_standard == 15 # all structural kept # Pass 2: aggressive await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, aggressive=True, ) after_aggressive = conv.message_count assert after_aggressive < after_standard # ref + set_output pair + 2 recent = 5 assert after_aggressive == 5 @pytest.mark.asyncio async def test_aggressive_persists_correctly(self, tmp_path): """Aggressive compaction correctly updates the store.""" store = MockConversationStore() conv = await _build_tool_heavy_conversation(store=store) spill = str(tmp_path) await conv.compact_preserving_structure( spillover_dir=spill, keep_recent=2, aggressive=True, ) # Verify store state matches in-memory state parts = await store.read_parts() assert len(parts) == conv.message_count class TestExtractToolCallHistory: def test_basic_extraction(self): msgs = [ Message( seq=0, role="assistant", content="", tool_calls=[ _make_tool_call("c1", "web_search", {"query": "python async"}), ], ), Message(seq=1, role="tool", content="results", tool_use_id="c1"), Message( seq=2, role="assistant", content="", tool_calls=[ _make_tool_call( "c2", "save_data", {"filename": "output.txt", "content": "data"} ), ], ), Message(seq=3, role="tool", content="saved", tool_use_id="c2"), ] result = extract_tool_call_history(msgs) assert "web_search (1x)" in result assert "save_data (1x)" in result assert "FILES SAVED: output.txt" in result def test_errors_included(self): msgs = [ Message( seq=0, role="tool", content="Connection refused", is_error=True, tool_use_id="c1", ), ] result = extract_tool_call_history(msgs) assert "ERRORS" in result assert "Connection refused" in result def test_empty_messages(self): assert extract_tool_call_history([]) == "" # --------------------------------------------------------------------------- # Tests for _is_context_too_large_error # --------------------------------------------------------------------------- class TestIsContextTooLargeError: def test_context_window_class_name(self): from framework.graph.event_loop_node import _is_context_too_large_error class ContextWindowExceededError(Exception): pass assert _is_context_too_large_error(ContextWindowExceededError("x")) def test_openai_context_length(self): from framework.graph.event_loop_node import _is_context_too_large_error err = RuntimeError("This model's maximum context length is 128000 tokens") assert _is_context_too_large_error(err) def test_anthropic_too_long(self): from framework.graph.event_loop_node import _is_context_too_large_error err = RuntimeError("prompt is too long: 150000 tokens > 100000") assert _is_context_too_large_error(err) def test_generic_exceeds_limit(self): from framework.graph.event_loop_node import _is_context_too_large_error err = ValueError("Request exceeds token limit") assert _is_context_too_large_error(err) def test_unrelated_error(self): from framework.graph.event_loop_node import _is_context_too_large_error assert not _is_context_too_large_error(ValueError("connection refused")) assert not _is_context_too_large_error(RuntimeError("timeout")) # --------------------------------------------------------------------------- # Tests for _format_messages_for_summary # --------------------------------------------------------------------------- class TestFormatMessagesForSummary: def test_user_assistant_messages(self): from framework.graph.event_loop_node import EventLoopNode msgs = [ Message(seq=0, role="user", content="Hello world"), Message(seq=1, role="assistant", content="Hi there"), ] result = EventLoopNode._format_messages_for_summary(msgs) assert "[user]: Hello world" in result assert "[assistant]: Hi there" in result def test_tool_result_truncated(self): from framework.graph.event_loop_node import EventLoopNode msgs = [ Message(seq=0, role="tool", content="x" * 1000, tool_use_id="c1"), ] result = EventLoopNode._format_messages_for_summary(msgs) assert "[tool result]:" in result assert "..." in result # Should be truncated to 500 + "..." assert len(result) < 600 def test_assistant_with_tool_calls(self): from framework.graph.event_loop_node import EventLoopNode tc = [_make_tool_call("c1", "web_search", {"query": "test"})] msgs = [ Message(seq=0, role="assistant", content="Searching", tool_calls=tc), ] result = EventLoopNode._format_messages_for_summary(msgs) assert "web_search" in result assert "[assistant (calls:" in result # --------------------------------------------------------------------------- # Tests for _llm_compact (recursive binary-search) # --------------------------------------------------------------------------- class TestLlmCompact: """Test the recursive LLM compaction with mock LLM.""" def _make_node(self): """Create a minimal EventLoopNode for testing.""" from framework.graph.event_loop_node import EventLoopNode, LoopConfig config = LoopConfig(max_context_tokens=32000) node = EventLoopNode.__new__(EventLoopNode) node._config = config node._event_bus = None node._judge = None node._approval_callback = None node._tool_executor = None node._adaptive_learner = None # Set class-level constants (already on class, but explicit) return node def _make_ctx(self, llm_responses=None, llm_error=None): """Create a mock NodeContext with controllable LLM.""" from unittest.mock import AsyncMock, MagicMock from framework.graph.node import NodeSpec spec = NodeSpec( id="test", name="Test Node", description="A test node", node_type="event_loop", input_keys=[], output_keys=["result"], ) ctx = MagicMock() ctx.node_spec = spec ctx.node_id = "test" ctx.stream_id = "test" ctx.continuous_mode = False ctx.runtime_logger = None mock_llm = AsyncMock() if llm_error: mock_llm.acomplete.side_effect = llm_error elif llm_responses: responses = [] for text in llm_responses: resp = MagicMock() resp.content = text responses.append(resp) mock_llm.acomplete.side_effect = responses else: resp = MagicMock() resp.content = "Summary of conversation." mock_llm.acomplete.return_value = resp ctx.llm = mock_llm return ctx @pytest.mark.asyncio async def test_single_call_success(self): node = self._make_node() ctx = self._make_ctx() msgs = [ Message(seq=0, role="user", content="Do something"), Message(seq=1, role="assistant", content="Done"), ] result = await node._llm_compact(ctx, msgs, None) assert "Summary of conversation." in result ctx.llm.acomplete.assert_called_once() @pytest.mark.asyncio async def test_context_too_large_triggers_split(self): """When LLM raises context error, should split and retry.""" from unittest.mock import MagicMock node = self._make_node() call_count = 0 async def mock_acomplete(**kwargs): nonlocal call_count call_count += 1 # First call with full messages → fail # Subsequent calls with smaller chunks → succeed if call_count == 1: raise RuntimeError("This model's maximum context length is 128000 tokens") resp = MagicMock() resp.content = f"Summary part {call_count}" return resp ctx = self._make_ctx() ctx.llm.acomplete = mock_acomplete msgs = [Message(seq=i, role="user", content=f"Message {i}") for i in range(10)] result = await node._llm_compact(ctx, msgs, None) # Should have split and produced two summaries assert "Summary part" in result assert call_count >= 3 # 1 failure + 2 successful halves @pytest.mark.asyncio async def test_non_context_error_propagates(self): """Non-context errors should propagate, not trigger splitting.""" node = self._make_node() ctx = self._make_ctx(llm_error=ValueError("API key invalid")) msgs = [ Message(seq=0, role="user", content="Hello"), Message(seq=1, role="assistant", content="Hi"), ] with pytest.raises(ValueError, match="API key invalid"): await node._llm_compact(ctx, msgs, None) @pytest.mark.asyncio async def test_proactive_split_for_large_input(self): """Messages exceeding char limit should be split proactively.""" node = self._make_node() # Lower the limit for testing node._LLM_COMPACT_CHAR_LIMIT = 100 ctx = self._make_ctx( llm_responses=["Part 1 summary", "Part 2 summary"], ) msgs = [ Message(seq=0, role="user", content="x" * 80), Message(seq=1, role="user", content="y" * 80), ] result = await node._llm_compact(ctx, msgs, None) assert "Part 1 summary" in result assert "Part 2 summary" in result # LLM should have been called twice (no failure, proactive split) assert ctx.llm.acomplete.call_count == 2 @pytest.mark.asyncio async def test_tool_history_appended_at_top_level(self): """Tool history should only be appended at depth 0.""" node = self._make_node() ctx = self._make_ctx() tc = [_make_tool_call("c1", "web_search", {"query": "test"})] msgs = [ Message(seq=0, role="assistant", content="", tool_calls=tc), Message(seq=1, role="tool", content="results", tool_use_id="c1"), ] result = await node._llm_compact(ctx, msgs, None) assert "TOOLS ALREADY CALLED" in result assert "web_search" in result # --------------------------------------------------------------------------- # Orphaned tool result repair # --------------------------------------------------------------------------- class TestRepairOrphanedToolCalls: """Test _repair_orphaned_tool_calls handles both directions.""" def test_orphaned_tool_result_dropped(self): """Tool result with no matching tool_use should be dropped.""" msgs = [ # tool result with no preceding assistant tool_use {"role": "tool", "tool_call_id": "orphan_1", "content": "stale result"}, {"role": "user", "content": "hello"}, {"role": "assistant", "content": "hi"}, ] repaired = NodeConversation._repair_orphaned_tool_calls(msgs) assert len(repaired) == 2 assert repaired[0]["role"] == "user" assert repaired[1]["role"] == "assistant" def test_valid_tool_pair_preserved(self): """Tool result with matching tool_use should be kept.""" msgs = [ {"role": "user", "content": "search"}, { "role": "assistant", "content": "", "tool_calls": [{"id": "tc_1", "function": {"name": "search", "arguments": "{}"}}], }, {"role": "tool", "tool_call_id": "tc_1", "content": "results"}, ] repaired = NodeConversation._repair_orphaned_tool_calls(msgs) assert len(repaired) == 3 assert repaired[2]["tool_call_id"] == "tc_1" def test_orphaned_tool_use_gets_stub(self): """Tool use with no following tool result gets a synthetic error stub.""" msgs = [ {"role": "user", "content": "search"}, { "role": "assistant", "content": "", "tool_calls": [{"id": "tc_1", "function": {"name": "search", "arguments": "{}"}}], }, # No tool result follows {"role": "user", "content": "what happened?"}, ] repaired = NodeConversation._repair_orphaned_tool_calls(msgs) # Should insert a synthetic tool result between assistant and user assert len(repaired) == 4 assert repaired[2]["role"] == "tool" assert repaired[2]["tool_call_id"] == "tc_1" assert "interrupted" in repaired[2]["content"].lower() def test_mixed_orphans(self): """Both orphaned results and orphaned calls handled together.""" msgs = [ # Orphaned result (no matching tool_use) {"role": "tool", "tool_call_id": "gone_1", "content": "old result"}, {"role": "user", "content": "try again"}, { "role": "assistant", "content": "", "tool_calls": [{"id": "tc_2", "function": {"name": "fetch", "arguments": "{}"}}], }, # Missing result for tc_2 {"role": "user", "content": "done?"}, ] repaired = NodeConversation._repair_orphaned_tool_calls(msgs) # orphaned result dropped, stub added for tc_2 roles = [m["role"] for m in repaired] assert roles == ["user", "assistant", "tool", "user"] assert repaired[2]["tool_call_id"] == "tc_2"