747 lines
26 KiB
Python
747 lines
26 KiB
Python
"""WP-8: Tests for EventLoopNode, OutputAccumulator, LoopConfig, JudgeProtocol.
|
|
|
|
Uses real FileConversationStore (no mocks for storage) and a MockStreamingLLM
|
|
that yields pre-programmed StreamEvents to control the loop deterministically.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from framework.graph.conversation import NodeConversation
|
|
from framework.graph.event_loop_node import (
|
|
EventLoopNode,
|
|
JudgeProtocol,
|
|
JudgeVerdict,
|
|
LoopConfig,
|
|
OutputAccumulator,
|
|
)
|
|
from framework.graph.node import NodeContext, NodeProtocol, NodeSpec, SharedMemory
|
|
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
|
from framework.llm.stream_events import (
|
|
FinishEvent,
|
|
StreamErrorEvent,
|
|
TextDeltaEvent,
|
|
ToolCallEvent,
|
|
)
|
|
from framework.runtime.core import Runtime
|
|
from framework.runtime.event_bus import EventBus, EventType
|
|
from framework.storage.conversation_store import FileConversationStore
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mock LLM that yields pre-programmed stream events
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class MockStreamingLLM(LLMProvider):
|
|
"""Mock LLM that yields pre-programmed StreamEvent sequences.
|
|
|
|
Each call to stream() consumes the next scenario from the list.
|
|
Cycles back to the beginning if more calls are made than scenarios.
|
|
"""
|
|
|
|
def __init__(self, scenarios: list[list] | None = None):
|
|
self.scenarios = scenarios or []
|
|
self._call_index = 0
|
|
self.stream_calls: list[dict] = []
|
|
|
|
async def stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
system: str = "",
|
|
tools: list[Tool] | None = None,
|
|
max_tokens: int = 4096,
|
|
) -> AsyncIterator:
|
|
self.stream_calls.append({"messages": messages, "system": system, "tools": tools})
|
|
if not self.scenarios:
|
|
return
|
|
events = self.scenarios[self._call_index % len(self.scenarios)]
|
|
self._call_index += 1
|
|
for event in events:
|
|
yield event
|
|
|
|
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
|
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
|
|
|
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
|
return LLMResponse(content="", model="mock", stop_reason="stop")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: build a simple text-only scenario
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def text_scenario(text: str, input_tokens: int = 10, output_tokens: int = 5) -> list:
|
|
"""Build a stream scenario that produces text and finishes."""
|
|
return [
|
|
TextDeltaEvent(content=text, snapshot=text),
|
|
FinishEvent(
|
|
stop_reason="stop", input_tokens=input_tokens, output_tokens=output_tokens, model="mock"
|
|
),
|
|
]
|
|
|
|
|
|
def tool_call_scenario(
|
|
tool_name: str,
|
|
tool_input: dict,
|
|
tool_use_id: str = "call_1",
|
|
text: str = "",
|
|
) -> list:
|
|
"""Build a stream scenario that produces a tool call."""
|
|
events = []
|
|
if text:
|
|
events.append(TextDeltaEvent(content=text, snapshot=text))
|
|
events.append(
|
|
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input)
|
|
)
|
|
events.append(
|
|
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock")
|
|
)
|
|
return events
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def runtime():
|
|
rt = MagicMock(spec=Runtime)
|
|
rt.start_run = MagicMock(return_value="run_1")
|
|
rt.decide = MagicMock(return_value="dec_1")
|
|
rt.record_outcome = MagicMock()
|
|
rt.end_run = MagicMock()
|
|
rt.report_problem = MagicMock()
|
|
rt.set_node = MagicMock()
|
|
return rt
|
|
|
|
|
|
@pytest.fixture
|
|
def node_spec():
|
|
return NodeSpec(
|
|
id="test_loop",
|
|
name="Test Loop",
|
|
description="A test event loop node",
|
|
node_type="event_loop",
|
|
output_keys=["result"],
|
|
system_prompt="You are a test assistant.",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def memory():
|
|
return SharedMemory()
|
|
|
|
|
|
def build_ctx(runtime, node_spec, memory, llm, tools=None, input_data=None, goal_context=""):
|
|
"""Build a NodeContext for testing."""
|
|
return NodeContext(
|
|
runtime=runtime,
|
|
node_id=node_spec.id,
|
|
node_spec=node_spec,
|
|
memory=memory,
|
|
input_data=input_data or {},
|
|
llm=llm,
|
|
available_tools=tools or [],
|
|
goal_context=goal_context,
|
|
)
|
|
|
|
|
|
# ===========================================================================
|
|
# NodeProtocol conformance
|
|
# ===========================================================================
|
|
|
|
|
|
class TestNodeProtocolConformance:
|
|
def test_subclasses_node_protocol(self):
|
|
"""EventLoopNode must be a subclass of NodeProtocol."""
|
|
assert issubclass(EventLoopNode, NodeProtocol)
|
|
|
|
def test_has_execute_method(self):
|
|
node = EventLoopNode()
|
|
assert hasattr(node, "execute")
|
|
assert asyncio.iscoroutinefunction(node.execute)
|
|
|
|
def test_has_validate_input(self):
|
|
node = EventLoopNode()
|
|
assert hasattr(node, "validate_input")
|
|
|
|
|
|
# ===========================================================================
|
|
# Basic loop execution
|
|
# ===========================================================================
|
|
|
|
|
|
class TestBasicLoop:
|
|
@pytest.mark.asyncio
|
|
async def test_basic_text_only_implicit_accept(self, runtime, node_spec, memory):
|
|
"""No tools, no judge. LLM produces text, implicit accept on stop."""
|
|
# Override to no output_keys so implicit judge accepts immediately
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("Hello world")])
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
|
|
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert result.tokens_used > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_llm_returns_failure(self, runtime, node_spec, memory):
|
|
"""ctx.llm=None should return failure immediately."""
|
|
ctx = build_ctx(runtime, node_spec, memory, llm=None)
|
|
|
|
node = EventLoopNode()
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is False
|
|
assert "LLM" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_iterations_failure(self, runtime, node_spec, memory):
|
|
"""When max_iterations is reached without acceptance, should fail."""
|
|
# LLM always produces text but never calls set_output, so implicit
|
|
# judge retries asking for missing keys
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("thinking...")])
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
|
|
node = EventLoopNode(config=LoopConfig(max_iterations=2))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is False
|
|
assert "Max iterations" in result.error
|
|
|
|
|
|
# ===========================================================================
|
|
# Judge integration
|
|
# ===========================================================================
|
|
|
|
|
|
class TestJudgeIntegration:
|
|
@pytest.mark.asyncio
|
|
async def test_judge_accept(self, runtime, node_spec, memory):
|
|
"""Mock judge ACCEPT -> success."""
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("Done!")])
|
|
|
|
judge = AsyncMock(spec=JudgeProtocol)
|
|
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
judge.evaluate.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_judge_escalate(self, runtime, node_spec, memory):
|
|
"""Mock judge ESCALATE -> failure."""
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("Attempt")])
|
|
|
|
judge = AsyncMock(spec=JudgeProtocol)
|
|
judge.evaluate = AsyncMock(
|
|
return_value=JudgeVerdict(action="ESCALATE", feedback="Tone violation")
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is False
|
|
assert "escalated" in result.error.lower()
|
|
assert "Tone violation" in result.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_judge_retry_then_accept(self, runtime, node_spec, memory):
|
|
"""RETRY twice, then ACCEPT. Should run 3 iterations."""
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
text_scenario("attempt 1"),
|
|
text_scenario("attempt 2"),
|
|
text_scenario("attempt 3"),
|
|
]
|
|
)
|
|
|
|
call_count = 0
|
|
|
|
async def evaluate_fn(context):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count < 3:
|
|
return JudgeVerdict(action="RETRY", feedback="Try harder")
|
|
return JudgeVerdict(action="ACCEPT")
|
|
|
|
judge = AsyncMock(spec=JudgeProtocol)
|
|
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=10))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert call_count == 3
|
|
|
|
|
|
# ===========================================================================
|
|
# set_output tool
|
|
# ===========================================================================
|
|
|
|
|
|
class TestSetOutput:
|
|
@pytest.mark.asyncio
|
|
async def test_set_output_accumulates(self, runtime, node_spec, memory):
|
|
"""LLM calls set_output -> values appear in NodeResult.output."""
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
# Turn 1: call set_output
|
|
tool_call_scenario("set_output", {"key": "result", "value": "42"}),
|
|
# Turn 2: text response (triggers implicit judge)
|
|
text_scenario("Done, result is 42"),
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert result.output["result"] == "42"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_output_rejects_invalid_key(self, runtime, node_spec, memory):
|
|
"""set_output with key not in output_keys -> is_error=True."""
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
# Turn 1: call set_output with bad key
|
|
tool_call_scenario("set_output", {"key": "bad_key", "value": "x"}),
|
|
# Turn 2: call set_output with good key
|
|
tool_call_scenario("set_output", {"key": "result", "value": "ok"}),
|
|
# Turn 3: text done
|
|
text_scenario("Done"),
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert result.output["result"] == "ok"
|
|
assert "bad_key" not in result.output
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
|
|
"""Judge accepts but output keys are missing -> retry with hint."""
|
|
judge = AsyncMock(spec=JudgeProtocol)
|
|
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="ACCEPT"))
|
|
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
# Turn 1: text without set_output -> judge accepts but keys missing -> retry
|
|
text_scenario("I'll get to it"),
|
|
# Turn 2: set_output
|
|
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
|
# Turn 3: text -> judge accepts, keys present -> success
|
|
text_scenario("All done"),
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(judge=judge, config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert result.output["result"] == "done"
|
|
|
|
|
|
# ===========================================================================
|
|
# Stall detection
|
|
# ===========================================================================
|
|
|
|
|
|
class TestStallDetection:
|
|
@pytest.mark.asyncio
|
|
async def test_stall_detection(self, runtime, node_spec, memory):
|
|
"""3 identical responses should trigger stall detection."""
|
|
node_spec.output_keys = [] # so implicit judge would accept
|
|
# But we need the judge to RETRY so we actually get 3 identical responses
|
|
judge = AsyncMock(spec=JudgeProtocol)
|
|
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
|
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("same answer")])
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(
|
|
judge=judge,
|
|
config=LoopConfig(max_iterations=10, stall_detection_threshold=3),
|
|
)
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is False
|
|
assert "stalled" in result.error.lower()
|
|
|
|
|
|
# ===========================================================================
|
|
# EventBus lifecycle events
|
|
# ===========================================================================
|
|
|
|
|
|
class TestEventBusLifecycle:
|
|
@pytest.mark.asyncio
|
|
async def test_lifecycle_events_published(self, runtime, node_spec, memory):
|
|
"""NODE_LOOP_STARTED, NODE_LOOP_ITERATION, NODE_LOOP_COMPLETED should be published."""
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("ok")])
|
|
bus = EventBus()
|
|
|
|
received_events = []
|
|
bus.subscribe(
|
|
event_types=[
|
|
EventType.NODE_LOOP_STARTED,
|
|
EventType.NODE_LOOP_ITERATION,
|
|
EventType.NODE_LOOP_COMPLETED,
|
|
],
|
|
handler=lambda e: received_events.append(e.type),
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert EventType.NODE_LOOP_STARTED in received_events
|
|
assert EventType.NODE_LOOP_ITERATION in received_events
|
|
assert EventType.NODE_LOOP_COMPLETED in received_events
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_facing_uses_client_output_delta(self, runtime, memory):
|
|
"""client_facing=True should emit CLIENT_OUTPUT_DELTA instead of LLM_TEXT_DELTA."""
|
|
spec = NodeSpec(
|
|
id="ui_node",
|
|
name="UI Node",
|
|
description="Streams to user",
|
|
node_type="event_loop",
|
|
output_keys=[],
|
|
client_facing=True,
|
|
)
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("visible to user")])
|
|
bus = EventBus()
|
|
|
|
received_types = []
|
|
bus.subscribe(
|
|
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.LLM_TEXT_DELTA],
|
|
handler=lambda e: received_types.append(e.type),
|
|
)
|
|
|
|
ctx = build_ctx(runtime, spec, memory, llm)
|
|
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
|
await node.execute(ctx)
|
|
|
|
assert EventType.CLIENT_OUTPUT_DELTA in received_types
|
|
assert EventType.LLM_TEXT_DELTA not in received_types
|
|
|
|
|
|
# ===========================================================================
|
|
# Tool execution
|
|
# ===========================================================================
|
|
|
|
|
|
class TestToolExecution:
|
|
@pytest.mark.asyncio
|
|
async def test_tool_execution_feedback(self, runtime, node_spec, memory):
|
|
"""Tool call -> result fed back to conversation via stream loop."""
|
|
node_spec.output_keys = []
|
|
|
|
def my_tool_executor(tool_use: ToolUse) -> ToolResult:
|
|
return ToolResult(
|
|
tool_use_id=tool_use.id,
|
|
content=f"Result for {tool_use.name}",
|
|
is_error=False,
|
|
)
|
|
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
# Turn 1: call a tool
|
|
tool_call_scenario("search", {"query": "test"}, tool_use_id="call_search"),
|
|
# Turn 2: text response after seeing tool result
|
|
text_scenario("Found the answer"),
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(
|
|
runtime,
|
|
node_spec,
|
|
memory,
|
|
llm,
|
|
tools=[Tool(name="search", description="Search", parameters={})],
|
|
)
|
|
node = EventLoopNode(
|
|
tool_executor=my_tool_executor,
|
|
config=LoopConfig(max_iterations=5),
|
|
)
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
# stream() should have been called twice (tool call turn + final text turn)
|
|
assert llm._call_index >= 2
|
|
|
|
|
|
# ===========================================================================
|
|
# Write-through persistence with real FileConversationStore
|
|
# ===========================================================================
|
|
|
|
|
|
class TestWriteThroughPersistence:
|
|
@pytest.mark.asyncio
|
|
async def test_messages_written_to_store(self, tmp_path, runtime, node_spec, memory):
|
|
"""Messages should be persisted immediately via write-through."""
|
|
store = FileConversationStore(tmp_path / "conv")
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("Hello")])
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(
|
|
conversation_store=store,
|
|
config=LoopConfig(max_iterations=5),
|
|
)
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
|
|
# Verify parts were written to disk
|
|
parts = await store.read_parts()
|
|
assert len(parts) >= 2 # at least initial user msg + assistant msg
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_output_accumulator_write_through(self, tmp_path, runtime, node_spec, memory):
|
|
"""set_output values should be persisted in cursor immediately."""
|
|
store = FileConversationStore(tmp_path / "conv")
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
tool_call_scenario("set_output", {"key": "result", "value": "persisted_value"}),
|
|
text_scenario("Done"),
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(
|
|
conversation_store=store,
|
|
config=LoopConfig(max_iterations=5),
|
|
)
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
assert result.output["result"] == "persisted_value"
|
|
|
|
# Verify output was written to cursor on disk
|
|
cursor = await store.read_cursor()
|
|
assert cursor is not None
|
|
assert cursor["outputs"]["result"] == "persisted_value"
|
|
|
|
|
|
# ===========================================================================
|
|
# Crash recovery (restore from real FileConversationStore)
|
|
# ===========================================================================
|
|
|
|
|
|
class TestCrashRecovery:
|
|
@pytest.mark.asyncio
|
|
async def test_restore_from_checkpoint(self, tmp_path, runtime, node_spec, memory):
|
|
"""Populate a store with state, then verify EventLoopNode restores from it."""
|
|
store = FileConversationStore(tmp_path / "conv")
|
|
|
|
# Simulate a previous run that wrote conversation + cursor
|
|
conv = NodeConversation(
|
|
system_prompt="You are a test assistant.",
|
|
output_keys=["result"],
|
|
store=store,
|
|
)
|
|
await conv.add_user_message("Initial input")
|
|
await conv.add_assistant_message("Working on it...")
|
|
|
|
# Write cursor with iteration and outputs
|
|
await store.write_cursor(
|
|
{
|
|
"iteration": 1,
|
|
"next_seq": conv.next_seq,
|
|
"outputs": {"result": "partial_value"},
|
|
}
|
|
)
|
|
|
|
# Now create a new EventLoopNode and execute -- it should restore
|
|
node_spec.output_keys = [] # no required keys so implicit accept works
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("Continuing...")])
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(
|
|
conversation_store=store,
|
|
config=LoopConfig(max_iterations=5),
|
|
)
|
|
result = await node.execute(ctx)
|
|
|
|
assert result.success is True
|
|
# Should have the restored output
|
|
assert result.output.get("result") == "partial_value"
|
|
|
|
|
|
# ===========================================================================
|
|
# External event injection
|
|
# ===========================================================================
|
|
|
|
|
|
class TestEventInjection:
|
|
@pytest.mark.asyncio
|
|
async def test_inject_event(self, runtime, node_spec, memory):
|
|
"""inject_event() content should appear as user message in next iteration."""
|
|
node_spec.output_keys = []
|
|
|
|
judge_calls = []
|
|
|
|
async def evaluate_fn(context):
|
|
judge_calls.append(context)
|
|
if len(judge_calls) >= 2:
|
|
return JudgeVerdict(action="ACCEPT")
|
|
return JudgeVerdict(action="RETRY")
|
|
|
|
judge = AsyncMock(spec=JudgeProtocol)
|
|
judge.evaluate = AsyncMock(side_effect=evaluate_fn)
|
|
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
text_scenario("iteration 1"),
|
|
text_scenario("iteration 2"),
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(
|
|
judge=judge,
|
|
config=LoopConfig(max_iterations=5),
|
|
)
|
|
|
|
# Pre-inject an event before execute runs
|
|
await node.inject_event("Priority: CEO wants meeting rescheduled")
|
|
|
|
result = await node.execute(ctx)
|
|
assert result.success is True
|
|
|
|
# Verify the injected content made it into the LLM messages
|
|
all_messages = []
|
|
for call in llm.stream_calls:
|
|
all_messages.extend(call["messages"])
|
|
injected_found = any("[External event]" in str(m.get("content", "")) for m in all_messages)
|
|
assert injected_found
|
|
|
|
|
|
# ===========================================================================
|
|
# Pause/resume
|
|
# ===========================================================================
|
|
|
|
|
|
class TestPauseResume:
|
|
@pytest.mark.asyncio
|
|
async def test_pause_returns_early(self, runtime, node_spec, memory):
|
|
"""pause_requested in input_data should trigger early return."""
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(scenarios=[text_scenario("should not run")])
|
|
|
|
ctx = build_ctx(
|
|
runtime,
|
|
node_spec,
|
|
memory,
|
|
llm,
|
|
input_data={"pause_requested": True},
|
|
)
|
|
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
|
result = await node.execute(ctx)
|
|
|
|
# Should return success (paused, not failed)
|
|
assert result.success is True
|
|
# LLM should not have been called (paused before first turn)
|
|
assert llm._call_index == 0
|
|
|
|
|
|
# ===========================================================================
|
|
# Stream errors
|
|
# ===========================================================================
|
|
|
|
|
|
class TestStreamErrors:
|
|
@pytest.mark.asyncio
|
|
async def test_non_recoverable_stream_error_raises(self, runtime, node_spec, memory):
|
|
"""Non-recoverable StreamErrorEvent should raise RuntimeError."""
|
|
node_spec.output_keys = []
|
|
llm = MockStreamingLLM(
|
|
scenarios=[
|
|
[StreamErrorEvent(error="Connection lost", recoverable=False)],
|
|
]
|
|
)
|
|
|
|
ctx = build_ctx(runtime, node_spec, memory, llm)
|
|
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
|
|
|
with pytest.raises(RuntimeError, match="Stream error"):
|
|
await node.execute(ctx)
|
|
|
|
|
|
# ===========================================================================
|
|
# OutputAccumulator unit tests
|
|
# ===========================================================================
|
|
|
|
|
|
class TestOutputAccumulator:
|
|
@pytest.mark.asyncio
|
|
async def test_set_and_get(self):
|
|
acc = OutputAccumulator()
|
|
await acc.set("key1", "value1")
|
|
assert acc.get("key1") == "value1"
|
|
assert acc.get("nonexistent") is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_to_dict(self):
|
|
acc = OutputAccumulator()
|
|
await acc.set("a", 1)
|
|
await acc.set("b", 2)
|
|
assert acc.to_dict() == {"a": 1, "b": 2}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_has_all_keys(self):
|
|
acc = OutputAccumulator()
|
|
assert acc.has_all_keys([]) is True
|
|
assert acc.has_all_keys(["x"]) is False
|
|
await acc.set("x", "val")
|
|
assert acc.has_all_keys(["x"]) is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_through_to_real_store(self, tmp_path):
|
|
"""OutputAccumulator should write through to FileConversationStore cursor."""
|
|
store = FileConversationStore(tmp_path / "acc_test")
|
|
acc = OutputAccumulator(store=store)
|
|
|
|
await acc.set("result", "hello")
|
|
|
|
cursor = await store.read_cursor()
|
|
assert cursor["outputs"]["result"] == "hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_restore_from_real_store(self, tmp_path):
|
|
"""OutputAccumulator.restore() should rebuild from FileConversationStore."""
|
|
store = FileConversationStore(tmp_path / "acc_restore")
|
|
await store.write_cursor({"outputs": {"key1": "val1", "key2": "val2"}})
|
|
|
|
acc = await OutputAccumulator.restore(store)
|
|
assert acc.get("key1") == "val1"
|
|
assert acc.get("key2") == "val2"
|
|
assert acc.has_all_keys(["key1", "key2"]) is True
|