feat: event loop WP1-4
This commit is contained in:
+4
-1
@@ -69,4 +69,7 @@ exports/*
|
||||
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.venv
|
||||
.venv
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*
|
||||
|
||||
@@ -1,8 +1,31 @@
|
||||
"""LLM provider abstraction."""
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
ReasoningDeltaEvent,
|
||||
ReasoningStartEvent,
|
||||
StreamErrorEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"StreamEvent",
|
||||
"TextDeltaEvent",
|
||||
"TextEndEvent",
|
||||
"ToolCallEvent",
|
||||
"ToolResultEvent",
|
||||
"ReasoningStartEvent",
|
||||
"ReasoningDeltaEvent",
|
||||
"FinishEvent",
|
||||
"StreamErrorEvent",
|
||||
]
|
||||
|
||||
try:
|
||||
from framework.llm.anthropic import AnthropicProvider # noqa: F401
|
||||
|
||||
@@ -7,10 +7,11 @@ Groq, and local models.
|
||||
See: https://docs.litellm.ai/docs/providers
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,6 +24,7 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -425,3 +427,166 @@ class LiteLLMProvider(LLMProvider):
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
Yields StreamEvent objects as chunks arrive from the provider.
|
||||
Tool call arguments are accumulated across chunks and yielded as
|
||||
a single ToolCallEvent with fully parsed JSON when complete.
|
||||
|
||||
Empty responses (e.g. Gemini stealth rate-limits that return 200
|
||||
with no content) are retried with exponential backoff, mirroring
|
||||
the retry behaviour of ``_completion_with_rate_limit_retry``.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
buffered_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# --- Text content ---
|
||||
if delta and delta.content:
|
||||
accumulated_text += delta.content
|
||||
buffered_events.append(
|
||||
TextDeltaEvent(
|
||||
content=delta.content,
|
||||
snapshot=accumulated_text,
|
||||
)
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
# --- Finish ---
|
||||
if choice.finish_reason:
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
try:
|
||||
parsed_args = json.loads(tc_data["arguments"])
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
parsed_args = {"_raw": tc_data.get("arguments", "")}
|
||||
buffered_events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tc_data["id"],
|
||||
tool_name=tc_data["name"],
|
||||
tool_input=parsed_args,
|
||||
)
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
buffered_events.append(TextEndEvent(full_text=accumulated_text))
|
||||
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
buffered_events.append(
|
||||
FinishEvent(
|
||||
stop_reason=choice.finish_reason,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether the stream produced any real content.
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush buffered events.
|
||||
for event in buffered_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
except RateLimitError as e:
|
||||
if attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
wait = RATE_LIMIT_BACKOFF_BASE * (2**attempt)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} rate limited (429): {e!s}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
yield StreamErrorEvent(error=str(e), recoverable=False)
|
||||
return
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMProvider(LLMProvider):
|
||||
@@ -175,3 +181,28 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a mock completion as word-level TextDeltaEvents.
|
||||
|
||||
Splits the mock response into words and yields each as a separate
|
||||
TextDeltaEvent with an accumulating snapshot, exercising the full
|
||||
streaming pipeline without any API calls.
|
||||
"""
|
||||
content = self._generate_mock_response(system=system, json_mode=False)
|
||||
words = content.split(" ")
|
||||
accumulated = ""
|
||||
|
||||
for i, word in enumerate(words):
|
||||
chunk = word if i == 0 else " " + word
|
||||
accumulated += chunk
|
||||
yield TextDeltaEvent(content=chunk, snapshot=accumulated)
|
||||
|
||||
yield TextEndEvent(full_text=accumulated)
|
||||
yield FinishEvent(stop_reason="mock_complete", model=self.model)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""LLM Provider abstraction for pluggable LLM backends."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -108,3 +108,45 @@ class LLMProvider(ABC):
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator["StreamEvent"]:
|
||||
"""
|
||||
Stream a completion as an async iterator of StreamEvents.
|
||||
|
||||
Default implementation wraps complete() with synthetic events.
|
||||
Subclasses SHOULD override for true streaming.
|
||||
|
||||
Tool orchestration is the CALLER's responsibility:
|
||||
- Caller detects ToolCallEvent, executes tool, adds result
|
||||
to messages, calls stream() again.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
)
|
||||
|
||||
response = self.complete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
yield TextDeltaEvent(content=response.content, snapshot=response.content)
|
||||
yield TextEndEvent(full_text=response.content)
|
||||
yield FinishEvent(
|
||||
stop_reason=response.stop_reason,
|
||||
input_tokens=response.input_tokens,
|
||||
output_tokens=response.output_tokens,
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
|
||||
# Deferred import target for type annotation
|
||||
from framework.llm.stream_events import StreamEvent as StreamEvent # noqa: E402, F401
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Stream event types for LLM streaming responses.
|
||||
|
||||
Defines a discriminated union of frozen dataclasses representing every event
|
||||
a streaming LLM call can produce. These types form the contract between the
|
||||
LLM provider layer, EventLoopNode, event bus, persistence, and monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextDeltaEvent:
|
||||
"""A chunk of text produced by the LLM."""
|
||||
|
||||
type: Literal["text_delta"] = "text_delta"
|
||||
content: str = "" # this chunk's text
|
||||
snapshot: str = "" # accumulated text so far
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextEndEvent:
|
||||
"""Signals that text generation is complete."""
|
||||
|
||||
type: Literal["text_end"] = "text_end"
|
||||
full_text: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolCallEvent:
|
||||
"""The LLM has requested a tool call."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
tool_use_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_input: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolResultEvent:
|
||||
"""Result of executing a tool call."""
|
||||
|
||||
type: Literal["tool_result"] = "tool_result"
|
||||
tool_use_id: str = ""
|
||||
content: str = ""
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningStartEvent:
|
||||
"""The LLM has started a reasoning/thinking block."""
|
||||
|
||||
type: Literal["reasoning_start"] = "reasoning_start"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReasoningDeltaEvent:
|
||||
"""A chunk of reasoning/thinking content."""
|
||||
|
||||
type: Literal["reasoning_delta"] = "reasoning_delta"
|
||||
content: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinishEvent:
|
||||
"""The LLM has finished generating."""
|
||||
|
||||
type: Literal["finish"] = "finish"
|
||||
stop_reason: str = ""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
model: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamErrorEvent:
|
||||
"""An error occurred during streaming."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str = ""
|
||||
recoverable: bool = False
|
||||
|
||||
|
||||
# Discriminated union of all stream event types
|
||||
StreamEvent = (
|
||||
TextDeltaEvent
|
||||
| TextEndEvent
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| ReasoningStartEvent
|
||||
| ReasoningDeltaEvent
|
||||
| FinishEvent
|
||||
| StreamErrorEvent
|
||||
)
|
||||
@@ -41,6 +41,28 @@ class EventType(str, Enum):
|
||||
STREAM_STARTED = "stream_started"
|
||||
STREAM_STOPPED = "stream_stopped"
|
||||
|
||||
# Node event-loop lifecycle
|
||||
NODE_LOOP_STARTED = "node_loop_started"
|
||||
NODE_LOOP_ITERATION = "node_loop_iteration"
|
||||
NODE_LOOP_COMPLETED = "node_loop_completed"
|
||||
|
||||
# LLM streaming observability
|
||||
LLM_TEXT_DELTA = "llm_text_delta"
|
||||
LLM_REASONING_DELTA = "llm_reasoning_delta"
|
||||
|
||||
# Tool lifecycle
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
TOOL_CALL_COMPLETED = "tool_call_completed"
|
||||
|
||||
# Client I/O (client_facing=True nodes only)
|
||||
CLIENT_OUTPUT_DELTA = "client_output_delta"
|
||||
CLIENT_INPUT_REQUESTED = "client_input_requested"
|
||||
|
||||
# Internal node observability (client_facing=False nodes)
|
||||
NODE_INTERNAL_OUTPUT = "node_internal_output"
|
||||
NODE_INPUT_BLOCKED = "node_input_blocked"
|
||||
NODE_STALLED = "node_stalled"
|
||||
|
||||
# Custom events
|
||||
CUSTOM = "custom"
|
||||
|
||||
@@ -51,6 +73,7 @@ class AgentEvent:
|
||||
|
||||
type: EventType
|
||||
stream_id: str
|
||||
node_id: str | None = None # Which node emitted this event
|
||||
execution_id: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
@@ -61,6 +84,7 @@ class AgentEvent:
|
||||
return {
|
||||
"type": self.type.value,
|
||||
"stream_id": self.stream_id,
|
||||
"node_id": self.node_id,
|
||||
"execution_id": self.execution_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
@@ -80,6 +104,7 @@ class Subscription:
|
||||
event_types: set[EventType]
|
||||
handler: EventHandler
|
||||
filter_stream: str | None = None # Only receive events from this stream
|
||||
filter_node: str | None = None # Only receive events from this node
|
||||
filter_execution: str | None = None # Only receive events from this execution
|
||||
|
||||
|
||||
@@ -138,6 +163,7 @@ class EventBus:
|
||||
event_types: list[EventType],
|
||||
handler: EventHandler,
|
||||
filter_stream: str | None = None,
|
||||
filter_node: str | None = None,
|
||||
filter_execution: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -147,6 +173,7 @@ class EventBus:
|
||||
event_types: Types of events to receive
|
||||
handler: Async function to call when event occurs
|
||||
filter_stream: Only receive events from this stream
|
||||
filter_node: Only receive events from this node
|
||||
filter_execution: Only receive events from this execution
|
||||
|
||||
Returns:
|
||||
@@ -160,6 +187,7 @@ class EventBus:
|
||||
event_types=set(event_types),
|
||||
handler=handler,
|
||||
filter_stream=filter_stream,
|
||||
filter_node=filter_node,
|
||||
filter_execution=filter_execution,
|
||||
)
|
||||
|
||||
@@ -218,6 +246,10 @@ class EventBus:
|
||||
if subscription.filter_stream and subscription.filter_stream != event.stream_id:
|
||||
return False
|
||||
|
||||
# Check node filter
|
||||
if subscription.filter_node and subscription.filter_node != event.node_id:
|
||||
return False
|
||||
|
||||
# Check execution filter
|
||||
if subscription.filter_execution and subscription.filter_execution != event.execution_id:
|
||||
return False
|
||||
@@ -410,6 +442,7 @@ class EventBus:
|
||||
self,
|
||||
event_type: EventType,
|
||||
stream_id: str | None = None,
|
||||
node_id: str | None = None,
|
||||
execution_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> AgentEvent | None:
|
||||
@@ -419,6 +452,7 @@ class EventBus:
|
||||
Args:
|
||||
event_type: Type of event to wait for
|
||||
stream_id: Filter by stream
|
||||
node_id: Filter by node
|
||||
execution_id: Filter by execution
|
||||
timeout: Maximum time to wait (seconds)
|
||||
|
||||
@@ -438,6 +472,7 @@ class EventBus:
|
||||
event_types=[event_type],
|
||||
handler=handler,
|
||||
filter_stream=stream_id,
|
||||
filter_node=node_id,
|
||||
filter_execution=execution_id,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user