feat: event loop WP1-4

This commit is contained in:
Timothy
2026-01-30 11:43:19 -08:00
parent 2f03605980
commit 7e670ce0a8
7 changed files with 400 additions and 5 deletions
+4 -1
View File
@@ -69,4 +69,7 @@ exports/*
.agent-builder-sessions/*
.venv
.venv
docs/github-issues/*
core/tests/*
+24 -1
View File
@@ -1,8 +1,31 @@
"""LLM provider abstraction."""
from framework.llm.provider import LLMProvider, LLMResponse
from framework.llm.stream_events import (
FinishEvent,
ReasoningDeltaEvent,
ReasoningStartEvent,
StreamErrorEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
ToolResultEvent,
)
__all__ = ["LLMProvider", "LLMResponse"]
__all__ = [
"LLMProvider",
"LLMResponse",
"StreamEvent",
"TextDeltaEvent",
"TextEndEvent",
"ToolCallEvent",
"ToolResultEvent",
"ReasoningStartEvent",
"ReasoningDeltaEvent",
"FinishEvent",
"StreamErrorEvent",
]
try:
from framework.llm.anthropic import AnthropicProvider # noqa: F401
+166 -1
View File
@@ -7,10 +7,11 @@ Groq, and local models.
See: https://docs.litellm.ai/docs/providers
"""
import asyncio
import json
import logging
import time
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from datetime import datetime
from pathlib import Path
from typing import Any
@@ -23,6 +24,7 @@ except ImportError:
RateLimitError = Exception # type: ignore[assignment, misc]
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import StreamEvent
logger = logging.getLogger(__name__)
@@ -425,3 +427,166 @@ class LiteLLMProvider(LLMProvider):
},
},
}
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream a completion via litellm.acompletion(stream=True).
Yields StreamEvent objects as chunks arrive from the provider.
Tool call arguments are accumulated across chunks and yielded as
a single ToolCallEvent with fully parsed JSON when complete.
Empty responses (e.g. Gemini stealth rate-limits that return 200
with no content) are retried with exponential backoff, mirroring
the retry behaviour of ``_completion_with_rate_limit_retry``.
"""
from framework.llm.stream_events import (
FinishEvent,
StreamErrorEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
full_messages: list[dict[str, Any]] = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": max_tokens,
"stream": True,
**self.extra_kwargs,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
buffered_events: list[StreamEvent] = []
accumulated_text = ""
tool_calls_acc: dict[int, dict[str, str]] = {}
input_tokens = 0
output_tokens = 0
try:
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
async for chunk in response:
choice = chunk.choices[0] if chunk.choices else None
if not choice:
continue
delta = choice.delta
# --- Text content ---
if delta and delta.content:
accumulated_text += delta.content
buffered_events.append(
TextDeltaEvent(
content=delta.content,
snapshot=accumulated_text,
)
)
# --- Tool calls (accumulate across chunks) ---
if delta and delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
if tc.id:
tool_calls_acc[idx]["id"] = tc.id
if tc.function:
if tc.function.name:
tool_calls_acc[idx]["name"] = tc.function.name
if tc.function.arguments:
tool_calls_acc[idx]["arguments"] += tc.function.arguments
# --- Finish ---
if choice.finish_reason:
for _idx, tc_data in sorted(tool_calls_acc.items()):
try:
parsed_args = json.loads(tc_data["arguments"])
except (json.JSONDecodeError, KeyError):
parsed_args = {"_raw": tc_data.get("arguments", "")}
buffered_events.append(
ToolCallEvent(
tool_use_id=tc_data["id"],
tool_name=tc_data["name"],
tool_input=parsed_args,
)
)
if accumulated_text:
buffered_events.append(TextEndEvent(full_text=accumulated_text))
usage = getattr(chunk, "usage", None)
if usage:
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
output_tokens = getattr(usage, "completion_tokens", 0) or 0
buffered_events.append(
FinishEvent(
stop_reason=choice.finish_reason,
input_tokens=input_tokens,
output_tokens=output_tokens,
model=self.model,
)
)
# Check whether the stream produced any real content.
has_content = accumulated_text or tool_calls_acc
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
token_count, token_method = _estimate_tokens(
self.model,
full_messages,
)
dump_path = _dump_failed_request(
model=self.model,
kwargs=kwargs,
error_type="empty_stream",
attempt=attempt,
)
logger.warning(
f"[stream-retry] {self.model} returned empty stream — "
f"~{token_count} tokens ({token_method}). "
f"Request dumped to: {dump_path}. "
f"Retrying in {wait}s "
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
)
await asyncio.sleep(wait)
continue
# Success (or final attempt) — flush buffered events.
for event in buffered_events:
yield event
return
except RateLimitError as e:
if attempt < RATE_LIMIT_MAX_RETRIES:
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
logger.warning(
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
f"Retrying in {wait}s "
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
)
await asyncio.sleep(wait)
continue
yield StreamErrorEvent(error=str(e), recoverable=False)
return
except Exception as e:
yield StreamErrorEvent(error=str(e), recoverable=False)
return
+32 -1
View File
@@ -2,10 +2,16 @@
import json
import re
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from typing import Any
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
from framework.llm.stream_events import (
FinishEvent,
StreamEvent,
TextDeltaEvent,
TextEndEvent,
)
class MockLLMProvider(LLMProvider):
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
output_tokens=0,
stop_reason="mock_complete",
)
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream a mock completion as word-level TextDeltaEvents.
Splits the mock response into words and yields each as a separate
TextDeltaEvent with an accumulating snapshot, exercising the full
streaming pipeline without any API calls.
"""
content = self._generate_mock_response(system=system, json_mode=False)
words = content.split(" ")
accumulated = ""
for i, word in enumerate(words):
chunk = word if i == 0 else " " + word
accumulated += chunk
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
yield TextEndEvent(full_text=accumulated)
yield FinishEvent(stop_reason="mock_complete", model=self.model)
+43 -1
View File
@@ -1,7 +1,7 @@
"""LLM Provider abstraction for pluggable LLM backends."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass, field
from typing import Any
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
Final LLMResponse after tool use completes
"""
pass
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools: list[Tool] | None = None,
max_tokens: int = 4096,
) -> AsyncIterator["StreamEvent"]:
"""
Stream a completion as an async iterator of StreamEvents.
Default implementation wraps complete() with synthetic events.
Subclasses SHOULD override for true streaming.
Tool orchestration is the CALLER's responsibility:
- Caller detects ToolCallEvent, executes tool, adds result
to messages, calls stream() again.
"""
from framework.llm.stream_events import (
FinishEvent,
TextDeltaEvent,
TextEndEvent,
)
response = self.complete(
messages=messages,
system=system,
tools=tools,
max_tokens=max_tokens,
)
yield TextDeltaEvent(content=response.content, snapshot=response.content)
yield TextEndEvent(full_text=response.content)
yield FinishEvent(
stop_reason=response.stop_reason,
input_tokens=response.input_tokens,
output_tokens=response.output_tokens,
model=response.model,
)
# Deferred import target for type annotation
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
+96
View File
@@ -0,0 +1,96 @@
"""Stream event types for LLM streaming responses.
Defines a discriminated union of frozen dataclasses representing every event
a streaming LLM call can produce. These types form the contract between the
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
@dataclass(frozen=True)
class TextDeltaEvent:
"""A chunk of text produced by the LLM."""
type: Literal["text_delta"] = "text_delta"
content: str = "" # this chunk's text
snapshot: str = "" # accumulated text so far
@dataclass(frozen=True)
class TextEndEvent:
"""Signals that text generation is complete."""
type: Literal["text_end"] = "text_end"
full_text: str = ""
@dataclass(frozen=True)
class ToolCallEvent:
"""The LLM has requested a tool call."""
type: Literal["tool_call"] = "tool_call"
tool_use_id: str = ""
tool_name: str = ""
tool_input: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class ToolResultEvent:
"""Result of executing a tool call."""
type: Literal["tool_result"] = "tool_result"
tool_use_id: str = ""
content: str = ""
is_error: bool = False
@dataclass(frozen=True)
class ReasoningStartEvent:
"""The LLM has started a reasoning/thinking block."""
type: Literal["reasoning_start"] = "reasoning_start"
@dataclass(frozen=True)
class ReasoningDeltaEvent:
"""A chunk of reasoning/thinking content."""
type: Literal["reasoning_delta"] = "reasoning_delta"
content: str = ""
@dataclass(frozen=True)
class FinishEvent:
"""The LLM has finished generating."""
type: Literal["finish"] = "finish"
stop_reason: str = ""
input_tokens: int = 0
output_tokens: int = 0
model: str = ""
@dataclass(frozen=True)
class StreamErrorEvent:
"""An error occurred during streaming."""
type: Literal["error"] = "error"
error: str = ""
recoverable: bool = False
# Discriminated union of all stream event types
StreamEvent = (
TextDeltaEvent
| TextEndEvent
| ToolCallEvent
| ToolResultEvent
| ReasoningStartEvent
| ReasoningDeltaEvent
| FinishEvent
| StreamErrorEvent
)
+35
View File
@@ -41,6 +41,28 @@ class EventType(str, Enum):
STREAM_STARTED = "stream_started"
STREAM_STOPPED = "stream_stopped"
# Node event-loop lifecycle
NODE_LOOP_STARTED = "node_loop_started"
NODE_LOOP_ITERATION = "node_loop_iteration"
NODE_LOOP_COMPLETED = "node_loop_completed"
# LLM streaming observability
LLM_TEXT_DELTA = "llm_text_delta"
LLM_REASONING_DELTA = "llm_reasoning_delta"
# Tool lifecycle
TOOL_CALL_STARTED = "tool_call_started"
TOOL_CALL_COMPLETED = "tool_call_completed"
# Client I/O (client_facing=True nodes only)
CLIENT_OUTPUT_DELTA = "client_output_delta"
CLIENT_INPUT_REQUESTED = "client_input_requested"
# Internal node observability (client_facing=False nodes)
NODE_INTERNAL_OUTPUT = "node_internal_output"
NODE_INPUT_BLOCKED = "node_input_blocked"
NODE_STALLED = "node_stalled"
# Custom events
CUSTOM = "custom"
@@ -51,6 +73,7 @@ class AgentEvent:
type: EventType
stream_id: str
node_id: str | None = None # Which node emitted this event
execution_id: str | None = None
data: dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
@@ -61,6 +84,7 @@ class AgentEvent:
return {
"type": self.type.value,
"stream_id": self.stream_id,
"node_id": self.node_id,
"execution_id": self.execution_id,
"data": self.data,
"timestamp": self.timestamp.isoformat(),
@@ -80,6 +104,7 @@ class Subscription:
event_types: set[EventType]
handler: EventHandler
filter_stream: str | None = None # Only receive events from this stream
filter_node: str | None = None # Only receive events from this node
filter_execution: str | None = None # Only receive events from this execution
@@ -138,6 +163,7 @@ class EventBus:
event_types: list[EventType],
handler: EventHandler,
filter_stream: str | None = None,
filter_node: str | None = None,
filter_execution: str | None = None,
) -> str:
"""
@@ -147,6 +173,7 @@ class EventBus:
event_types: Types of events to receive
handler: Async function to call when event occurs
filter_stream: Only receive events from this stream
filter_node: Only receive events from this node
filter_execution: Only receive events from this execution
Returns:
@@ -160,6 +187,7 @@ class EventBus:
event_types=set(event_types),
handler=handler,
filter_stream=filter_stream,
filter_node=filter_node,
filter_execution=filter_execution,
)
@@ -218,6 +246,10 @@ class EventBus:
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
return False
# Check node filter
if subscription.filter_node and subscription.filter_node != event.node_id:
return False
# Check execution filter
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
return False
@@ -410,6 +442,7 @@ class EventBus:
self,
event_type: EventType,
stream_id: str | None = None,
node_id: str | None = None,
execution_id: str | None = None,
timeout: float | None = None,
) -> AgentEvent | None:
@@ -419,6 +452,7 @@ class EventBus:
Args:
event_type: Type of event to wait for
stream_id: Filter by stream
node_id: Filter by node
execution_id: Filter by execution
timeout: Maximum time to wait (seconds)
@@ -438,6 +472,7 @@ class EventBus:
event_types=[event_type],
handler=handler,
filter_stream=stream_id,
filter_node=node_id,
filter_execution=execution_id,
)