Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9e252fc33e | |||
| bf57220c8f | |||
| cb1c07e60c | |||
| 2aa38ad9bb | |||
| c0b8980447 | |||
| 601a5d87e9 | |||
| 29e85a13c7 | |||
| 220beb5c64 | |||
| f56600f7af | |||
| 148f61ac3e | |||
| e28d989c92 | |||
| 4681e52f86 | |||
| 8351c808dc | |||
| c959cab9c2 | |||
| 9526570e0a | |||
| e87a40a7c3 |
@@ -13,6 +13,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from framework.graph.edge import DEFAULT_MAX_TOKENS
|
||||
from framework.llm.codex_backend import CODEX_API_BASE, build_codex_litellm_kwargs
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level config file access
|
||||
@@ -125,7 +126,6 @@ def get_worker_api_key() -> str | None:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
api_key_env_var = worker_llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
return os.environ.get(api_key_env_var)
|
||||
@@ -141,7 +141,7 @@ def get_worker_api_base() -> str | None:
|
||||
return get_api_base()
|
||||
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
return CODEX_API_BASE
|
||||
if worker_llm.get("use_kimi_code_subscription"):
|
||||
return "https://api.kimi.com/coding"
|
||||
if worker_llm.get("use_antigravity_subscription"):
|
||||
@@ -169,25 +169,14 @@ def get_worker_llm_extra_kwargs() -> dict[str, Any]:
|
||||
if worker_llm.get("use_codex_subscription"):
|
||||
api_key = get_worker_api_key()
|
||||
if api_key:
|
||||
headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
account_id = None
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
headers["ChatGPT-Account-Id"] = account_id
|
||||
except ImportError:
|
||||
pass
|
||||
return {
|
||||
"extra_headers": headers,
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
if worker_llm.get("provider") == "ollama":
|
||||
return {"num_ctx": worker_llm.get("num_ctx", 16384)}
|
||||
return build_codex_litellm_kwargs(api_key, account_id=account_id)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -276,7 +265,6 @@ def get_api_key() -> str | None:
|
||||
return token
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Standard env-var path (covers ZAI Code and all API-key providers)
|
||||
api_key_env_var = llm.get("api_key_env_var")
|
||||
if api_key_env_var:
|
||||
@@ -382,7 +370,7 @@ def get_api_base() -> str | None:
|
||||
llm = get_hive_config().get("llm", {})
|
||||
if llm.get("use_codex_subscription"):
|
||||
# Codex subscription routes through the ChatGPT backend, not api.openai.com.
|
||||
return "https://chatgpt.com/backend-api/codex"
|
||||
return CODEX_API_BASE
|
||||
if llm.get("use_kimi_code_subscription"):
|
||||
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
|
||||
return "https://api.kimi.com/coding"
|
||||
@@ -417,27 +405,15 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
|
||||
if llm.get("use_codex_subscription"):
|
||||
api_key = get_api_key()
|
||||
if api_key:
|
||||
headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
account_id = None
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
headers["ChatGPT-Account-Id"] = account_id
|
||||
except ImportError:
|
||||
pass
|
||||
return {
|
||||
"extra_headers": headers,
|
||||
"store": False,
|
||||
"allowed_openai_params": ["store"],
|
||||
}
|
||||
return build_codex_litellm_kwargs(api_key, account_id=account_id)
|
||||
if llm.get("provider") == "ollama":
|
||||
# Pass num_ctx to Ollama so it doesn't silently truncate the ~9.5k Queen prompt.
|
||||
# Ollama's default num_ctx is only 2048. We set it to 16384 here so LiteLLM
|
||||
# passes it through as a provider-specific option.
|
||||
return {"num_ctx": llm.get("num_ctx", 16384)}
|
||||
return {}
|
||||
|
||||
|
||||
@@ -351,13 +351,15 @@ class NodeConversation:
|
||||
def system_prompt(self) -> str:
|
||||
return self._system_prompt
|
||||
|
||||
def update_system_prompt(self, new_prompt: str) -> None:
|
||||
def update_system_prompt(self, new_prompt: str, output_keys: list[str] | None = None) -> None:
|
||||
"""Update the system prompt.
|
||||
|
||||
Used in continuous conversation mode at phase transitions to swap
|
||||
Layer 3 (focus) while preserving the conversation history.
|
||||
"""
|
||||
self._system_prompt = new_prompt
|
||||
if output_keys is not None:
|
||||
self._output_keys = output_keys
|
||||
self._meta_persisted = False # re-persist with new prompt
|
||||
|
||||
def set_current_phase(self, phase_id: str) -> None:
|
||||
@@ -771,7 +773,7 @@ class NodeConversation:
|
||||
delete_before = recent_messages[0].seq if recent_messages else self._next_seq
|
||||
await self._store.delete_parts_before(delete_before)
|
||||
await self._store.write_part(summary_msg.seq, summary_msg.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
await self._write_cursor_update({"next_seq": self._next_seq})
|
||||
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
@@ -975,7 +977,7 @@ class NodeConversation:
|
||||
# Write kept structural messages (they may have been modified)
|
||||
for msg in kept_structural:
|
||||
await self._store.write_part(msg.seq, msg.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
await self._write_cursor_update({"next_seq": self._next_seq})
|
||||
|
||||
# Reassemble: reference + kept structural (in original order) + recent
|
||||
self._messages = [ref_msg] + kept_structural + recent_messages
|
||||
@@ -1012,7 +1014,7 @@ class NodeConversation:
|
||||
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
|
||||
if self._store:
|
||||
await self._store.delete_parts_before(self._next_seq)
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
await self._write_cursor_update({"next_seq": self._next_seq})
|
||||
self._messages.clear()
|
||||
self._last_api_input_tokens = None
|
||||
|
||||
@@ -1047,6 +1049,14 @@ class NodeConversation:
|
||||
|
||||
# --- Persistence internals ---------------------------------------------
|
||||
|
||||
async def _write_cursor_update(self, data: dict[str, Any]) -> None:
|
||||
"""Merge cursor updates instead of clobbering existing crash-recovery state."""
|
||||
if self._store is None:
|
||||
return
|
||||
cursor = await self._store.read_cursor() or {}
|
||||
cursor.update(data)
|
||||
await self._store.write_cursor(cursor)
|
||||
|
||||
async def _persist(self, message: Message) -> None:
|
||||
"""Write-through a single message. No-op when store is None."""
|
||||
if self._store is None:
|
||||
@@ -1054,7 +1064,7 @@ class NodeConversation:
|
||||
if not self._meta_persisted:
|
||||
await self._persist_meta()
|
||||
await self._store.write_part(message.seq, message.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
await self._write_cursor_update({"next_seq": self._next_seq})
|
||||
|
||||
async def _persist_meta(self) -> None:
|
||||
"""Lazily write conversation metadata to the store (called once)."""
|
||||
|
||||
+2945
-447
File diff suppressed because it is too large
Load Diff
@@ -1480,7 +1480,22 @@ class GraphExecutor:
|
||||
narrative=narrative,
|
||||
accounts_prompt=_node_accounts,
|
||||
)
|
||||
continuous_conversation.update_system_prompt(new_system)
|
||||
continuous_conversation.update_system_prompt(
|
||||
new_system,
|
||||
output_keys=list(next_spec.output_keys or []),
|
||||
)
|
||||
|
||||
# Stamp the next phase before inserting the transition
|
||||
# marker so the marker itself is preserved with the
|
||||
# phase it introduces during compaction/restore.
|
||||
continuous_conversation.set_current_phase(next_spec.id)
|
||||
|
||||
transition_tool_names = set(cumulative_tool_names)
|
||||
transition_tool_names.update(next_spec.tools or [])
|
||||
if next_spec.output_keys:
|
||||
transition_tool_names.add("set_output")
|
||||
if next_spec.client_facing:
|
||||
transition_tool_names.update({"ask_user", "ask_user_multiple"})
|
||||
|
||||
# Insert transition marker into conversation
|
||||
data_dir = str(self._storage_path / "data") if self._storage_path else None
|
||||
@@ -1488,7 +1503,7 @@ class GraphExecutor:
|
||||
previous_node=node_spec,
|
||||
next_node=next_spec,
|
||||
memory=memory,
|
||||
cumulative_tool_names=sorted(cumulative_tool_names),
|
||||
cumulative_tool_names=sorted(transition_tool_names),
|
||||
data_dir=data_dir,
|
||||
adapt_content=_adapt_text,
|
||||
)
|
||||
@@ -1497,9 +1512,6 @@ class GraphExecutor:
|
||||
is_transition_marker=True,
|
||||
)
|
||||
|
||||
# Set current phase for phase-aware compaction
|
||||
continuous_conversation.set_current_phase(next_spec.id)
|
||||
|
||||
# Phase-boundary compaction (same flow as EventLoopNode._compact)
|
||||
if continuous_conversation.usage_ratio() > 0.5:
|
||||
await continuous_conversation.prune_old_tool_results(
|
||||
|
||||
@@ -152,8 +152,6 @@ def compose_system_prompt(
|
||||
accounts_prompt: str | None = None,
|
||||
skills_catalog_prompt: str | None = None,
|
||||
protocols_prompt: str | None = None,
|
||||
execution_preamble: str | None = None,
|
||||
node_type_preamble: str | None = None,
|
||||
) -> str:
|
||||
"""Compose the multi-layer system prompt.
|
||||
|
||||
@@ -164,10 +162,6 @@ def compose_system_prompt(
|
||||
accounts_prompt: Connected accounts block (sits between identity and narrative).
|
||||
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
|
||||
protocols_prompt: Default skill operational protocols section.
|
||||
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
|
||||
(prepended before focus so the LLM knows its pipeline scope).
|
||||
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
|
||||
best-practices prompt (prepended before focus).
|
||||
|
||||
Returns:
|
||||
Composed system prompt with all layers present, plus current datetime.
|
||||
@@ -194,15 +188,6 @@ def compose_system_prompt(
|
||||
if narrative:
|
||||
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
|
||||
|
||||
# Execution scope preamble (worker nodes — tells the LLM it is one
|
||||
# step in a multi-node pipeline and should not overreach)
|
||||
if execution_preamble:
|
||||
parts.append(f"\n{execution_preamble}")
|
||||
|
||||
# Node-type preamble (e.g. GCU browser best-practices)
|
||||
if node_type_preamble:
|
||||
parts.append(f"\n{node_type_preamble}")
|
||||
|
||||
# Layer 3: Focus (current phase directive)
|
||||
if focus_prompt:
|
||||
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
|
||||
@@ -320,7 +305,8 @@ def build_transition_marker(
|
||||
file_size = (data_path / filename).stat().st_size
|
||||
val_str = (
|
||||
f"[Saved to '{filename}' ({file_size:,} bytes). "
|
||||
f"Use load_data(filename='{filename}') to access.]"
|
||||
f"Use load_data(filename='{filename}') to access from session data. "
|
||||
"Do NOT open it as a workspace file or expect it in the current directory.]"
|
||||
)
|
||||
except Exception:
|
||||
val_str = val_str[:300] + "..."
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
"""Codex adapter for Hive's LiteLLM provider.
|
||||
|
||||
Codex CLI is tool-first and event-structured: tool invocations and tool results
|
||||
are emitted as explicit response items, not as plain-text workflow narration.
|
||||
This adapter keeps the ChatGPT Codex backend aligned with Hive's normal
|
||||
provider contract by normalizing Codex request shaping and response recovery at
|
||||
the provider boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.llm.codex_backend import (
|
||||
build_codex_extra_headers,
|
||||
is_codex_api_base,
|
||||
merge_codex_allowed_openai_params,
|
||||
normalize_codex_api_base,
|
||||
)
|
||||
from framework.llm.provider import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CODEX_CRITICAL_TOOL_NAMES = frozenset(
|
||||
{
|
||||
"ask_user",
|
||||
"ask_user_multiple",
|
||||
"set_output",
|
||||
"escalate",
|
||||
"save_agent_draft",
|
||||
"confirm_and_build",
|
||||
"initialize_and_build_agent",
|
||||
}
|
||||
)
|
||||
_CODEX_SYSTEM_CHUNK_CHARS = 3500
|
||||
_CODEX_SYSTEM_PREAMBLE = """# Codex Execution Contract
|
||||
Follow the system sections below in order.
|
||||
- Obey every CRITICAL, MUST, NEVER, and ONLY instruction exactly.
|
||||
- When tools are available, emit structured tool calls instead of replying with plain-text promises.
|
||||
- Do not skip required workflow boundaries or approval gates.
|
||||
"""
|
||||
|
||||
|
||||
class CodexResponsesAdapter:
|
||||
"""Normalize the ChatGPT Codex backend to Hive's standard provider semantics."""
|
||||
|
||||
def __init__(self, provider: LiteLLMProvider):
|
||||
self._provider = provider
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Return True when the provider targets the ChatGPT Codex backend."""
|
||||
return is_codex_api_base(self._provider.api_base)
|
||||
|
||||
def chunk_system_prompt(self, system: str) -> list[str]:
|
||||
"""Break large system prompts into smaller Codex-friendly chunks."""
|
||||
normalized = system.replace("\r\n", "\n").strip()
|
||||
if not normalized:
|
||||
return []
|
||||
|
||||
sections: list[str] = []
|
||||
current: list[str] = []
|
||||
for line in normalized.splitlines():
|
||||
if line.startswith("#") and current:
|
||||
sections.append("\n".join(current).strip())
|
||||
current = [line]
|
||||
else:
|
||||
current.append(line)
|
||||
if current:
|
||||
sections.append("\n".join(current).strip())
|
||||
|
||||
chunks: list[str] = []
|
||||
for section in sections:
|
||||
if len(section) <= _CODEX_SYSTEM_CHUNK_CHARS:
|
||||
chunks.append(section)
|
||||
continue
|
||||
|
||||
paragraphs = [
|
||||
paragraph.strip() for paragraph in section.split("\n\n") if paragraph.strip()
|
||||
]
|
||||
current_chunk = ""
|
||||
for paragraph in paragraphs:
|
||||
candidate = paragraph if not current_chunk else f"{current_chunk}\n\n{paragraph}"
|
||||
if current_chunk and len(candidate) > _CODEX_SYSTEM_CHUNK_CHARS:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = paragraph
|
||||
else:
|
||||
current_chunk = candidate
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks or [normalized]
|
||||
|
||||
def build_system_messages(
|
||||
self,
|
||||
system: str,
|
||||
*,
|
||||
json_mode: bool,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build Codex system messages in the tool-first format Codex CLI expects."""
|
||||
system_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
chunks = self.chunk_system_prompt(system)
|
||||
if len(chunks) > 1 or len(chunks[0]) > _CODEX_SYSTEM_CHUNK_CHARS:
|
||||
system_messages.append({"role": "system", "content": _CODEX_SYSTEM_PREAMBLE})
|
||||
for chunk in chunks:
|
||||
system_messages.append({"role": "system", "content": chunk})
|
||||
else:
|
||||
system_messages.append({"role": "system", "content": "You are a helpful assistant."})
|
||||
|
||||
if json_mode:
|
||||
system_messages.append(
|
||||
{"role": "system", "content": "Please respond with a valid JSON object."}
|
||||
)
|
||||
return system_messages
|
||||
|
||||
def derive_tool_choice(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[Tool] | None,
|
||||
) -> str | dict[str, Any] | None:
|
||||
"""Force structured tool use when Codex sees critical framework tools."""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
tool_names = {tool.name for tool in tools}
|
||||
if not (tool_names & _CODEX_CRITICAL_TOOL_NAMES):
|
||||
return None
|
||||
|
||||
last_role = next(
|
||||
(m.get("role") for m in reversed(messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role == "assistant":
|
||||
return None
|
||||
return "required"
|
||||
|
||||
def harden_request_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Strip unsupported params and inject the Codex backend headers."""
|
||||
cleaned = dict(kwargs)
|
||||
cleaned["api_base"] = normalize_codex_api_base(
|
||||
cleaned.get("api_base") or self._provider.api_base
|
||||
)
|
||||
cleaned["store"] = False
|
||||
cleaned["allowed_openai_params"] = merge_codex_allowed_openai_params(
|
||||
cleaned.get("allowed_openai_params")
|
||||
)
|
||||
cleaned.pop("max_tokens", None)
|
||||
cleaned.pop("stream_options", None)
|
||||
|
||||
extra_headers = dict(cleaned.get("extra_headers") or {})
|
||||
if "ChatGPT-Account-Id" not in extra_headers:
|
||||
try:
|
||||
from framework.runner.runner import get_codex_account_id
|
||||
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
extra_headers["ChatGPT-Account-Id"] = account_id
|
||||
except Exception:
|
||||
logger.debug("Could not populate ChatGPT-Account-Id", exc_info=True)
|
||||
|
||||
cleaned["extra_headers"] = build_codex_extra_headers(
|
||||
self._provider.api_key,
|
||||
account_id=extra_headers.get("ChatGPT-Account-Id"),
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
return cleaned
|
||||
|
||||
async def recover_empty_stream(
|
||||
self,
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
last_role: str | None,
|
||||
acompletion: Callable[..., Any],
|
||||
) -> list[StreamEvent] | None:
|
||||
"""Try a non-stream completion when Codex returns an empty stream."""
|
||||
fallback_kwargs = dict(kwargs)
|
||||
fallback_kwargs.pop("stream", None)
|
||||
fallback_kwargs.pop("stream_options", None)
|
||||
fallback_kwargs = self._provider._sanitize_request_kwargs(fallback_kwargs, stream=False)
|
||||
|
||||
try:
|
||||
response = await acompletion(**fallback_kwargs)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"[stream-recover] %s non-stream fallback after empty %s stream failed: %s",
|
||||
self._provider.model,
|
||||
last_role,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
events = self._provider._build_stream_events_from_nonstream_response(response)
|
||||
if events:
|
||||
logger.info(
|
||||
"[stream-recover] %s recovered empty %s stream via non-stream completion",
|
||||
self._provider.model,
|
||||
last_role,
|
||||
)
|
||||
return events
|
||||
return None
|
||||
|
||||
def merge_tool_call_chunk(
|
||||
self,
|
||||
tool_calls_acc: dict[int, dict[str, str]],
|
||||
tc: Any,
|
||||
last_tool_idx: int,
|
||||
) -> int:
|
||||
"""Merge a streamed tool-call chunk, compensating for broken bridge indexes."""
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
tc_id = getattr(tc, "id", None) or ""
|
||||
func = getattr(tc, "function", None)
|
||||
func_name = getattr(func, "name", "") if func is not None else ""
|
||||
func_args = getattr(func, "arguments", "") if func is not None else ""
|
||||
|
||||
if tc_id:
|
||||
existing_idx = next(
|
||||
(key for key, value in tool_calls_acc.items() if value["id"] == tc_id),
|
||||
None,
|
||||
)
|
||||
if existing_idx is not None:
|
||||
idx = existing_idx
|
||||
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in ("", tc_id):
|
||||
idx = max(tool_calls_acc.keys(), default=-1) + 1
|
||||
last_tool_idx = idx
|
||||
elif func_name:
|
||||
if (
|
||||
last_tool_idx in tool_calls_acc
|
||||
and tool_calls_acc[last_tool_idx]["name"]
|
||||
and tool_calls_acc[last_tool_idx]["name"] != func_name
|
||||
and tool_calls_acc[last_tool_idx]["arguments"]
|
||||
):
|
||||
idx = max(tool_calls_acc.keys(), default=-1) + 1
|
||||
last_tool_idx = idx
|
||||
else:
|
||||
idx = last_tool_idx if tool_calls_acc else idx
|
||||
else:
|
||||
idx = last_tool_idx if tool_calls_acc else idx
|
||||
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc_id:
|
||||
tool_calls_acc[idx]["id"] = tc_id
|
||||
if func_name:
|
||||
tool_calls_acc[idx]["name"] = func_name
|
||||
if func_args:
|
||||
tool_calls_acc[idx]["arguments"] += func_args
|
||||
return idx
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Shared helpers for Codex's ChatGPT-backed transport.
|
||||
|
||||
Codex CLI talks to the ChatGPT Codex backend, which is not the standard
|
||||
platform OpenAI API. Hive keeps its normal provider contract by centralizing
|
||||
the transport-specific headers and request kwargs here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
CODEX_API_BASE = "https://chatgpt.com/backend-api/codex"
|
||||
CODEX_USER_AGENT = "CodexBar"
|
||||
CODEX_ALLOWED_OPENAI_PARAMS = ("store",)
|
||||
_CODEX_HOST = "chatgpt.com"
|
||||
_CODEX_PATH = "/backend-api/codex"
|
||||
|
||||
|
||||
def is_codex_api_base(api_base: str | None) -> bool:
|
||||
"""Return True when *api_base* targets the ChatGPT Codex backend."""
|
||||
if not api_base:
|
||||
return False
|
||||
parsed = urlparse(api_base)
|
||||
path = parsed.path.rstrip("/")
|
||||
return (
|
||||
parsed.scheme in {"http", "https"}
|
||||
and parsed.hostname == _CODEX_HOST
|
||||
and (path == _CODEX_PATH or path == f"{_CODEX_PATH}/responses")
|
||||
)
|
||||
|
||||
|
||||
def normalize_codex_api_base(api_base: str | None) -> str | None:
|
||||
"""Normalize ChatGPT Codex backend URLs to the stable base endpoint."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
parsed = urlparse(api_base)
|
||||
path = parsed.path.rstrip("/")
|
||||
if not is_codex_api_base(api_base):
|
||||
return api_base.rstrip("/")
|
||||
if path.endswith("/responses"):
|
||||
path = path[: -len("/responses")]
|
||||
normalized = parsed._replace(path=path, params="", query="", fragment="")
|
||||
return urlunparse(normalized).rstrip("/")
|
||||
|
||||
|
||||
def merge_codex_allowed_openai_params(params: list[str] | tuple[str, ...] | None) -> list[str]:
|
||||
"""Ensure Codex-required pass-through params are always present."""
|
||||
allowed = set(params or [])
|
||||
allowed.update(CODEX_ALLOWED_OPENAI_PARAMS)
|
||||
return sorted(allowed)
|
||||
|
||||
|
||||
def build_codex_extra_headers(
|
||||
api_key: str | None,
|
||||
*,
|
||||
account_id: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build headers for the ChatGPT Codex backend."""
|
||||
headers = dict(extra_headers or {})
|
||||
if api_key:
|
||||
headers.setdefault("Authorization", f"Bearer {api_key}")
|
||||
headers.setdefault("User-Agent", CODEX_USER_AGENT)
|
||||
if account_id:
|
||||
headers.setdefault("ChatGPT-Account-Id", account_id)
|
||||
return headers
|
||||
|
||||
|
||||
def build_codex_litellm_kwargs(
|
||||
api_key: str | None,
|
||||
*,
|
||||
account_id: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return the LiteLLM kwargs required by the ChatGPT Codex backend."""
|
||||
return {
|
||||
"extra_headers": build_codex_extra_headers(
|
||||
api_key,
|
||||
account_id=account_id,
|
||||
extra_headers=extra_headers,
|
||||
),
|
||||
"store": False,
|
||||
"allowed_openai_params": list(CODEX_ALLOWED_OPENAI_PARAMS),
|
||||
}
|
||||
+404
-185
@@ -28,6 +28,8 @@ except ImportError:
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE
|
||||
from framework.llm.codex_adapter import CodexResponsesAdapter
|
||||
from framework.llm.codex_backend import normalize_codex_api_base
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
@@ -534,7 +536,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
api_base = api_base.rstrip("/")[:-3]
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or self._default_api_base_for_model(_original_model)
|
||||
self.api_base = normalize_codex_api_base(
|
||||
api_base or self._default_api_base_for_model(_original_model)
|
||||
)
|
||||
self.extra_kwargs = kwargs
|
||||
# Detect Claude Code OAuth subscription by checking the api_key prefix.
|
||||
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
|
||||
@@ -542,24 +546,28 @@ class LiteLLMProvider(LLMProvider):
|
||||
# Anthropic requires a specific User-Agent for OAuth requests.
|
||||
eh = self.extra_kwargs.setdefault("extra_headers", {})
|
||||
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
|
||||
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
|
||||
# several standard OpenAI params: max_output_tokens, stream_options.
|
||||
self._codex_backend = bool(
|
||||
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
|
||||
)
|
||||
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
|
||||
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
|
||||
self._codex_adapter = CodexResponsesAdapter(self)
|
||||
# Backward-compatible alias for existing tests/callers.
|
||||
self._codex_backend = self._codex_adapter.enabled
|
||||
|
||||
if litellm is None:
|
||||
raise ImportError(
|
||||
"LiteLLM is not installed. Please install it with: uv pip install litellm"
|
||||
)
|
||||
|
||||
# Note: The Codex ChatGPT backend is a Responses API endpoint at
|
||||
# The Codex ChatGPT backend is a Responses API endpoint at
|
||||
# chatgpt.com/backend-api/codex/responses. LiteLLM's model registry
|
||||
# correctly marks codex models with mode="responses", so we do NOT
|
||||
# override the mode. The responses_api_bridge in litellm handles
|
||||
# converting Chat Completions requests to Responses API format.
|
||||
# marks legacy codex models (gpt-5.3-codex) with mode="responses",
|
||||
# but newer models like gpt-5.4 default to mode="chat". Force
|
||||
# mode="responses" so litellm routes through the responses_api_bridge.
|
||||
if self._codex_backend and litellm is not None:
|
||||
_strip = self.model.removeprefix("openai/")
|
||||
_entry = litellm.model_cost.get(_strip, {})
|
||||
if _entry.get("mode") != "responses":
|
||||
litellm.model_cost.setdefault(_strip, {})
|
||||
litellm.model_cost[_strip]["mode"] = "responses"
|
||||
|
||||
@staticmethod
|
||||
def _default_api_base_for_model(model: str) -> str | None:
|
||||
@@ -575,6 +583,134 @@ class LiteLLMProvider(LLMProvider):
|
||||
return HIVE_API_BASE
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_codex_api_base(api_base: str | None) -> str | None:
|
||||
"""Normalize ChatGPT Codex backend URLs to the stable base endpoint."""
|
||||
return normalize_codex_api_base(api_base)
|
||||
|
||||
def _chunk_codex_system_prompt(self, system: str) -> list[str]:
|
||||
"""Break large system prompts into smaller Codex-friendly chunks."""
|
||||
return self._codex_adapter.chunk_system_prompt(system)
|
||||
|
||||
def _build_request_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
*,
|
||||
json_mode: bool,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build request messages, including Codex-specific prompt chunking."""
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
|
||||
system_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
if self._codex_backend:
|
||||
system_messages.extend(
|
||||
self._codex_adapter.build_system_messages(system, json_mode=json_mode)
|
||||
)
|
||||
else:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
system_messages.append(sys_msg)
|
||||
elif self._codex_backend:
|
||||
system_messages.extend(
|
||||
self._codex_adapter.build_system_messages("", json_mode=json_mode)
|
||||
)
|
||||
|
||||
if json_mode and not self._codex_backend:
|
||||
json_instruction = "Please respond with a valid JSON object."
|
||||
if system_messages:
|
||||
system_messages[0] = {
|
||||
**system_messages[0],
|
||||
"content": f"{system_messages[0]['content']}\n\n{json_instruction}",
|
||||
}
|
||||
else:
|
||||
system_messages.append({"role": "system", "content": json_instruction})
|
||||
|
||||
full_messages.extend(system_messages)
|
||||
full_messages.extend(messages)
|
||||
|
||||
return [
|
||||
m
|
||||
for m in full_messages
|
||||
if not (
|
||||
m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
def _derive_codex_tool_choice(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[Tool] | None,
|
||||
) -> str | dict[str, Any] | None:
|
||||
"""Force tool use for Codex when critical framework tools are available."""
|
||||
if not self._codex_backend:
|
||||
return None
|
||||
return self._codex_adapter.derive_tool_choice(messages, tools)
|
||||
|
||||
def _sanitize_request_kwargs(
|
||||
self,
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
stream: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Normalize provider kwargs, with extra hardening for Codex."""
|
||||
cleaned = dict(kwargs)
|
||||
if cleaned.get("metadata") is None:
|
||||
cleaned.pop("metadata", None)
|
||||
|
||||
if self._codex_backend:
|
||||
cleaned = self._codex_adapter.harden_request_kwargs(cleaned)
|
||||
|
||||
if stream:
|
||||
cleaned["stream"] = True
|
||||
return cleaned
|
||||
|
||||
def _build_completion_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
*,
|
||||
tools: list[Tool] | None,
|
||||
max_tokens: int,
|
||||
response_format: dict[str, Any] | None,
|
||||
json_mode: bool,
|
||||
stream: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Build request kwargs for completion/stream calls."""
|
||||
full_messages = self._build_request_messages(messages, system, json_mode=json_mode)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
if not stream:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
else:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
if not self._is_anthropic_model():
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
tool_choice = self._derive_codex_tool_choice(full_messages, tools)
|
||||
if tool_choice is not None:
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
elif _is_ollama_model(self.model):
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
return self._sanitize_request_kwargs(kwargs, stream=stream)
|
||||
|
||||
def _completion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -713,46 +849,15 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare messages with system prompt
|
||||
full_messages = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
# Append to system message if present, otherwise add as system message
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
|
||||
# Add response_format for structured output
|
||||
# LiteLLM passes this through to the underlying provider
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
kwargs = self._build_completion_kwargs(
|
||||
messages,
|
||||
system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Make the call
|
||||
response = self._completion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
@@ -913,44 +1018,15 @@ class LiteLLMProvider(LLMProvider):
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
kwargs = self._build_completion_kwargs(
|
||||
messages,
|
||||
system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
@@ -1200,17 +1276,92 @@ class LiteLLMProvider(LLMProvider):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_pythonish_tool_arguments(raw_arguments: str) -> str:
|
||||
"""Convert common JSON-like literals into a form ast.literal_eval can parse."""
|
||||
replacements = {
|
||||
"true": "True",
|
||||
"false": "False",
|
||||
"null": "None",
|
||||
}
|
||||
out: list[str] = []
|
||||
token: list[str] = []
|
||||
in_string = False
|
||||
string_quote = ""
|
||||
escaped = False
|
||||
|
||||
def flush_token() -> None:
|
||||
if not token:
|
||||
return
|
||||
word = "".join(token)
|
||||
out.append(replacements.get(word, word))
|
||||
token.clear()
|
||||
|
||||
for char in raw_arguments:
|
||||
if in_string:
|
||||
out.append(char)
|
||||
if escaped:
|
||||
escaped = False
|
||||
elif char == "\\":
|
||||
escaped = True
|
||||
elif char == string_quote:
|
||||
in_string = False
|
||||
continue
|
||||
|
||||
if char in {'"', "'"}:
|
||||
flush_token()
|
||||
in_string = True
|
||||
string_quote = char
|
||||
out.append(char)
|
||||
continue
|
||||
|
||||
if char.isalpha():
|
||||
token.append(char)
|
||||
continue
|
||||
|
||||
flush_token()
|
||||
out.append(char)
|
||||
|
||||
flush_token()
|
||||
return "".join(out)
|
||||
|
||||
@staticmethod
|
||||
def _strip_tool_argument_fence(raw_arguments: str) -> str:
|
||||
"""Remove surrounding fenced-code markers from streamed tool arguments."""
|
||||
stripped = raw_arguments.strip()
|
||||
if not stripped.startswith("```") or not stripped.endswith("```"):
|
||||
return stripped
|
||||
|
||||
lines = stripped.splitlines()
|
||||
if len(lines) >= 2:
|
||||
return "\n".join(lines[1:-1]).strip()
|
||||
return stripped.strip("`").strip()
|
||||
|
||||
def _parse_pythonish_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
|
||||
"""Parse single-quoted / trailing-comma argument payloads safely."""
|
||||
stripped = self._strip_tool_argument_fence(raw_arguments)
|
||||
if not stripped or stripped[0] != "{":
|
||||
return None
|
||||
candidate = self._close_truncated_json_fragment(stripped)
|
||||
candidate = self._normalize_pythonish_tool_arguments(candidate)
|
||||
try:
|
||||
parsed = ast.literal_eval(candidate)
|
||||
except (SyntaxError, ValueError):
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
|
||||
"""Parse streamed tool arguments, repairing truncation when possible."""
|
||||
stripped = self._strip_tool_argument_fence(raw_arguments)
|
||||
try:
|
||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||
parsed = json.loads(stripped) if stripped else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
|
||||
repaired = self._repair_truncated_tool_arguments(raw_arguments)
|
||||
repaired = self._repair_truncated_tool_arguments(stripped)
|
||||
if repaired is not None:
|
||||
logger.warning(
|
||||
"[tool-args] Recovered truncated arguments for %s on %s",
|
||||
@@ -1219,6 +1370,15 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
return repaired
|
||||
|
||||
pythonish = self._parse_pythonish_tool_arguments(stripped)
|
||||
if pythonish is not None:
|
||||
logger.warning(
|
||||
"[tool-args] Recovered malformed arguments for %s on %s",
|
||||
tool_name,
|
||||
self.model,
|
||||
)
|
||||
return pythonish
|
||||
|
||||
raise ValueError(
|
||||
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
|
||||
)
|
||||
@@ -1546,6 +1706,139 @@ class LiteLLMProvider(LLMProvider):
|
||||
model=response.model,
|
||||
)
|
||||
|
||||
def _build_stream_events_from_nonstream_response(
|
||||
self,
|
||||
response: Any,
|
||||
) -> list[StreamEvent]:
|
||||
"""Convert a non-stream completion response into stream events."""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
choices = getattr(response, "choices", None) or []
|
||||
if not choices:
|
||||
output_text = getattr(response, "output_text", "") or ""
|
||||
if not output_text:
|
||||
return []
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, TextEndEvent
|
||||
|
||||
usage = getattr(response, "usage", None)
|
||||
return [
|
||||
TextDeltaEvent(content=output_text, snapshot=output_text),
|
||||
TextEndEvent(full_text=output_text),
|
||||
FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=getattr(usage, "prompt_tokens", 0) or 0 if usage else 0,
|
||||
output_tokens=getattr(usage, "completion_tokens", 0) or 0 if usage else 0,
|
||||
model=getattr(response, "model", None) or self.model,
|
||||
),
|
||||
]
|
||||
|
||||
choice = choices[0]
|
||||
message = getattr(choice, "message", None)
|
||||
content = self._extract_message_text(message)
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
|
||||
events: list[StreamEvent] = []
|
||||
for tc in tool_calls:
|
||||
parsed_args = self._coerce_tool_input(
|
||||
tc.function.arguments if tc.function else {},
|
||||
tc.function.name if tc.function else "",
|
||||
)
|
||||
events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=getattr(tc, "id", ""),
|
||||
tool_name=tc.function.name if tc.function else "",
|
||||
tool_input=parsed_args,
|
||||
)
|
||||
)
|
||||
|
||||
if content:
|
||||
events.append(TextDeltaEvent(content=content, snapshot=content))
|
||||
events.append(TextEndEvent(full_text=content))
|
||||
|
||||
usage = getattr(response, "usage", None)
|
||||
input_tokens = getattr(usage, "prompt_tokens", 0) or 0 if usage else 0
|
||||
output_tokens = getattr(usage, "completion_tokens", 0) or 0 if usage else 0
|
||||
cached_tokens = 0
|
||||
if usage:
|
||||
details = getattr(usage, "prompt_tokens_details", None)
|
||||
cached_tokens = (
|
||||
getattr(details, "cached_tokens", 0) or 0
|
||||
if details is not None
|
||||
else getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
|
||||
events.append(
|
||||
FinishEvent(
|
||||
stop_reason=getattr(choice, "finish_reason", None)
|
||||
or ("tool_calls" if tool_calls else "stop"),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cached_tokens=cached_tokens,
|
||||
model=getattr(response, "model", None) or self.model,
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_text(message: Any) -> str:
|
||||
"""Extract text from a provider message object across response shapes."""
|
||||
if message is None:
|
||||
return ""
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
text = block.get("text") or block.get("content") or ""
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
else:
|
||||
text = getattr(block, "text", "") or getattr(block, "content", "")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content or "")
|
||||
|
||||
def _coerce_tool_input(self, raw_arguments: Any, tool_name: str) -> dict[str, Any]:
|
||||
"""Normalize raw tool-call arguments from either string or object forms."""
|
||||
if isinstance(raw_arguments, dict):
|
||||
return raw_arguments
|
||||
if raw_arguments in (None, ""):
|
||||
return {}
|
||||
return self._parse_tool_call_arguments(str(raw_arguments), tool_name)
|
||||
|
||||
async def _recover_empty_codex_stream(
|
||||
self,
|
||||
kwargs: dict[str, Any],
|
||||
last_role: str | None,
|
||||
) -> list[StreamEvent] | None:
|
||||
"""Try a non-stream completion when Codex returns an empty stream."""
|
||||
if not self._codex_backend:
|
||||
return None
|
||||
return await self._codex_adapter.recover_empty_stream(
|
||||
kwargs,
|
||||
last_role=last_role,
|
||||
acompletion=litellm.acompletion, # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
def _merge_tool_call_chunk(
|
||||
self,
|
||||
tool_calls_acc: dict[int, dict[str, str]],
|
||||
tc: Any,
|
||||
last_tool_idx: int,
|
||||
) -> int:
|
||||
"""Merge a streamed tool-call chunk, compensating for broken Codex indexes."""
|
||||
return self._codex_adapter.merge_tool_call_chunk(tool_calls_acc, tc, last_tool_idx)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -1597,69 +1890,16 @@ class LiteLLMProvider(LLMProvider):
|
||||
yield event
|
||||
return
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if self._claude_code_oauth:
|
||||
billing = _claude_code_billing_header(messages)
|
||||
full_messages.append({"role": "system", "content": billing})
|
||||
if system:
|
||||
sys_msg: dict[str, Any] = {"role": "system", "content": system}
|
||||
if _model_supports_cache_control(self.model):
|
||||
sys_msg["cache_control"] = {"type": "ephemeral"}
|
||||
full_messages.append(sys_msg)
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Codex Responses API requires an `instructions` field (system prompt).
|
||||
# Inject a minimal one when callers don't provide a system message.
|
||||
if self._codex_backend and not any(m["role"] == "system" for m in full_messages):
|
||||
full_messages.insert(0, {"role": "system", "content": "You are a helpful assistant."})
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
# Remove ghost empty assistant messages (content="" and no tool_calls).
|
||||
# These arise when a model returns an empty stream after a tool result
|
||||
# (an "expected" no-op turn). Keeping them in history confuses some
|
||||
# models (notably Codex/gpt-5.3) and causes cascading empty streams.
|
||||
full_messages = [
|
||||
m
|
||||
for m in full_messages
|
||||
if not (
|
||||
m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
|
||||
# Only include it for providers that support it.
|
||||
if not self._is_anthropic_model():
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
if _is_ollama_model(self.model):
|
||||
# Ollama requires explicit tool_choice=auto for function calling
|
||||
# so future readers don't have to guess.
|
||||
kwargs.setdefault("tool_choice", "auto")
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
# The Codex ChatGPT backend (Responses API) rejects several params.
|
||||
if self._codex_backend:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs.pop("stream_options", None)
|
||||
kwargs = self._build_completion_kwargs(
|
||||
messages,
|
||||
system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
stream=True,
|
||||
)
|
||||
full_messages = kwargs["messages"]
|
||||
|
||||
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
|
||||
# Post-stream events (ToolCall, TextEnd, Finish) are buffered
|
||||
@@ -1717,43 +1957,17 @@ class LiteLLMProvider(LLMProvider):
|
||||
# argument deltas that arrive with id=None.
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
|
||||
if tc.id:
|
||||
# New tool call announced (or done event re-sent).
|
||||
# Check if this id already has a slot.
|
||||
existing_idx = next(
|
||||
(k for k, v in tool_calls_acc.items() if v["id"] == tc.id),
|
||||
None,
|
||||
)
|
||||
if existing_idx is not None:
|
||||
idx = existing_idx
|
||||
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in (
|
||||
"",
|
||||
tc.id,
|
||||
):
|
||||
# Slot taken by a different call — assign new index
|
||||
idx = max(tool_calls_acc.keys()) + 1
|
||||
_last_tool_idx = idx
|
||||
else:
|
||||
# Argument delta with no id — route to last opened slot
|
||||
idx = _last_tool_idx
|
||||
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function:
|
||||
if tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
if tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
_last_tool_idx = self._merge_tool_call_chunk(
|
||||
tool_calls_acc,
|
||||
tc,
|
||||
_last_tool_idx,
|
||||
)
|
||||
|
||||
# --- Finish ---
|
||||
if choice.finish_reason:
|
||||
stream_finish_reason = choice.finish_reason
|
||||
for _idx, tc_data in sorted(tool_calls_acc.items()):
|
||||
parsed_args = self._parse_tool_call_arguments(
|
||||
parsed_args = self._coerce_tool_input(
|
||||
tc_data.get("arguments", ""),
|
||||
tc_data.get("name", ""),
|
||||
)
|
||||
@@ -1886,6 +2100,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
recovered_events = await self._recover_empty_codex_stream(kwargs, last_role)
|
||||
if recovered_events:
|
||||
for event in recovered_events:
|
||||
yield event
|
||||
return
|
||||
if attempt < EMPTY_STREAM_MAX_RETRIES:
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
|
||||
@@ -22,6 +22,7 @@ from framework.graph.edge import (
|
||||
)
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.llm.codex_backend import CODEX_API_BASE, build_codex_litellm_kwargs
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runner.preload_validation import run_preload_validation
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
@@ -327,17 +328,68 @@ def _read_codex_auth_file() -> dict | None:
|
||||
return None
|
||||
|
||||
|
||||
def _get_jwt_claims(token: str) -> dict | None:
|
||||
"""Decode JWT claims without verification for local expiry/account inspection."""
|
||||
import base64
|
||||
|
||||
try:
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
payload = parts[1]
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
decoded = base64.urlsafe_b64decode(payload)
|
||||
claims = json.loads(decoded)
|
||||
return claims if isinstance(claims, dict) else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_codex_token_expiry(auth_data: dict) -> float | None:
|
||||
"""Return the best-known expiry timestamp for a Codex access token."""
|
||||
from datetime import datetime
|
||||
|
||||
tokens = auth_data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
explicit = (
|
||||
auth_data.get("expires_at")
|
||||
or auth_data.get("expiresAt")
|
||||
or tokens.get("expires_at")
|
||||
or tokens.get("expiresAt")
|
||||
)
|
||||
if isinstance(explicit, (int, float)):
|
||||
return float(explicit)
|
||||
if isinstance(explicit, str):
|
||||
try:
|
||||
return datetime.fromisoformat(explicit.replace("Z", "+00:00")).timestamp()
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if isinstance(access_token, str):
|
||||
claims = _get_jwt_claims(access_token) or {}
|
||||
exp = claims.get("exp")
|
||||
if isinstance(exp, (int, float)):
|
||||
return float(exp)
|
||||
return None
|
||||
|
||||
|
||||
def _is_codex_token_expired(auth_data: dict) -> bool:
|
||||
"""Check whether the Codex token is expired or close to expiry.
|
||||
|
||||
The Codex auth.json has no explicit ``expiresAt`` field, so we infer
|
||||
expiry as ``last_refresh + _CODEX_TOKEN_LIFETIME_SECS``. Falls back
|
||||
to the file mtime when ``last_refresh`` is absent.
|
||||
to JWT ``exp`` or file age heuristics when no explicit timestamp exists.
|
||||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
now = time.time()
|
||||
explicit_expiry = _get_codex_token_expiry(auth_data)
|
||||
if explicit_expiry is not None:
|
||||
return now >= (explicit_expiry - _TOKEN_REFRESH_BUFFER_SECS)
|
||||
|
||||
last_refresh = auth_data.get("last_refresh")
|
||||
|
||||
if last_refresh is None:
|
||||
@@ -431,6 +483,8 @@ def get_codex_token() -> str | None:
|
||||
Returns:
|
||||
The access token if available, None otherwise.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Try Keychain first, then file
|
||||
auth_data = _read_codex_keychain() or _read_codex_auth_file()
|
||||
if not auth_data:
|
||||
@@ -441,15 +495,20 @@ def get_codex_token() -> str | None:
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
explicit_expiry = _get_codex_token_expiry(auth_data)
|
||||
is_expired = _is_codex_token_expired(auth_data)
|
||||
|
||||
# Check if token is still valid
|
||||
if not _is_codex_token_expired(auth_data):
|
||||
if not is_expired:
|
||||
return access_token
|
||||
|
||||
# Token is expired or near expiry — attempt refresh
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning("Codex token expired and no refresh token available")
|
||||
return access_token # Return expired token; it may still work briefly
|
||||
if explicit_expiry is not None and time.time() >= explicit_expiry:
|
||||
return None
|
||||
return access_token
|
||||
|
||||
logger.info("Codex token expired or near expiry, refreshing...")
|
||||
token_data = _refresh_codex_token(refresh_token)
|
||||
@@ -460,6 +519,8 @@ def get_codex_token() -> str | None:
|
||||
|
||||
# Refresh failed — return the existing token and warn
|
||||
logger.warning("Codex token refresh failed. Run 'codex' to re-authenticate.")
|
||||
if explicit_expiry is not None and time.time() >= explicit_expiry:
|
||||
return None
|
||||
return access_token
|
||||
|
||||
|
||||
@@ -471,26 +532,12 @@ def _get_account_id_from_jwt(access_token: str) -> str | None:
|
||||
This is used as a fallback when the auth.json doesn't store the
|
||||
account_id explicitly.
|
||||
"""
|
||||
import base64
|
||||
|
||||
try:
|
||||
parts = access_token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
payload = parts[1]
|
||||
# Add base64 padding
|
||||
padding = 4 - len(payload) % 4
|
||||
if padding != 4:
|
||||
payload += "=" * padding
|
||||
decoded = base64.urlsafe_b64decode(payload)
|
||||
claims = json.loads(decoded)
|
||||
auth = claims.get("https://api.openai.com/auth")
|
||||
if isinstance(auth, dict):
|
||||
account_id = auth.get("chatgpt_account_id")
|
||||
if isinstance(account_id, str) and account_id:
|
||||
return account_id
|
||||
except Exception:
|
||||
pass
|
||||
claims = _get_jwt_claims(access_token) or {}
|
||||
auth = claims.get("https://api.openai.com/auth")
|
||||
if isinstance(auth, dict):
|
||||
account_id = auth.get("chatgpt_account_id")
|
||||
if isinstance(account_id, str) and account_id:
|
||||
return account_id
|
||||
return None
|
||||
|
||||
|
||||
@@ -1569,20 +1616,20 @@ class AgentRunner:
|
||||
# OpenAI Codex subscription routes through the ChatGPT backend
|
||||
# (chatgpt.com/backend-api/codex/responses), NOT the standard
|
||||
# OpenAI API. The consumer OAuth token lacks platform API scopes.
|
||||
extra_headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
extra_headers["ChatGPT-Account-Id"] = account_id
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
api_base="https://chatgpt.com/backend-api/codex",
|
||||
extra_headers=extra_headers,
|
||||
store=False,
|
||||
allowed_openai_params=["store"],
|
||||
api_base=CODEX_API_BASE,
|
||||
**build_codex_litellm_kwargs(api_key, account_id=account_id),
|
||||
)
|
||||
elif api_key and use_kimi_code:
|
||||
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
|
||||
# The api_base is set automatically by LiteLLMProvider for kimi/ models.
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif api_key and use_kimi_code:
|
||||
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
|
||||
|
||||
@@ -535,8 +535,8 @@ class EventBus:
|
||||
async with self._semaphore:
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
logger.exception(f"Handler error for {event.type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Handler error for {event.type}: {e}")
|
||||
|
||||
# Run all handlers concurrently
|
||||
await asyncio.gather(*[run_handler(h) for h in handlers], return_exceptions=True)
|
||||
@@ -901,6 +901,9 @@ class EventBus:
|
||||
execution_id: str | None = None,
|
||||
options: list[str] | None = None,
|
||||
questions: list[dict] | None = None,
|
||||
auto_blocked: bool = False,
|
||||
assistant_text_present: bool = False,
|
||||
assistant_text_requires_input: bool = False,
|
||||
) -> None:
|
||||
"""Emit client input requested event (client_facing=True nodes).
|
||||
|
||||
@@ -917,6 +920,12 @@ class EventBus:
|
||||
data["options"] = options
|
||||
if questions:
|
||||
data["questions"] = questions
|
||||
if auto_blocked:
|
||||
data["auto_blocked"] = True
|
||||
if assistant_text_present:
|
||||
data["assistant_text_present"] = True
|
||||
if assistant_text_requires_input:
|
||||
data["assistant_text_requires_input"] = True
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
|
||||
@@ -452,7 +452,9 @@ class ExecutionStream:
|
||||
node = executor.node_registry.get(node_id)
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
await node.inject_event(
|
||||
content, is_client_input=is_client_input, image_content=image_content
|
||||
content,
|
||||
is_client_input=is_client_input,
|
||||
image_content=image_content,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@@ -1030,6 +1032,22 @@ class ExecutionStream:
|
||||
else:
|
||||
status = SessionStatus.ACTIVE
|
||||
|
||||
persisted_input_data = dict(ctx.input_data or {})
|
||||
entry_node_id = getattr(self.entry_spec, "entry_node", None) or getattr(
|
||||
self.graph, "entry_node", None
|
||||
)
|
||||
entry_input_keys: list[str] = []
|
||||
if entry_node_id and hasattr(self.graph, "get_node"):
|
||||
entry_node = self.graph.get_node(entry_node_id)
|
||||
entry_input_keys = list(getattr(entry_node, "input_keys", []) or [])
|
||||
|
||||
if result and isinstance(result.output, dict):
|
||||
for key in entry_input_keys:
|
||||
if persisted_input_data.get(key) in (None, ""):
|
||||
value = result.output.get(key)
|
||||
if value not in (None, ""):
|
||||
persisted_input_data[key] = value
|
||||
|
||||
# Create SessionState
|
||||
if result:
|
||||
# Create from execution result
|
||||
@@ -1040,7 +1058,7 @@ class ExecutionStream:
|
||||
stream_id=self.stream_id,
|
||||
correlation_id=ctx.correlation_id,
|
||||
started_at=ctx.started_at.isoformat(),
|
||||
input_data=ctx.input_data,
|
||||
input_data=persisted_input_data,
|
||||
agent_id=self.graph.id,
|
||||
entry_point=self.entry_spec.id,
|
||||
)
|
||||
@@ -1075,7 +1093,7 @@ class ExecutionStream:
|
||||
),
|
||||
progress=progress,
|
||||
memory=ss.get("memory", {}),
|
||||
input_data=ctx.input_data,
|
||||
input_data=persisted_input_data,
|
||||
)
|
||||
|
||||
# Handle error case
|
||||
|
||||
@@ -48,7 +48,20 @@ def validate_agent_path(agent_path: str | Path) -> Path:
|
||||
Raises:
|
||||
ValueError: If the path is outside all allowed roots.
|
||||
"""
|
||||
resolved = Path(agent_path).expanduser().resolve()
|
||||
raw_path = str(agent_path).strip()
|
||||
if not raw_path:
|
||||
raise ValueError(
|
||||
"agent_path must be inside an allowed directory "
|
||||
"(exports/, examples/, or ~/.hive/agents/)"
|
||||
)
|
||||
|
||||
candidate = Path(agent_path).expanduser()
|
||||
if not candidate.is_absolute():
|
||||
# Resolve relative paths from the repository root so server-side
|
||||
# validation is independent of the process working directory.
|
||||
candidate = _REPO_ROOT / candidate
|
||||
|
||||
resolved = candidate.resolve()
|
||||
for root in _get_allowed_agent_roots():
|
||||
if resolved.is_relative_to(root) and resolved != root:
|
||||
return resolved
|
||||
@@ -281,6 +294,8 @@ def _setup_static_serving(app: web.Application) -> None:
|
||||
async def handle_spa(request: web.Request) -> web.FileResponse:
|
||||
"""Serve static files with SPA fallback to index.html."""
|
||||
rel_path = request.match_info.get("path", "")
|
||||
if rel_path == "api" or rel_path.startswith("api/"):
|
||||
raise web.HTTPNotFound()
|
||||
file_path = (dist_dir / rel_path).resolve()
|
||||
|
||||
if file_path.is_file() and file_path.is_relative_to(dist_dir):
|
||||
|
||||
@@ -17,6 +17,23 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _client_input_counts_as_planning_ask(event: Any) -> bool:
|
||||
"""Return True when a queen input-request should satisfy planning ask rounds.
|
||||
|
||||
Explicit ask_user / ask_user_multiple calls always count. We also count
|
||||
queen auto-blocks that followed assistant text which clearly invited a
|
||||
reply, which covers Codex-style plain-text planning questions that failed
|
||||
to call ask_user. Empty/status-only auto-blocks do not count.
|
||||
"""
|
||||
data = getattr(event, "data", None) or {}
|
||||
if data.get("prompt") or data.get("questions") or data.get("options"):
|
||||
return True
|
||||
if not data.get("auto_blocked"):
|
||||
return False
|
||||
requires_input = data.get("assistant_text_requires_input", False)
|
||||
return bool(requires_input)
|
||||
|
||||
|
||||
async def create_queen(
|
||||
session: Session,
|
||||
session_manager: Any,
|
||||
@@ -62,7 +79,6 @@ async def create_queen(
|
||||
from framework.agents.queen.nodes.thinking_hook import select_expert_persona
|
||||
from framework.graph.event_loop_node import HookContext, HookResult
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
@@ -88,6 +104,8 @@ async def create_queen(
|
||||
logger.warning("Queen: MCP config failed to load", exc_info=True)
|
||||
|
||||
try:
|
||||
from framework.runner.mcp_registry import MCPRegistry
|
||||
|
||||
registry = MCPRegistry()
|
||||
registry.initialize()
|
||||
if (queen_pkg_dir / "mcp_registry.json").is_file():
|
||||
@@ -115,14 +133,7 @@ async def create_queen(
|
||||
async def _track_planning_asks(event: AgentEvent) -> None:
|
||||
if phase_state.phase != "planning":
|
||||
return
|
||||
# Only count explicit ask_user / ask_user_multiple calls, not
|
||||
# auto-block (text-only turns emit CLIENT_INPUT_REQUESTED with
|
||||
# an empty prompt and no options/questions).
|
||||
data = event.data or {}
|
||||
has_prompt = bool(data.get("prompt"))
|
||||
has_questions = bool(data.get("questions"))
|
||||
has_options = bool(data.get("options"))
|
||||
if has_prompt or has_questions or has_options:
|
||||
if _client_input_counts_as_planning_ask(event):
|
||||
phase_state.planning_ask_rounds += 1
|
||||
|
||||
session.event_bus.subscribe(
|
||||
@@ -240,15 +251,11 @@ async def create_queen(
|
||||
|
||||
# ---- Default skill protocols -------------------------------------
|
||||
try:
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
from framework.skills.manager import SkillsManager
|
||||
|
||||
# Pass project_root so user-scope skills (~/.hive/skills/, ~/.agents/skills/)
|
||||
# are discovered. Queen has no agent-specific project root, so we use its
|
||||
# own directory — the value just needs to be non-None to enable user-scope scanning.
|
||||
_queen_skills_mgr = SkillsManager(SkillsManagerConfig(project_root=Path(__file__).parent))
|
||||
_queen_skills_mgr = SkillsManager()
|
||||
_queen_skills_mgr.load()
|
||||
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
|
||||
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
|
||||
except Exception:
|
||||
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
|
||||
|
||||
|
||||
@@ -8,11 +8,113 @@ from typing import Any
|
||||
from aiohttp import web
|
||||
|
||||
from framework.credentials.validation import validate_agent_credentials
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.server.app import resolve_session, safe_path_segment, sessions_dir
|
||||
from framework.server.routes_sessions import _credential_error_response
|
||||
from framework.server.session_manager import (
|
||||
_run_validation_report_sync,
|
||||
_validation_blocks_stage_or_run,
|
||||
_validation_failures,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TERMINAL_STOP_MARKERS = (
|
||||
"done for now",
|
||||
"stop here",
|
||||
"stop for now",
|
||||
"end session",
|
||||
"finish and close session",
|
||||
"finish and close",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_choice_text(text: str) -> str:
|
||||
lowered = str(text or "").strip().lower()
|
||||
return " ".join(lowered.replace("_", " ").split())
|
||||
|
||||
|
||||
def _looks_like_terminal_stop_reply(text: str) -> bool:
|
||||
normalized = _normalize_choice_text(text)
|
||||
return any(marker in normalized for marker in _TERMINAL_STOP_MARKERS)
|
||||
|
||||
|
||||
def _queen_is_waiting_on_terminal_followup(session: Any) -> bool:
|
||||
"""Return True when the latest queen question offered a terminal stop option."""
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus is None or not hasattr(bus, "get_history"):
|
||||
return False
|
||||
|
||||
events = bus.get_history(
|
||||
event_type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id="queen",
|
||||
limit=5,
|
||||
)
|
||||
for event in events:
|
||||
data = getattr(event, "data", None) or {}
|
||||
options = [str(opt) for opt in (data.get("options") or []) if opt]
|
||||
for question in data.get("questions") or []:
|
||||
options.extend(str(opt) for opt in (question.get("options") or []) if opt)
|
||||
if options:
|
||||
return any(_looks_like_terminal_stop_reply(opt) for opt in options)
|
||||
return False
|
||||
|
||||
|
||||
async def _acknowledge_terminal_queen_choice(session: Any, message: str) -> None:
|
||||
"""Emit a final acknowledgment when the user chooses to stop."""
|
||||
ack = "Okay, stopping here. I’ll wait for your next message."
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus is None:
|
||||
return
|
||||
|
||||
if hasattr(bus, "publish"):
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_RECEIVED,
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={"content": message},
|
||||
)
|
||||
)
|
||||
if hasattr(bus, "emit_client_output_delta"):
|
||||
await bus.emit_client_output_delta(
|
||||
"queen",
|
||||
"queen",
|
||||
ack,
|
||||
ack,
|
||||
execution_id=session.id,
|
||||
)
|
||||
|
||||
|
||||
async def _worker_validation_error(session) -> web.Response | None:
|
||||
"""Return a 409 response when the loaded worker is invalid."""
|
||||
report = getattr(session, "worker_validation_report", None)
|
||||
if report is None and getattr(session, "worker_path", None):
|
||||
loop = asyncio.get_running_loop()
|
||||
report = await loop.run_in_executor(
|
||||
None, lambda: _run_validation_report_sync(str(session.worker_path))
|
||||
)
|
||||
session.worker_validation_report = report
|
||||
session.worker_validation_failures = _validation_failures(report)
|
||||
|
||||
if _validation_blocks_stage_or_run(report):
|
||||
failures = getattr(session, "worker_validation_failures", None) or _validation_failures(
|
||||
report
|
||||
)
|
||||
worker_name = getattr(getattr(session, "worker_path", None), "name", "") or "current worker"
|
||||
return web.json_response(
|
||||
{
|
||||
"error": (
|
||||
f"Worker '{worker_name}' failed validation and cannot be executed. "
|
||||
"Fix the package and reload it before running or resuming."
|
||||
),
|
||||
"validation_failures": failures,
|
||||
},
|
||||
status=409,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def handle_trigger(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/trigger — start an execution.
|
||||
@@ -26,6 +128,10 @@ async def handle_trigger(request: web.Request) -> web.Response:
|
||||
if not session.worker_runtime:
|
||||
return web.json_response({"error": "No worker loaded in this session"}, status=503)
|
||||
|
||||
validation_err = await _worker_validation_error(session)
|
||||
if validation_err is not None:
|
||||
return validation_err
|
||||
|
||||
# Validate credentials before running — deferred from load time to avoid
|
||||
# showing the modal before the user clicks Run. Runs in executor because
|
||||
# validate_agent_credentials makes blocking HTTP health-check calls.
|
||||
@@ -53,11 +159,7 @@ async def handle_trigger(request: web.Request) -> web.Response:
|
||||
body = await request.json()
|
||||
entry_point_id = body.get("entry_point_id", "default")
|
||||
input_data = body.get("input_data", {})
|
||||
session_state = body.get("session_state") or {}
|
||||
|
||||
# Scope the worker execution to the live session ID
|
||||
if "resume_session_id" not in session_state:
|
||||
session_state["resume_session_id"] = session.id
|
||||
session_state = body.get("session_state") or None
|
||||
|
||||
execution_id = await session.worker_runtime.trigger(
|
||||
entry_point_id,
|
||||
@@ -108,10 +210,7 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
The input box is permanently connected to the queen agent.
|
||||
Worker input is handled separately via /worker-input.
|
||||
|
||||
Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]}
|
||||
|
||||
The optional ``images`` field accepts a list of OpenAI-format image_url
|
||||
content blocks. The frontend encodes images as base64 data URIs.
|
||||
Body: {"message": "hello"}
|
||||
"""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
@@ -119,29 +218,34 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
|
||||
body = await request.json()
|
||||
message = body.get("message", "")
|
||||
image_content = body.get("images") or None # list[dict] | None
|
||||
|
||||
if not message and not image_content:
|
||||
if not message:
|
||||
return web.json_response({"error": "message is required"}, status=400)
|
||||
|
||||
manager: Any = request.app["manager"]
|
||||
|
||||
if _looks_like_terminal_stop_reply(message) and _queen_is_waiting_on_terminal_followup(session):
|
||||
await _acknowledge_terminal_queen_choice(session, message)
|
||||
await manager.suspend_queen(session)
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "queen",
|
||||
"delivered": True,
|
||||
}
|
||||
)
|
||||
|
||||
queen_executor = session.queen_executor
|
||||
if queen_executor is not None:
|
||||
node = queen_executor.node_registry.get("queen")
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
await node.inject_event(message, is_client_input=True, image_content=image_content)
|
||||
# Publish to EventBus so the session event log captures user messages
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await node.inject_event(message, is_client_input=True)
|
||||
await session.event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_RECEIVED,
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={
|
||||
"content": message,
|
||||
"image_count": len(image_content) if image_content else 0,
|
||||
},
|
||||
data={"content": message},
|
||||
)
|
||||
)
|
||||
return web.json_response(
|
||||
@@ -152,7 +256,6 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
# Queen is dead — try to revive her
|
||||
manager: Any = request.app["manager"]
|
||||
try:
|
||||
await manager.revive_queen(session, initial_prompt=message)
|
||||
return web.json_response(
|
||||
@@ -274,6 +377,10 @@ async def handle_resume(request: web.Request) -> web.Response:
|
||||
if not session.worker_runtime:
|
||||
return web.json_response({"error": "No worker loaded in this session"}, status=503)
|
||||
|
||||
validation_err = await _worker_validation_error(session)
|
||||
if validation_err is not None:
|
||||
return validation_err
|
||||
|
||||
body = await request.json()
|
||||
worker_session_id = body.get("session_id")
|
||||
checkpoint_id = body.get("checkpoint_id")
|
||||
@@ -419,9 +526,14 @@ async def handle_stop(request: web.Request) -> web.Response:
|
||||
if hasattr(node, "cancel_current_turn"):
|
||||
node.cancel_current_turn()
|
||||
|
||||
cancelled = await stream.cancel_execution(
|
||||
execution_id, reason="Execution stopped by user"
|
||||
)
|
||||
try:
|
||||
cancelled = await stream.cancel_execution(
|
||||
execution_id, reason="Execution stopped by user"
|
||||
)
|
||||
except TypeError:
|
||||
# Backward compatibility for older stream/test doubles that
|
||||
# still expose cancel_execution(execution_id) only.
|
||||
cancelled = await stream.cancel_execution(execution_id)
|
||||
if cancelled:
|
||||
# Cancel queen's in-progress LLM turn
|
||||
if session.queen_executor:
|
||||
|
||||
@@ -28,8 +28,6 @@ import contextlib
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
@@ -42,22 +40,54 @@ from framework.server.app import (
|
||||
sessions_dir,
|
||||
validate_agent_path,
|
||||
)
|
||||
from framework.server.session_manager import SessionManager
|
||||
from framework.server.session_manager import (
|
||||
SessionManager,
|
||||
WorkerValidationError,
|
||||
_run_validation_report_sync,
|
||||
_validation_blocks_stage_or_run,
|
||||
_validation_failures,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _worker_validation_error(session) -> web.Response | None:
|
||||
"""Return a 409 response when the loaded worker is invalid."""
|
||||
report = getattr(session, "worker_validation_report", None)
|
||||
if report is None and getattr(session, "worker_path", None):
|
||||
loop = asyncio.get_running_loop()
|
||||
report = await loop.run_in_executor(
|
||||
None, lambda: _run_validation_report_sync(str(session.worker_path))
|
||||
)
|
||||
session.worker_validation_report = report
|
||||
session.worker_validation_failures = _validation_failures(report)
|
||||
|
||||
if _validation_blocks_stage_or_run(report):
|
||||
failures = getattr(session, "worker_validation_failures", None) or _validation_failures(
|
||||
report
|
||||
)
|
||||
worker_name = getattr(getattr(session, "worker_path", None), "name", "") or "current worker"
|
||||
return web.json_response(
|
||||
{
|
||||
"error": (
|
||||
f"Worker '{worker_name}' failed validation and cannot be executed. "
|
||||
"Fix the package and reload it before running or restoring."
|
||||
),
|
||||
"validation_failures": failures,
|
||||
},
|
||||
status=409,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_manager(request: web.Request) -> SessionManager:
|
||||
return request.app["manager"]
|
||||
|
||||
|
||||
def _session_to_live_dict(session) -> dict:
|
||||
"""Serialize a live Session to the session-primary JSON shape."""
|
||||
from framework.llm.capabilities import supports_image_tool_results
|
||||
|
||||
info = session.worker_info
|
||||
phase_state = getattr(session, "phase_state", None)
|
||||
queen_model: str = getattr(getattr(session, "runner", None), "model", "") or ""
|
||||
return {
|
||||
"session_id": session.id,
|
||||
"worker_id": session.worker_id,
|
||||
@@ -73,7 +103,6 @@ def _session_to_live_dict(session) -> dict:
|
||||
"queen_phase": phase_state.phase
|
||||
if phase_state
|
||||
else ("staging" if session.worker_runtime else "planning"),
|
||||
"queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True,
|
||||
}
|
||||
|
||||
|
||||
@@ -311,6 +340,11 @@ async def handle_load_worker(request: web.Request) -> web.Response:
|
||||
model=model,
|
||||
)
|
||||
except ValueError as e:
|
||||
if isinstance(e, WorkerValidationError):
|
||||
return web.json_response(
|
||||
{"error": str(e), "validation_failures": e.failures},
|
||||
status=409,
|
||||
)
|
||||
return web.json_response({"error": str(e)}, status=409)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Agent not found: {agent_path}"}, status=404)
|
||||
@@ -729,6 +763,10 @@ async def handle_restore_checkpoint(request: web.Request) -> web.Response:
|
||||
if not session.worker_runtime:
|
||||
return web.json_response({"error": "No worker loaded in this session"}, status=503)
|
||||
|
||||
validation_err = await _worker_validation_error(session)
|
||||
if validation_err is not None:
|
||||
return validation_err
|
||||
|
||||
ws_id = request.match_info.get("ws_id") or request.match_info.get("session_id", "")
|
||||
ws_id = safe_path_segment(ws_id)
|
||||
checkpoint_id = safe_path_segment(request.match_info["checkpoint_id"])
|
||||
@@ -984,29 +1022,6 @@ async def handle_discover(request: web.Request) -> web.Response:
|
||||
return web.json_response(result)
|
||||
|
||||
|
||||
async def handle_reveal_session_folder(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager."""
|
||||
manager: SessionManager = request.app["manager"]
|
||||
session_id = request.match_info["session_id"]
|
||||
|
||||
session = manager.get_session(session_id)
|
||||
storage_session_id = (session.queen_resume_from or session.id) if session else session_id
|
||||
folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.Popen(["open", str(folder)])
|
||||
elif sys.platform == "win32":
|
||||
subprocess.Popen(["explorer", str(folder)])
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", str(folder)])
|
||||
except Exception as exc:
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
return web.json_response({"path": str(folder)})
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Route registration
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1031,7 +1046,6 @@ def register_routes(app: web.Application) -> None:
|
||||
app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker)
|
||||
|
||||
# Session info
|
||||
app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder)
|
||||
app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats)
|
||||
app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points)
|
||||
app.router.add_patch(
|
||||
|
||||
@@ -12,6 +12,8 @@ Architecture:
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
@@ -22,6 +24,134 @@ from typing import Any
|
||||
from framework.runtime.triggers import TriggerDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
CODER_TOOLS_SERVER = REPO_ROOT / "tools" / "coder_tools_server.py"
|
||||
|
||||
|
||||
def _parse_validation_report(raw: str | dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Best-effort parse of validate_agent_package output."""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if not isinstance(raw, str):
|
||||
return {}
|
||||
|
||||
cleaned = raw.strip()
|
||||
if "\n\n[Saved to " in cleaned:
|
||||
cleaned = cleaned.split("\n\n[Saved to ", 1)[0].strip()
|
||||
if not cleaned:
|
||||
return {}
|
||||
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
try:
|
||||
return json.loads(cleaned[start : end + 1])
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def _validation_failures(report: dict[str, Any] | None) -> list[str]:
|
||||
"""Extract readable failure summaries from a validation report."""
|
||||
if not isinstance(report, dict):
|
||||
return []
|
||||
|
||||
steps = report.get("steps") or {}
|
||||
failures: list[str] = []
|
||||
for step_name, step in steps.items():
|
||||
if not isinstance(step, dict) or step.get("passed", False):
|
||||
continue
|
||||
if step.get("errors"):
|
||||
errors = step["errors"]
|
||||
if isinstance(errors, list):
|
||||
failures.extend(f"{step_name}: {err}" for err in errors)
|
||||
continue
|
||||
if step.get("missing_tools"):
|
||||
missing = step["missing_tools"]
|
||||
if isinstance(missing, list):
|
||||
failures.extend(f"{step_name}: missing tool {tool}" for tool in missing)
|
||||
continue
|
||||
detail = step.get("error") or step.get("output") or "validation failed"
|
||||
failures.append(f"{step_name}: {detail}")
|
||||
if not failures and report.get("summary"):
|
||||
failures.append(str(report["summary"]))
|
||||
return failures
|
||||
|
||||
|
||||
def _validation_blocks_stage_or_run(report: dict[str, Any] | None) -> bool:
|
||||
"""Return True when a validation report contains any failed step."""
|
||||
if not isinstance(report, dict):
|
||||
return False
|
||||
steps = report.get("steps")
|
||||
if not isinstance(steps, dict):
|
||||
return bool(report.get("valid") is False)
|
||||
return any(isinstance(step, dict) and not step.get("passed", False) for step in steps.values())
|
||||
|
||||
|
||||
def _run_validation_report_sync(agent_ref: str | Path) -> dict[str, Any]:
|
||||
"""Run validate_agent_package in an isolated subprocess.
|
||||
|
||||
Accepts either a built-agent package name (for exports/) or a full
|
||||
allowed agent path such as examples/templates/<agent>.
|
||||
"""
|
||||
if not agent_ref:
|
||||
return {}
|
||||
agent_ref_str = str(agent_ref)
|
||||
|
||||
script = textwrap.dedent(
|
||||
"""
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
server_path = sys.argv[1]
|
||||
agent_ref = sys.argv[2]
|
||||
spec = importlib.util.spec_from_file_location("coder_tools_server", server_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
import json
|
||||
print(json.dumps(module._validate_agent_package_impl(agent_ref), default=str))
|
||||
"""
|
||||
)
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["uv", "run", "python", "-c", script, str(CODER_TOOLS_SERVER), agent_ref_str],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
cwd=REPO_ROOT,
|
||||
stdin=subprocess.DEVNULL,
|
||||
)
|
||||
except (OSError, subprocess.SubprocessError) as exc:
|
||||
return {
|
||||
"valid": False,
|
||||
"summary": f"validate_agent_package failed for '{agent_ref_str}'",
|
||||
"steps": {"validator_subprocess": {"passed": False, "error": str(exc)[:2000]}},
|
||||
}
|
||||
if proc.returncode != 0:
|
||||
detail = proc.stderr.strip() or proc.stdout.strip() or "validation subprocess failed"
|
||||
return {
|
||||
"valid": False,
|
||||
"summary": f"validate_agent_package failed for '{agent_ref_str}'",
|
||||
"steps": {"validator_subprocess": {"passed": False, "error": detail[:2000]}},
|
||||
}
|
||||
return _parse_validation_report(proc.stdout)
|
||||
|
||||
|
||||
class WorkerValidationError(ValueError):
|
||||
"""Raised when a worker package fails validation before load/run."""
|
||||
|
||||
def __init__(self, agent_name: str, report: dict[str, Any]):
|
||||
self.agent_name = agent_name
|
||||
self.report = report
|
||||
self.failures = _validation_failures(report)
|
||||
super().__init__(
|
||||
f"Worker '{agent_name}' failed validation: "
|
||||
+ ("; ".join(self.failures) if self.failures else "validation failed")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,6 +171,8 @@ class Session:
|
||||
runner: Any | None = None # AgentRunner
|
||||
worker_runtime: Any | None = None # AgentRuntime
|
||||
worker_info: Any | None = None # AgentInfo
|
||||
worker_validation_report: dict[str, Any] | None = None
|
||||
worker_validation_failures: list[str] = field(default_factory=list)
|
||||
# Queen phase state (building/staging/running)
|
||||
phase_state: Any = None # QueenPhaseState
|
||||
# Worker handoff subscription
|
||||
@@ -83,6 +215,47 @@ class SessionManager:
|
||||
self._credential_store = credential_store
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def suspend_queen(self, session: Session) -> None:
|
||||
"""Park the queen until the user sends a fresh message.
|
||||
|
||||
This is lighter than stopping the full session: it tears down the
|
||||
queen executor and its subscriptions, but preserves the live session,
|
||||
loaded worker, and persisted history. The next `/chat` call will
|
||||
revive the queen via the normal code path.
|
||||
"""
|
||||
if session.worker_handoff_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.worker_handoff_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.worker_handoff_sub = None
|
||||
|
||||
if session.memory_consolidation_sub is not None:
|
||||
try:
|
||||
session.event_bus.unsubscribe(session.memory_consolidation_sub)
|
||||
except Exception:
|
||||
pass
|
||||
session.memory_consolidation_sub = None
|
||||
|
||||
executor = session.queen_executor
|
||||
if executor is not None:
|
||||
node = executor.node_registry.get("queen")
|
||||
if node is not None and hasattr(node, "signal_shutdown"):
|
||||
node.signal_shutdown()
|
||||
|
||||
if session.queen_task is not None:
|
||||
task = session.queen_task
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("Queen task exited with error during suspend", exc_info=True)
|
||||
session.queen_task = None
|
||||
|
||||
session.queen_executor = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
@@ -96,7 +269,8 @@ class SessionManager:
|
||||
|
||||
Internal helper — use create_session() or create_session_with_worker().
|
||||
"""
|
||||
from framework.config import RuntimeConfig, get_hive_config
|
||||
from framework.config import RuntimeConfig
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.runtime.event_bus import EventBus
|
||||
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -110,20 +284,12 @@ class SessionManager:
|
||||
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
|
||||
|
||||
# Session owns these — shared with queen and worker
|
||||
llm_config = get_hive_config().get("llm", {})
|
||||
if llm_config.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
llm = AntigravityProvider(model=rc.model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
llm = LiteLLMProvider(
|
||||
model=rc.model,
|
||||
api_key=rc.api_key,
|
||||
api_base=rc.api_base,
|
||||
**rc.extra_kwargs,
|
||||
)
|
||||
event_bus = EventBus()
|
||||
|
||||
session = Session(
|
||||
@@ -294,6 +460,11 @@ class SessionManager:
|
||||
try:
|
||||
# Blocking I/O — load in executor
|
||||
loop = asyncio.get_running_loop()
|
||||
validation_report = await loop.run_in_executor(
|
||||
None, lambda: _run_validation_report_sync(str(agent_path))
|
||||
)
|
||||
if _validation_blocks_stage_or_run(validation_report):
|
||||
raise WorkerValidationError(agent_path.name, validation_report)
|
||||
|
||||
# Prioritize: explicit model arg > worker-specific model > session default
|
||||
from framework.config import (
|
||||
@@ -320,25 +491,17 @@ class SessionManager:
|
||||
# with the correct worker credentials so _setup() doesn't fall back
|
||||
# to the queen's llm config (which may be a different provider).
|
||||
if worker_model and not model:
|
||||
from framework.config import get_hive_config
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
worker_llm_cfg = get_hive_config().get("worker_llm", {})
|
||||
if worker_llm_cfg.get("use_antigravity_subscription"):
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
runner._llm = AntigravityProvider(model=resolved_model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
worker_api_key = get_worker_api_key()
|
||||
worker_api_base = get_worker_api_base()
|
||||
worker_extra = get_worker_llm_extra_kwargs()
|
||||
runner._llm = LiteLLMProvider(
|
||||
model=resolved_model,
|
||||
api_key=worker_api_key,
|
||||
api_base=worker_api_base,
|
||||
**worker_extra,
|
||||
)
|
||||
worker_api_key = get_worker_api_key()
|
||||
worker_api_base = get_worker_api_base()
|
||||
worker_extra = get_worker_llm_extra_kwargs()
|
||||
runner._llm = LiteLLMProvider(
|
||||
model=resolved_model,
|
||||
api_key=worker_api_key,
|
||||
api_base=worker_api_base,
|
||||
**worker_extra,
|
||||
)
|
||||
|
||||
# Setup with session's event bus
|
||||
if runner._agent_runtime is None:
|
||||
@@ -383,6 +546,8 @@ class SessionManager:
|
||||
session.runner = runner
|
||||
session.worker_runtime = runtime
|
||||
session.worker_info = info
|
||||
session.worker_validation_report = validation_report
|
||||
session.worker_validation_failures = _validation_failures(validation_report)
|
||||
|
||||
# Subscribe to execution completion for per-run digest generation
|
||||
self._subscribe_worker_digest(session)
|
||||
@@ -637,6 +802,8 @@ class SessionManager:
|
||||
session.runner = None
|
||||
session.worker_runtime = None
|
||||
session.worker_info = None
|
||||
session.worker_validation_report = None
|
||||
session.worker_validation_failures = []
|
||||
|
||||
# Notify queen
|
||||
await self._notify_queen_worker_unloaded(session)
|
||||
@@ -820,12 +987,25 @@ class SessionManager:
|
||||
return
|
||||
await node.inject_event(f"[WORKER_DIGEST]\n{content}")
|
||||
|
||||
async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None:
|
||||
"""Write the digest then push it to the queen."""
|
||||
async def _consolidate_and_notify(
|
||||
run_id: str,
|
||||
outcome_event: Any,
|
||||
*,
|
||||
inject_to_queen: bool,
|
||||
) -> None:
|
||||
"""Write the digest and optionally push it into the queen.
|
||||
|
||||
Final worker completion/failure already emits a richer
|
||||
[WORKER_TERMINAL] handoff with the real primary result. Injecting the
|
||||
final digest as a second queen event causes Codex to replace that
|
||||
result with a bland generic follow-up prompt. Keep writing digests to
|
||||
disk for memory/history, but only inject mid-run snapshots.
|
||||
"""
|
||||
from framework.agents.worker_memory import consolidate_worker_run
|
||||
|
||||
await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm)
|
||||
await _inject_digest_to_queen(run_id)
|
||||
if inject_to_queen:
|
||||
await _inject_digest_to_queen(run_id)
|
||||
|
||||
async def _on_worker_event(event: Any) -> None:
|
||||
if event.stream_id == "queen":
|
||||
@@ -851,7 +1031,7 @@ class SessionManager:
|
||||
run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id)
|
||||
if run_id:
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, event),
|
||||
_consolidate_and_notify(run_id, event, inject_to_queen=False),
|
||||
name=f"worker-digest-final-{run_id}",
|
||||
)
|
||||
|
||||
@@ -872,7 +1052,7 @@ class SessionManager:
|
||||
if run_id:
|
||||
_last_digest[exec_id] = now
|
||||
asyncio.create_task(
|
||||
_consolidate_and_notify(run_id, None),
|
||||
_consolidate_and_notify(run_id, None, inject_to_queen=True),
|
||||
name=f"worker-digest-{run_id}",
|
||||
)
|
||||
|
||||
@@ -1047,17 +1227,10 @@ class SessionManager:
|
||||
_consolidation_session_dir = queen_dir
|
||||
|
||||
async def _on_compaction(_event) -> None:
|
||||
# Only consolidate on queen compactions — worker and subagent
|
||||
# compactions are frequent and don't warrant a memory update.
|
||||
if getattr(_event, "stream_id", None) != "queen":
|
||||
return
|
||||
from framework.agents.queen.queen_memory import consolidate_queen_memory
|
||||
|
||||
asyncio.create_task(
|
||||
consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
),
|
||||
name=f"queen-memory-consolidation-{session.id}",
|
||||
await consolidate_queen_memory(
|
||||
session.id, _consolidation_session_dir, _consolidation_llm
|
||||
)
|
||||
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
|
||||
@@ -14,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.runtime.triggers import TriggerDefinition
|
||||
from framework.server.app import create_app
|
||||
from framework.server.session_manager import Session
|
||||
@@ -190,6 +191,8 @@ def _make_session(
|
||||
runner=runner,
|
||||
worker_runtime=rt,
|
||||
worker_info=MockAgentInfo(),
|
||||
worker_validation_report={"valid": True, "steps": {}},
|
||||
worker_validation_failures=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -556,6 +559,7 @@ class TestExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger(self):
|
||||
session = _make_session()
|
||||
session.worker_runtime.trigger = AsyncMock(return_value="exec_test_123")
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
@@ -565,6 +569,11 @@ class TestExecution:
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["execution_id"] == "exec_test_123"
|
||||
session.worker_runtime.trigger.assert_awaited_once_with(
|
||||
"default",
|
||||
{"msg": "hi"},
|
||||
session_state=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_not_found(self):
|
||||
@@ -576,6 +585,25 @@ class TestExecution:
|
||||
)
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_blocks_invalid_loaded_worker(self):
|
||||
session = _make_session()
|
||||
session.worker_validation_report = {
|
||||
"valid": False,
|
||||
"steps": {"behavior_validation": {"passed": False}},
|
||||
}
|
||||
session.worker_validation_failures = ["behavior_validation: placeholder prompt"]
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/trigger",
|
||||
json={"entry_point_id": "default", "input_data": {"msg": "hi"}},
|
||||
)
|
||||
assert resp.status == 409
|
||||
data = await resp.json()
|
||||
assert "failed validation" in data["error"]
|
||||
assert data["validation_failures"] == ["behavior_validation: placeholder prompt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject(self):
|
||||
session = _make_session()
|
||||
@@ -616,8 +644,8 @@ class TestExecution:
|
||||
assert data["delivered"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_injects_when_node_waiting(self):
|
||||
"""When a node is awaiting input, /chat should inject instead of trigger."""
|
||||
async def test_chat_still_goes_to_queen_when_node_waiting(self):
|
||||
"""The main chat channel stays wired to Queen even if a worker is waiting."""
|
||||
session = _make_session()
|
||||
session.worker_runtime.find_awaiting_node = lambda: ("chat_node", "primary")
|
||||
app = _make_app_with_session(session)
|
||||
@@ -628,6 +656,83 @@ class TestExecution:
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "queen"
|
||||
assert data["delivered"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_done_for_now_parks_queen_without_new_followup(self):
|
||||
"""Terminal stop choices should acknowledge once and park the queen."""
|
||||
session = _make_session()
|
||||
session.event_bus.get_history.return_value = [
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={"options": ["Run again with same input", "Done for now"]},
|
||||
)
|
||||
]
|
||||
session.event_bus.emit_client_output_delta = AsyncMock()
|
||||
app = _make_app_with_session(session)
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/chat",
|
||||
json={"message": "No, stop here"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "queen"
|
||||
assert data["delivered"] is True
|
||||
|
||||
queen_node = session.queen_executor
|
||||
assert queen_node is None
|
||||
session.event_bus.emit_client_output_delta.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_non_terminal_choice_still_goes_to_queen(self):
|
||||
"""Non-terminal follow-up choices should still be injected into the queen."""
|
||||
session = _make_session()
|
||||
session.event_bus.get_history.return_value = [
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={"options": ["Run again with same input", "Done for now"]},
|
||||
)
|
||||
]
|
||||
session.event_bus.emit_client_output_delta = AsyncMock()
|
||||
app = _make_app_with_session(session)
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/chat",
|
||||
json={"message": "Run again with same input"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "queen"
|
||||
assert data["delivered"] is True
|
||||
queen_node = session.queen_executor.node_registry["queen"]
|
||||
queen_node.inject_event.assert_awaited_once_with(
|
||||
"Run again with same input",
|
||||
is_client_input=True,
|
||||
)
|
||||
session.event_bus.emit_client_output_delta.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_input_injects_when_node_waiting(self):
|
||||
session = _make_session()
|
||||
session.worker_runtime.find_awaiting_node = lambda: ("chat_node", "primary")
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/worker-input",
|
||||
json={"message": "user reply"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "injected"
|
||||
assert data["node_id"] == "chat_node"
|
||||
assert data["delivered"] is True
|
||||
@@ -715,8 +820,6 @@ class TestResume:
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["execution_id"] == "exec_test_123"
|
||||
assert data["resumed_from"] == session_id
|
||||
assert data["checkpoint_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_with_checkpoint(self, sample_session, tmp_agent_dir):
|
||||
@@ -761,6 +864,31 @@ class TestResume:
|
||||
)
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_blocks_invalid_loaded_worker(self, sample_session, tmp_agent_dir):
|
||||
session_id, session_dir, state = sample_session
|
||||
tmp_path, agent_name, base = tmp_agent_dir
|
||||
|
||||
session = _make_session(tmp_dir=tmp_path / ".hive" / "agents" / agent_name)
|
||||
session.worker_validation_report = {
|
||||
"valid": False,
|
||||
"steps": {"tool_validation": {"passed": False}},
|
||||
}
|
||||
session.worker_validation_failures = ["tool_validation: missing tool execute_command_tool"]
|
||||
app = _make_app_with_session(session)
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/resume",
|
||||
json={"session_id": session_id},
|
||||
)
|
||||
assert resp.status == 409
|
||||
data = await resp.json()
|
||||
assert "failed validation" in data["error"]
|
||||
assert data["validation_failures"] == [
|
||||
"tool_validation: missing tool execute_command_tool"
|
||||
]
|
||||
|
||||
|
||||
class TestStop:
|
||||
@pytest.mark.asyncio
|
||||
@@ -800,6 +928,19 @@ class TestStop:
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_ignores_worker_validation_failure(self):
|
||||
session = _make_session()
|
||||
session.worker_validation_failures = ["behavior_validation: broken"]
|
||||
session.worker_runtime._mock_streams["default"]._execution_tasks["exec_abc"] = MagicMock()
|
||||
app = _make_app_with_session(session)
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/stop",
|
||||
json={"execution_id": "exec_abc"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
class TestReplay:
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -36,6 +36,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
@@ -53,6 +54,7 @@ from framework.tools.flowchart_utils import (
|
||||
save_flowchart_file,
|
||||
synthesize_draft_from_runtime,
|
||||
)
|
||||
from framework.tools.worker_monitoring_tools import read_worker_health_snapshot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
@@ -61,6 +63,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NON_ACCEPT_JUDGE_ACTIONS = frozenset({"RETRY", "CONTINUE", "ESCALATE"})
|
||||
_HEALTH_SIGNAL_DESCRIPTIONS: dict[str, str] = {
|
||||
"failed_session": "worker session is marked failed",
|
||||
"stalled": "worker appears stalled with no meaningful progress for 5+ minutes",
|
||||
"slow_progress": "worker progress has slowed for 2+ minutes without completing",
|
||||
"long_non_accept_streak": "worker has a sustained non-ACCEPT judge streak",
|
||||
"judge_pressure": "worker is under repeated non-ACCEPT judge pressure",
|
||||
"recent_non_accept_churn": "recent judge verdicts are all non-ACCEPT, indicating churn",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerSessionAdapter:
|
||||
@@ -118,8 +130,6 @@ class QueenPhaseState:
|
||||
|
||||
# Default skill operational protocols — appended to every phase prompt
|
||||
protocols_prompt: str = ""
|
||||
# Community skills catalog (XML) — appended after protocols
|
||||
skills_catalog_prompt: str = ""
|
||||
|
||||
def get_current_tools(self) -> list:
|
||||
"""Return tools for the current phase."""
|
||||
@@ -146,8 +156,6 @@ class QueenPhaseState:
|
||||
|
||||
memory = format_for_injection()
|
||||
parts = [base]
|
||||
if self.skills_catalog_prompt:
|
||||
parts.append(self.skills_catalog_prompt)
|
||||
if self.protocols_prompt:
|
||||
parts.append(self.protocols_prompt)
|
||||
if memory:
|
||||
@@ -750,6 +758,438 @@ def _update_meta_json(session_manager, manager_session_id, updates: dict) -> Non
|
||||
pass
|
||||
|
||||
|
||||
def _parse_validation_report(raw: Any) -> dict | None:
|
||||
"""Best-effort parse of validate_agent_package output."""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if hasattr(raw, "content"):
|
||||
raw = raw.content
|
||||
if raw is None:
|
||||
return None
|
||||
text = str(raw).strip()
|
||||
if not text:
|
||||
return None
|
||||
candidates = [text]
|
||||
if "\n\n[Saved to" in text:
|
||||
candidates.append(text.split("\n\n[Saved to", 1)[0].strip())
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
candidates.append(text[start : end + 1])
|
||||
try:
|
||||
for candidate in candidates:
|
||||
try:
|
||||
return json.loads(candidate)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
continue
|
||||
except TypeError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _validation_failures(report: dict | None) -> list[str]:
|
||||
"""Flatten failed validation steps into readable messages."""
|
||||
if not report:
|
||||
return []
|
||||
failures: list[str] = []
|
||||
for step_name, step in (report.get("steps") or {}).items():
|
||||
if step.get("passed"):
|
||||
continue
|
||||
detail = step.get("output") or step.get("error") or "failed"
|
||||
failures.append(f"{step_name}: {detail}")
|
||||
return failures
|
||||
|
||||
|
||||
def _validation_blocks_stage_or_run(report: dict | None) -> bool:
|
||||
"""Return True when validation results should block staging or execution."""
|
||||
if not report:
|
||||
return False
|
||||
return any(
|
||||
isinstance(step, dict) and not step.get("passed", False)
|
||||
for step in (report.get("steps") or {}).values()
|
||||
)
|
||||
|
||||
|
||||
def _invalid_validation_report(reason: str) -> dict:
|
||||
"""Build a structured validation failure when validator output is unusable."""
|
||||
return {
|
||||
"valid": False,
|
||||
"summary": reason,
|
||||
"steps": {
|
||||
"validator_subprocess": {
|
||||
"passed": False,
|
||||
"error": reason,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
_STRUCTURED_TASK_PAIR_RE = re.compile(
|
||||
r"\[?(?P<key>[A-Za-z_][A-Za-z0-9_]*)\]?\s*(?::|=)\s*(?P<value>.*?)(?=(?:\s+\[?[A-Za-z_][A-Za-z0-9_]*\]?\s*(?::|=))|$)"
|
||||
)
|
||||
_STRUCTURED_TASK_LINE_RE = re.compile(
|
||||
r"^\s*(?:[-*]\s*)?\[?(?P<key>[A-Za-z_][A-Za-z0-9_]*)\]?\s*(?::|=)\s*(?P<value>.*?)\s*$"
|
||||
)
|
||||
_NUMERIC_WITH_SUFFIX_RE = re.compile(r"^(?P<number>-?\d+(?:\.\d+)?)\s*\([^)]*\)\s*$")
|
||||
_LEADING_NUMERIC_RE = re.compile(r"^\s*(?P<number>-?\d+(?:\.\d+)?)\b")
|
||||
_RERUN_WITH_DEFAULTS_RE = re.compile(
|
||||
r"\b(?:run\s+again|rerun|continue)\b.*\b(?:same\s+(?:default|defaults|settings|inputs)|"
|
||||
r"defaults?)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_PATH_INPUT_KEY_HINTS = (
|
||||
"_dir",
|
||||
"_path",
|
||||
"_folder",
|
||||
"_root",
|
||||
)
|
||||
_NUMERIC_INPUT_KEY_HINTS = (
|
||||
"_threshold",
|
||||
"_count",
|
||||
"_limit",
|
||||
"_max",
|
||||
"_min",
|
||||
"_ratio",
|
||||
"_size",
|
||||
)
|
||||
|
||||
|
||||
def _coerce_task_value(raw: str) -> Any:
|
||||
"""Best-effort coerce simple structured task values from text."""
|
||||
text = raw.strip().rstrip(",")
|
||||
if not text:
|
||||
return ""
|
||||
numeric_match = _NUMERIC_WITH_SUFFIX_RE.match(text)
|
||||
if numeric_match:
|
||||
text = numeric_match.group("number")
|
||||
try:
|
||||
return json.loads(text)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return text
|
||||
|
||||
|
||||
def _parse_structured_task_payload(task: str) -> dict[str, Any]:
|
||||
"""Extract ``key: value`` pairs or JSON objects from a task string."""
|
||||
text = (task or "").strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
payload: dict[str, Any] = {}
|
||||
current_key: str | None = None
|
||||
for raw_line in text.splitlines():
|
||||
if not raw_line.strip():
|
||||
current_key = None
|
||||
continue
|
||||
inline_matches = list(_STRUCTURED_TASK_PAIR_RE.finditer(raw_line.strip()))
|
||||
if len(inline_matches) > 1:
|
||||
for match in inline_matches:
|
||||
payload[match.group("key")] = _coerce_task_value(match.group("value"))
|
||||
current_key = inline_matches[-1].group("key")
|
||||
continue
|
||||
line_match = _STRUCTURED_TASK_LINE_RE.match(raw_line)
|
||||
if line_match:
|
||||
key = line_match.group("key")
|
||||
value_text = line_match.group("value")
|
||||
payload[key] = _coerce_task_value(value_text)
|
||||
current_key = key
|
||||
continue
|
||||
if current_key and (raw_line.startswith(" ") or raw_line.startswith("\t")):
|
||||
existing = payload.get(current_key, "")
|
||||
continuation = raw_line.strip()
|
||||
if isinstance(existing, str):
|
||||
payload[current_key] = f"{existing}\n{continuation}".strip()
|
||||
continue
|
||||
current_key = None
|
||||
|
||||
if payload:
|
||||
return payload
|
||||
|
||||
matches = list(_STRUCTURED_TASK_PAIR_RE.finditer(text))
|
||||
for match in matches:
|
||||
key = match.group("key")
|
||||
value_text = match.group("value")
|
||||
if not value_text.strip():
|
||||
continue
|
||||
payload[key] = _coerce_task_value(value_text)
|
||||
return payload
|
||||
|
||||
|
||||
def _looks_like_path_input_key(key: str) -> bool:
|
||||
lowered = key.lower()
|
||||
return lowered.endswith(_PATH_INPUT_KEY_HINTS) or lowered in {
|
||||
"path",
|
||||
"dir",
|
||||
"folder",
|
||||
"root",
|
||||
}
|
||||
|
||||
|
||||
def _looks_like_numeric_input_key(key: str) -> bool:
|
||||
lowered = key.lower()
|
||||
return lowered.endswith(_NUMERIC_INPUT_KEY_HINTS) or lowered in {
|
||||
"threshold",
|
||||
"count",
|
||||
"limit",
|
||||
"max",
|
||||
"min",
|
||||
"ratio",
|
||||
"size",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_worker_input_value(key: str, value: Any) -> Any:
|
||||
"""Normalize structured worker inputs before handing them to the runtime."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
if _looks_like_numeric_input_key(key):
|
||||
numeric_match = _LEADING_NUMERIC_RE.match(text)
|
||||
if numeric_match:
|
||||
number = numeric_match.group("number")
|
||||
return float(number) if "." in number else int(number)
|
||||
|
||||
if _looks_like_path_input_key(key):
|
||||
candidate = Path(text).expanduser()
|
||||
if not candidate.is_absolute():
|
||||
candidate = (Path.cwd() / candidate).resolve()
|
||||
return str(candidate)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def _should_backfill_from_recent_input(key: str, value: Any) -> bool:
|
||||
"""Return True when a recent session input should replace the current value."""
|
||||
if value is None:
|
||||
return True
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return True
|
||||
if _looks_like_numeric_input_key(key):
|
||||
return _LEADING_NUMERIC_RE.fullmatch(text) is None
|
||||
return False
|
||||
|
||||
|
||||
def _load_recent_worker_input_defaults(
|
||||
runtime: Any,
|
||||
input_keys: list[str],
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Load the best recent worker input payload from unified session state.
|
||||
|
||||
When a current session is known, only that session is considered so reruns
|
||||
cannot inherit structured defaults from an unrelated historical session.
|
||||
Otherwise we fall back to the latest available session as a best-effort
|
||||
compatibility path for legacy callers.
|
||||
"""
|
||||
store = getattr(runtime, "_session_store", None)
|
||||
sessions_dir = Path(getattr(store, "sessions_dir", "")) if store is not None else None
|
||||
if sessions_dir is None or not sessions_dir.exists():
|
||||
return {}
|
||||
allowed_key_set = set(input_keys)
|
||||
|
||||
candidate_state_paths: list[Path]
|
||||
if session_id:
|
||||
state_path = sessions_dir / session_id / "state.json"
|
||||
candidate_state_paths = [state_path] if state_path.exists() else []
|
||||
else:
|
||||
candidate_state_paths = []
|
||||
|
||||
if not candidate_state_paths:
|
||||
candidate_state_paths = sorted(
|
||||
sessions_dir.glob("session_*/state.json"),
|
||||
key=lambda path: path.stat().st_mtime if path.exists() else 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
work_keys = [key for key in input_keys if key not in {"user_request", "task", "feedback"}]
|
||||
if not work_keys:
|
||||
best_payload: dict[str, Any] = {}
|
||||
best_updated_at = ""
|
||||
for state_path in candidate_state_paths:
|
||||
try:
|
||||
raw_state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
input_data = raw_state.get("input_data") or {}
|
||||
if not isinstance(input_data, dict) or not input_data:
|
||||
continue
|
||||
input_data = {key: value for key, value in input_data.items() if key in allowed_key_set}
|
||||
if not input_data:
|
||||
continue
|
||||
updated_at = str((raw_state.get("timestamps") or {}).get("updated_at") or "")
|
||||
if updated_at >= best_updated_at:
|
||||
best_payload = dict(input_data)
|
||||
best_updated_at = updated_at
|
||||
return best_payload
|
||||
|
||||
best_payload: dict[str, Any] = {}
|
||||
best_score = -1
|
||||
best_updated_at = ""
|
||||
|
||||
for state_path in candidate_state_paths:
|
||||
try:
|
||||
raw_state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
input_data = raw_state.get("input_data") or {}
|
||||
if not isinstance(input_data, dict):
|
||||
continue
|
||||
|
||||
merged_input = {key: value for key, value in input_data.items() if key in allowed_key_set}
|
||||
result_output = (raw_state.get("result") or {}).get("output") or {}
|
||||
if isinstance(result_output, dict):
|
||||
for key in work_keys:
|
||||
if merged_input.get(key) in (None, "") and result_output.get(key) not in (None, ""):
|
||||
merged_input[key] = result_output[key]
|
||||
|
||||
score = sum(1 for key in work_keys if merged_input.get(key) not in (None, ""))
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
updated_at = str((raw_state.get("timestamps") or {}).get("updated_at") or "")
|
||||
if score > best_score or (score == best_score and updated_at > best_updated_at):
|
||||
best_payload = merged_input
|
||||
best_score = score
|
||||
best_updated_at = updated_at
|
||||
|
||||
return best_payload
|
||||
|
||||
|
||||
async def _preflight_worker_run(session: Any, runtime: Any, timeout_seconds: int) -> None:
|
||||
"""Validate credentials and refresh MCP servers before a worker run."""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def _preflight():
|
||||
cred_error: CredentialError | None = None
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: validate_credentials(
|
||||
runtime.graph.nodes,
|
||||
interactive=False,
|
||||
skip=False,
|
||||
),
|
||||
)
|
||||
except CredentialError as e:
|
||||
cred_error = e
|
||||
|
||||
runner = getattr(session, "runner", None)
|
||||
if runner:
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("MCP resync failed: %s", e)
|
||||
|
||||
if cred_error is not None:
|
||||
raise cred_error
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_preflight(), timeout=timeout_seconds)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"worker run preflight timed out after %ds — proceeding",
|
||||
timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _get_default_entry_input_keys(runtime: Any) -> list[str]:
|
||||
"""Return the loaded worker's default entry node input keys, if available."""
|
||||
try:
|
||||
entry_points = runtime.get_entry_points()
|
||||
except Exception:
|
||||
return []
|
||||
if not entry_points:
|
||||
return []
|
||||
|
||||
graph = getattr(runtime, "graph", None)
|
||||
if graph is None or not hasattr(graph, "get_node"):
|
||||
return []
|
||||
|
||||
entry_spec = entry_points[0]
|
||||
entry_node_id = getattr(entry_spec, "entry_node", None) or getattr(graph, "entry_node", None)
|
||||
if not entry_node_id:
|
||||
return []
|
||||
|
||||
node = graph.get_node(entry_node_id)
|
||||
return list(getattr(node, "input_keys", []) or []) if node is not None else []
|
||||
|
||||
|
||||
def _build_worker_input_data(
|
||||
runtime: Any,
|
||||
task: str,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Shape queen task text into the loaded worker's expected entry inputs."""
|
||||
structured = _parse_structured_task_payload(task)
|
||||
allowed_keys = _get_default_entry_input_keys(runtime)
|
||||
|
||||
# Backwards compatibility for older workers that still expect a single
|
||||
# free-form task string, while allowing newer workers to receive
|
||||
# structured fields directly.
|
||||
if not allowed_keys:
|
||||
payload = {"user_request": task, "task": task}
|
||||
payload.update(structured)
|
||||
return payload
|
||||
|
||||
shaped: dict[str, Any] = {}
|
||||
if "user_request" in allowed_keys:
|
||||
shaped["user_request"] = task
|
||||
if "task" in allowed_keys:
|
||||
shaped["task"] = task
|
||||
|
||||
work_keys = [key for key in allowed_keys if key not in {"user_request", "task", "feedback"}]
|
||||
structured_work_keys = {key for key in work_keys if key in structured}
|
||||
should_merge_recent_defaults = bool(structured_work_keys) or bool(
|
||||
_RERUN_WITH_DEFAULTS_RE.search(task or "")
|
||||
)
|
||||
recent_defaults = (
|
||||
_load_recent_worker_input_defaults(runtime, allowed_keys, session_id=session_id)
|
||||
if should_merge_recent_defaults
|
||||
else {}
|
||||
)
|
||||
|
||||
for key in allowed_keys:
|
||||
if key in {"user_request", "task", "feedback"}:
|
||||
continue
|
||||
if key in structured:
|
||||
shaped[key] = _normalize_worker_input_value(key, structured[key])
|
||||
if (
|
||||
key in recent_defaults
|
||||
and _should_backfill_from_recent_input(key, shaped[key])
|
||||
and recent_defaults.get(key) not in (None, "")
|
||||
):
|
||||
shaped[key] = _normalize_worker_input_value(key, recent_defaults[key])
|
||||
elif recent_defaults.get(key) not in (None, ""):
|
||||
shaped[key] = _normalize_worker_input_value(key, recent_defaults[key])
|
||||
|
||||
work_keys = [key for key in allowed_keys if key != "feedback"]
|
||||
if not any(key in shaped for key in work_keys):
|
||||
if len(work_keys) == 1:
|
||||
shaped[work_keys[0]] = task
|
||||
elif "user_request" in allowed_keys and "user_request" not in shaped:
|
||||
shaped["user_request"] = task
|
||||
elif "task" in allowed_keys and "task" not in shaped:
|
||||
shaped["task"] = task
|
||||
|
||||
return shaped
|
||||
|
||||
|
||||
def register_queen_lifecycle_tools(
|
||||
registry: ToolRegistry,
|
||||
session: Any = None,
|
||||
@@ -803,6 +1243,23 @@ def register_queen_lifecycle_tools(
|
||||
"""Get current worker runtime from session (late-binding)."""
|
||||
return getattr(session, "worker_runtime", None)
|
||||
|
||||
async def _run_package_validation(agent_ref: str) -> dict | None:
|
||||
"""Run validate_agent_package if available in the registry."""
|
||||
validator = registry._tools.get("validate_agent_package")
|
||||
if validator is None or not agent_ref:
|
||||
return None
|
||||
# The validator accepts either a built-agent package name or a
|
||||
# fully resolved allowed agent path.
|
||||
result = validator.executor({"agent_name": agent_ref})
|
||||
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
|
||||
result = await result
|
||||
parsed = _parse_validation_report(result)
|
||||
if parsed is None:
|
||||
return _invalid_validation_report(
|
||||
"validate_agent_package returned an invalid or undecodable report"
|
||||
)
|
||||
return parsed
|
||||
|
||||
# --- start_worker ---------------------------------------------------------
|
||||
|
||||
# How long to wait for credential validation + MCP resync before
|
||||
@@ -821,66 +1278,17 @@ def register_queen_lifecycle_tools(
|
||||
return json.dumps({"error": "No worker loaded in this session."})
|
||||
|
||||
try:
|
||||
# Pre-flight: validate credentials and resync MCP servers.
|
||||
# Both are blocking I/O (HTTP health-checks, subprocess spawns)
|
||||
# so they run in a thread-pool executor. We cap the total
|
||||
# preflight time so the queen never hangs waiting.
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def _preflight():
|
||||
cred_error: CredentialError | None = None
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: validate_credentials(
|
||||
runtime.graph.nodes,
|
||||
interactive=False,
|
||||
skip=False,
|
||||
),
|
||||
)
|
||||
except CredentialError as e:
|
||||
cred_error = e
|
||||
|
||||
runner = getattr(session, "runner", None)
|
||||
if runner:
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("MCP resync failed: %s", e)
|
||||
|
||||
# Re-raise CredentialError after MCP resync so both steps
|
||||
# get a chance to run before we bail.
|
||||
if cred_error is not None:
|
||||
raise cred_error
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"start_worker preflight timed out after %ds — proceeding with trigger",
|
||||
_START_PREFLIGHT_TIMEOUT,
|
||||
)
|
||||
except CredentialError:
|
||||
raise # handled below
|
||||
await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT)
|
||||
|
||||
# Resume timers in case they were paused by a previous stop_worker
|
||||
runtime.resume_timers()
|
||||
|
||||
# Get session state from any prior execution for memory continuity
|
||||
session_state = runtime._get_primary_session_state("default") or {}
|
||||
|
||||
# Use the shared session ID so queen, judge, and worker all
|
||||
# scope their conversations to the same session.
|
||||
if session_id:
|
||||
session_state["resume_session_id"] = session_id
|
||||
|
||||
exec_id = await runtime.trigger(
|
||||
entry_point_id="default",
|
||||
input_data={"user_request": task},
|
||||
session_state=session_state,
|
||||
input_data=_build_worker_input_data(runtime, task, session_id=session_id),
|
||||
# Worker runs should start from the explicit input payload for
|
||||
# this run, not inherit another execution's shared session.
|
||||
session_state=None,
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
@@ -2547,19 +2955,76 @@ def register_queen_lifecycle_tools(
|
||||
|
||||
return preamble
|
||||
|
||||
def _detect_red_flags(bus: EventBus) -> int:
|
||||
def _get_worker_health_snapshot() -> dict[str, Any] | None:
|
||||
worker_path = getattr(session, "worker_path", None)
|
||||
if not worker_path:
|
||||
return None
|
||||
try:
|
||||
snapshot = read_worker_health_snapshot(
|
||||
Path(worker_path),
|
||||
session_id=session_id,
|
||||
default_session_id=session_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to read worker health snapshot for queen status")
|
||||
return None
|
||||
if snapshot.get("error"):
|
||||
return None
|
||||
return snapshot
|
||||
|
||||
def _detect_red_flags(bus: EventBus, health_snapshot: dict[str, Any] | None = None) -> int:
|
||||
"""Count issue categories with cheap limit=1 queries."""
|
||||
if health_snapshot:
|
||||
issue_signals = health_snapshot.get("issue_signals", [])
|
||||
if isinstance(issue_signals, list) and issue_signals:
|
||||
return len(issue_signals)
|
||||
|
||||
count = 0
|
||||
for evt_type in (
|
||||
EventType.NODE_RETRY,
|
||||
EventType.NODE_STALLED,
|
||||
EventType.NODE_TOOL_DOOM_LOOP,
|
||||
EventType.CONSTRAINT_VIOLATION,
|
||||
):
|
||||
if bus.get_history(event_type=evt_type, limit=1):
|
||||
count += 1
|
||||
if _get_recent_judge_pressure(bus)[0]:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _format_summary(preamble: dict[str, Any], red_flags: int) -> str:
|
||||
def _get_recent_judge_pressure(bus: EventBus, streak_threshold: int = 4) -> tuple[bool, str]:
|
||||
"""Detect sustained judge churn even when no hard stall event exists yet."""
|
||||
verdict_events = bus.get_history(event_type=EventType.JUDGE_VERDICT, limit=8)
|
||||
if len(verdict_events) < streak_threshold:
|
||||
return False, ""
|
||||
|
||||
streak: list[str] = []
|
||||
for evt in verdict_events:
|
||||
action = str(evt.data.get("action", "")).upper()
|
||||
if action == "ACCEPT":
|
||||
break
|
||||
if action in _NON_ACCEPT_JUDGE_ACTIONS:
|
||||
streak.append(action)
|
||||
continue
|
||||
break
|
||||
|
||||
if len(streak) < streak_threshold:
|
||||
return False, ""
|
||||
|
||||
compressed: list[str] = []
|
||||
for action in streak:
|
||||
if not compressed or compressed[-1] != action:
|
||||
compressed.append(action)
|
||||
return (
|
||||
True,
|
||||
f"{len(streak)} consecutive non-ACCEPT judge verdict(s): {' -> '.join(compressed)}",
|
||||
)
|
||||
|
||||
def _format_summary(
|
||||
preamble: dict[str, Any],
|
||||
red_flags: int,
|
||||
health_snapshot: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Generate a 1-2 sentence prose summary from the preamble."""
|
||||
status = preamble["status"]
|
||||
|
||||
@@ -2586,10 +3051,17 @@ def register_queen_lifecycle_tools(
|
||||
node_part += f", iteration {iteration}"
|
||||
parts.append(node_part)
|
||||
|
||||
health_signals = health_snapshot.get("issue_signals", []) if health_snapshot else []
|
||||
if red_flags:
|
||||
parts.append(f"{red_flags} issue type(s) detected — use focus='issues' for details")
|
||||
if isinstance(health_signals, list) and health_signals:
|
||||
parts.append(
|
||||
f"{red_flags} issue signal(s) detected "
|
||||
f"({', '.join(health_signals)}) — use focus='issues' for details"
|
||||
)
|
||||
else:
|
||||
parts.append(f"{red_flags} issue type(s) detected — use focus='issues' for details")
|
||||
else:
|
||||
parts.append("No issues detected")
|
||||
parts.append("No issue signals detected")
|
||||
|
||||
# Latest subagent progress (if any delegation is in flight)
|
||||
bus = _get_event_bus()
|
||||
@@ -2737,7 +3209,7 @@ def register_queen_lifecycle_tools(
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_issues(bus: EventBus) -> str:
|
||||
def _format_issues(bus: EventBus, health_snapshot: dict[str, Any] | None = None) -> str:
|
||||
"""Format retries, stalls, doom loops, and constraint violations."""
|
||||
lines = []
|
||||
total = 0
|
||||
@@ -2787,8 +3259,42 @@ def register_queen_lifecycle_tools(
|
||||
ago = _format_time_ago(evt.timestamp)
|
||||
lines.append(f" {cid} ({ago}): {desc}")
|
||||
|
||||
has_judge_pressure, judge_pressure_desc = _get_recent_judge_pressure(bus)
|
||||
if has_judge_pressure:
|
||||
total += 1
|
||||
lines.append("Judge pressure detected:")
|
||||
lines.append(f" {judge_pressure_desc}")
|
||||
|
||||
if health_snapshot:
|
||||
issue_signals = health_snapshot.get("issue_signals", [])
|
||||
if isinstance(issue_signals, list) and issue_signals:
|
||||
total += len(issue_signals)
|
||||
lines.append("Health signals:")
|
||||
for signal in issue_signals:
|
||||
desc = _HEALTH_SIGNAL_DESCRIPTIONS.get(signal, signal.replace("_", " "))
|
||||
if (
|
||||
signal in {"stalled", "slow_progress"}
|
||||
and health_snapshot.get("stall_minutes") is not None
|
||||
):
|
||||
desc += f" ({health_snapshot['stall_minutes']} min since last step)"
|
||||
elif (
|
||||
signal in {"long_non_accept_streak", "judge_pressure"}
|
||||
and health_snapshot.get("steps_since_last_accept") is not None
|
||||
):
|
||||
desc += (
|
||||
" ("
|
||||
f"{health_snapshot['steps_since_last_accept']} non-ACCEPT step(s)"
|
||||
" since last ACCEPT)"
|
||||
)
|
||||
elif signal == "recent_non_accept_churn" and health_snapshot.get(
|
||||
"recent_verdicts"
|
||||
):
|
||||
verdicts = ", ".join(health_snapshot["recent_verdicts"][-4:])
|
||||
desc += f" ({verdicts})"
|
||||
lines.append(f" {signal}: {desc}")
|
||||
|
||||
if total == 0:
|
||||
return "No issues detected. No retries, stalls, or constraint violations."
|
||||
return "No issues detected. No runtime issue signals were found."
|
||||
|
||||
header = f"{total} issue(s) detected."
|
||||
return header + "\n\n" + "\n".join(lines)
|
||||
@@ -3086,8 +3592,9 @@ def register_queen_lifecycle_tools(
|
||||
try:
|
||||
if focus is None:
|
||||
# Default: brief prose summary
|
||||
red_flags = _detect_red_flags(bus) if bus else 0
|
||||
return _format_summary(preamble, red_flags)
|
||||
health_snapshot = _get_worker_health_snapshot()
|
||||
red_flags = _detect_red_flags(bus, health_snapshot) if bus else 0
|
||||
return _format_summary(preamble, red_flags, health_snapshot)
|
||||
|
||||
if bus is None:
|
||||
return (
|
||||
@@ -3102,7 +3609,7 @@ def register_queen_lifecycle_tools(
|
||||
elif focus == "tools":
|
||||
return _format_tools(bus, last_n)
|
||||
elif focus == "issues":
|
||||
return _format_issues(bus)
|
||||
return _format_issues(bus, _get_worker_health_snapshot())
|
||||
elif focus == "progress":
|
||||
return await _format_progress(runtime, bus)
|
||||
elif focus == "full":
|
||||
@@ -3399,14 +3906,6 @@ def register_queen_lifecycle_tools(
|
||||
available immediately. The user will see the agent's graph and
|
||||
can interact with it without opening a new tab.
|
||||
"""
|
||||
runtime = _get_runtime()
|
||||
if runtime is not None:
|
||||
try:
|
||||
await session_manager.unload_worker(manager_session_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to unload existing worker: %s", e, exc_info=True)
|
||||
return json.dumps({"error": f"Failed to unload existing worker: {e}"})
|
||||
|
||||
try:
|
||||
resolved_path = validate_agent_path(agent_path)
|
||||
except ValueError as e:
|
||||
@@ -3414,6 +3913,19 @@ def register_queen_lifecycle_tools(
|
||||
if not resolved_path.exists():
|
||||
return json.dumps({"error": f"Agent path does not exist: {agent_path}"})
|
||||
|
||||
validation_report = await _run_package_validation(str(resolved_path))
|
||||
if _validation_blocks_stage_or_run(validation_report):
|
||||
failures = _validation_failures(validation_report)
|
||||
return json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Cannot load agent '{resolved_path.name}' because validation failed. "
|
||||
"Fix the package and re-run validate_agent_package() before loading."
|
||||
),
|
||||
"validation_failures": failures,
|
||||
}
|
||||
)
|
||||
|
||||
# Pre-check: verify the module exports goal/nodes/edges before
|
||||
# attempting the full load. This gives the queen an actionable
|
||||
# error message instead of a cryptic ImportError or TypeError.
|
||||
@@ -3459,6 +3971,14 @@ def register_queen_lifecycle_tools(
|
||||
}
|
||||
)
|
||||
|
||||
runtime = _get_runtime()
|
||||
if runtime is not None:
|
||||
try:
|
||||
await session_manager.unload_worker(manager_session_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to unload existing worker: %s", e, exc_info=True)
|
||||
return json.dumps({"error": f"Failed to unload existing worker: {e}"})
|
||||
|
||||
try:
|
||||
updated_session = await session_manager.load_worker(
|
||||
manager_session_id,
|
||||
@@ -3607,60 +4127,35 @@ def register_queen_lifecycle_tools(
|
||||
if runtime is None:
|
||||
return json.dumps({"error": "No worker loaded in this session."})
|
||||
|
||||
worker_path = getattr(session, "worker_path", None)
|
||||
worker_name = Path(worker_path).name if worker_path else ""
|
||||
validation_report = await _run_package_validation(
|
||||
str(worker_path) if worker_path else worker_name
|
||||
)
|
||||
if _validation_blocks_stage_or_run(validation_report):
|
||||
failures = _validation_failures(validation_report)
|
||||
return json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Cannot run agent '{worker_name or 'current worker'}' because validation "
|
||||
"is failing. Fix the package and reload it before running."
|
||||
),
|
||||
"validation_failures": failures,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Pre-flight: validate credentials and resync MCP servers.
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def _preflight():
|
||||
cred_error: CredentialError | None = None
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: validate_credentials(
|
||||
runtime.graph.nodes,
|
||||
interactive=False,
|
||||
skip=False,
|
||||
),
|
||||
)
|
||||
except CredentialError as e:
|
||||
cred_error = e
|
||||
|
||||
runner = getattr(session, "runner", None)
|
||||
if runner:
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("MCP resync failed: %s", e)
|
||||
|
||||
if cred_error is not None:
|
||||
raise cred_error
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"run_agent_with_input preflight timed out after %ds — proceeding",
|
||||
_START_PREFLIGHT_TIMEOUT,
|
||||
)
|
||||
except CredentialError:
|
||||
raise # handled below
|
||||
await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT)
|
||||
|
||||
# Resume timers in case they were paused by a previous stop
|
||||
runtime.resume_timers()
|
||||
|
||||
# Get session state from any prior execution for memory continuity
|
||||
session_state = runtime._get_primary_session_state("default") or {}
|
||||
|
||||
if session_id:
|
||||
session_state["resume_session_id"] = session_id
|
||||
|
||||
exec_id = await runtime.trigger(
|
||||
entry_point_id="default",
|
||||
input_data={"user_request": task},
|
||||
session_state=session_state,
|
||||
input_data=_build_worker_input_data(runtime, task, session_id=session_id),
|
||||
# Fresh manual worker runs avoid stale state leaking from a
|
||||
# previous execution into Codex's next tool/planning turn.
|
||||
session_state=None,
|
||||
)
|
||||
|
||||
# Switch to running phase
|
||||
@@ -3693,6 +4188,91 @@ def register_queen_lifecycle_tools(
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to start worker: {e}"})
|
||||
|
||||
async def rerun_worker_with_last_input() -> str:
|
||||
"""Rerun the loaded worker using the last complete structured input payload."""
|
||||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return json.dumps({"error": "No worker loaded in this session."})
|
||||
|
||||
worker_path = getattr(session, "worker_path", None)
|
||||
worker_name = Path(worker_path).name if worker_path else ""
|
||||
validation_report = await _run_package_validation(
|
||||
str(worker_path) if worker_path else worker_name
|
||||
)
|
||||
if _validation_blocks_stage_or_run(validation_report):
|
||||
failures = _validation_failures(validation_report)
|
||||
return json.dumps(
|
||||
{
|
||||
"error": (
|
||||
f"Cannot rerun agent '{worker_name or 'current worker'}' "
|
||||
"because validation "
|
||||
"is failing. Fix the package and reload it before running."
|
||||
),
|
||||
"validation_failures": failures,
|
||||
}
|
||||
)
|
||||
|
||||
allowed_keys = _get_default_entry_input_keys(runtime)
|
||||
input_data = {
|
||||
key: _normalize_worker_input_value(key, value)
|
||||
for key, value in _load_recent_worker_input_defaults(
|
||||
runtime,
|
||||
allowed_keys,
|
||||
session_id=session_id,
|
||||
).items()
|
||||
}
|
||||
work_keys = [key for key in allowed_keys if key not in {"user_request", "task", "feedback"}]
|
||||
if work_keys:
|
||||
missing = [key for key in work_keys if input_data.get(key) in (None, "")]
|
||||
if missing:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": "No complete previous worker input is available for a "
|
||||
"same-defaults rerun.",
|
||||
"missing_inputs": missing,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT)
|
||||
|
||||
runtime.resume_timers()
|
||||
|
||||
exec_id = await runtime.trigger(
|
||||
entry_point_id="default",
|
||||
input_data=input_data,
|
||||
session_state=None,
|
||||
)
|
||||
|
||||
if phase_state is not None:
|
||||
await phase_state.switch_to_running()
|
||||
_update_meta_json(session_manager, manager_session_id, {"phase": "running"})
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "started",
|
||||
"phase": "running",
|
||||
"execution_id": exec_id,
|
||||
"input_data": input_data,
|
||||
}
|
||||
)
|
||||
except CredentialError as e:
|
||||
error_payload = credential_errors_to_json(e)
|
||||
error_payload["agent_path"] = str(getattr(session, "worker_path", "") or "")
|
||||
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus is not None:
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CREDENTIALS_REQUIRED,
|
||||
stream_id="queen",
|
||||
data=error_payload,
|
||||
)
|
||||
)
|
||||
return json.dumps(error_payload)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to rerun worker: {e}"})
|
||||
|
||||
_run_input_tool = Tool(
|
||||
name="run_agent_with_input",
|
||||
description=(
|
||||
@@ -3716,6 +4296,21 @@ def register_queen_lifecycle_tools(
|
||||
)
|
||||
tools_registered += 1
|
||||
|
||||
_rerun_tool = Tool(
|
||||
name="rerun_worker_with_last_input",
|
||||
description=(
|
||||
"Rerun the loaded worker using the most recent complete structured input payload. "
|
||||
"Use this when the user asks to run again with the same defaults or same input."
|
||||
),
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
registry.register(
|
||||
"rerun_worker_with_last_input",
|
||||
_rerun_tool,
|
||||
lambda _inputs: rerun_worker_with_last_input(),
|
||||
)
|
||||
tools_registered += 1
|
||||
|
||||
# --- set_trigger -----------------------------------------------------------
|
||||
|
||||
async def set_trigger(
|
||||
|
||||
@@ -36,6 +36,172 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# How many tool_log steps to include in the health summary
|
||||
_DEFAULT_LAST_N_STEPS = 40
|
||||
_NON_ACCEPT_VERDICTS = frozenset({"RETRY", "CONTINUE", "ESCALATE"})
|
||||
|
||||
|
||||
def classify_worker_health(
|
||||
*,
|
||||
session_status: str,
|
||||
recent_verdicts: list[str],
|
||||
steps_since_last_accept: int,
|
||||
stall_minutes: float | None,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Classify worker health from persisted run evidence.
|
||||
|
||||
Keeping this logic at module scope lets Queen-facing status views reuse the
|
||||
exact same health signals as the monitoring tool instead of drifting into a
|
||||
separate, weaker interpretation of worker state.
|
||||
"""
|
||||
issue_signals: list[str] = []
|
||||
|
||||
if session_status == "failed":
|
||||
issue_signals.append("failed_session")
|
||||
|
||||
if stall_minutes is not None:
|
||||
if stall_minutes >= 5:
|
||||
issue_signals.append("stalled")
|
||||
elif stall_minutes >= 2:
|
||||
issue_signals.append("slow_progress")
|
||||
|
||||
if steps_since_last_accept >= 6:
|
||||
issue_signals.append("long_non_accept_streak")
|
||||
elif steps_since_last_accept >= 4:
|
||||
issue_signals.append("judge_pressure")
|
||||
|
||||
if len(recent_verdicts) >= 4 and all(v in _NON_ACCEPT_VERDICTS for v in recent_verdicts[-4:]):
|
||||
issue_signals.append("recent_non_accept_churn")
|
||||
|
||||
issue_signals = list(dict.fromkeys(issue_signals))
|
||||
|
||||
if any(sig in issue_signals for sig in ("failed_session", "stalled", "long_non_accept_streak")):
|
||||
return "critical", issue_signals
|
||||
if issue_signals:
|
||||
return "warning", issue_signals
|
||||
return "healthy", issue_signals
|
||||
|
||||
|
||||
def read_worker_health_snapshot(
|
||||
storage_path: Path,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
last_n_steps: int = _DEFAULT_LAST_N_STEPS,
|
||||
default_session_id: str | None = None,
|
||||
worker_agent_id: str | None = None,
|
||||
worker_graph_id: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""Read persisted worker logs and return the structured health snapshot.
|
||||
|
||||
This is the shared source of truth for worker-health reporting. The
|
||||
monitoring tool returns it as JSON, while Queen's user-facing summaries can
|
||||
consume the same dict directly to avoid underreporting issue signals.
|
||||
"""
|
||||
storage_path = Path(storage_path)
|
||||
resolved_worker_agent_id = worker_agent_id or storage_path.name
|
||||
resolved_worker_graph_id = worker_graph_id or storage_path.name
|
||||
|
||||
# Auto-discover the most recent session if not specified.
|
||||
if not session_id or session_id == "auto":
|
||||
sessions_dir = storage_path / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
return {"error": "No sessions found — worker has not started yet"}
|
||||
|
||||
if default_session_id and (sessions_dir / default_session_id).is_dir():
|
||||
session_id = default_session_id
|
||||
else:
|
||||
candidates = [
|
||||
d for d in sessions_dir.iterdir() if d.is_dir() and (d / "state.json").exists()
|
||||
]
|
||||
if not candidates:
|
||||
return {"error": "No sessions found — worker has not started yet"}
|
||||
|
||||
def _sort_key(d: Path):
|
||||
try:
|
||||
state = json.loads((d / "state.json").read_text(encoding="utf-8"))
|
||||
priority = 0 if state.get("status", "") in ("in_progress", "running") else 1
|
||||
return (priority, -d.stat().st_mtime)
|
||||
except Exception:
|
||||
return (2, 0)
|
||||
|
||||
candidates.sort(key=_sort_key)
|
||||
session_id = candidates[0].name
|
||||
|
||||
session_dir = storage_path / "sessions" / str(session_id)
|
||||
tool_logs_path = session_dir / "logs" / "tool_logs.jsonl"
|
||||
state_path = session_dir / "state.json"
|
||||
if not session_dir.exists() or not state_path.exists():
|
||||
return {"error": f"No persisted worker state found for session '{session_id}'"}
|
||||
|
||||
session_status = "unknown"
|
||||
if state_path.exists():
|
||||
try:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
session_status = state.get("status", "unknown")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
steps: list[dict] = []
|
||||
if tool_logs_path.exists():
|
||||
try:
|
||||
with open(tool_logs_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
steps.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except OSError as e:
|
||||
return {"error": f"Could not read tool logs: {e}"}
|
||||
|
||||
total_steps = len(steps)
|
||||
recent = steps[-last_n_steps:] if len(steps) > last_n_steps else steps
|
||||
recent_verdicts = [s.get("verdict", "") for s in recent if s.get("verdict")]
|
||||
|
||||
steps_since_last_accept = 0
|
||||
for verdict in reversed(recent_verdicts):
|
||||
if verdict == "ACCEPT":
|
||||
break
|
||||
steps_since_last_accept += 1
|
||||
|
||||
last_step_time_iso: str | None = None
|
||||
stall_minutes: float | None = None
|
||||
if steps and tool_logs_path.exists():
|
||||
try:
|
||||
mtime = tool_logs_path.stat().st_mtime
|
||||
last_step_time_iso = datetime.fromtimestamp(mtime, UTC).isoformat()
|
||||
elapsed = (datetime.now(UTC).timestamp() - mtime) / 60
|
||||
stall_minutes = round(elapsed, 1) if elapsed >= 1.0 else None
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
evidence_snippet = ""
|
||||
for step in reversed(recent):
|
||||
text = step.get("llm_text", "")
|
||||
if text:
|
||||
evidence_snippet = text[:500]
|
||||
break
|
||||
|
||||
health_status, issue_signals = classify_worker_health(
|
||||
session_status=session_status,
|
||||
recent_verdicts=recent_verdicts,
|
||||
steps_since_last_accept=steps_since_last_accept,
|
||||
stall_minutes=stall_minutes,
|
||||
)
|
||||
|
||||
return {
|
||||
"worker_agent_id": resolved_worker_agent_id,
|
||||
"worker_graph_id": resolved_worker_graph_id,
|
||||
"session_id": session_id,
|
||||
"session_status": session_status,
|
||||
"health_status": health_status,
|
||||
"issue_signals": issue_signals,
|
||||
"total_steps": total_steps,
|
||||
"recent_verdicts": recent_verdicts,
|
||||
"steps_since_last_accept": steps_since_last_accept,
|
||||
"last_step_time_iso": last_step_time_iso,
|
||||
"stall_minutes": stall_minutes,
|
||||
"evidence_snippet": evidence_snippet,
|
||||
}
|
||||
|
||||
|
||||
def register_worker_monitoring_tools(
|
||||
@@ -91,6 +257,8 @@ def register_worker_monitoring_tools(
|
||||
Returns a JSON object with:
|
||||
- session_id: the session inspected (useful when auto-discovered)
|
||||
- session_status: "running"|"completed"|"failed"|"in_progress"|"unknown"
|
||||
- health_status: "healthy"|"warning"|"critical"
|
||||
- issue_signals: list of detected warning/attention categories
|
||||
- total_steps: total number of log steps recorded so far
|
||||
- recent_verdicts: list of last N verdict strings (ACCEPT/RETRY/CONTINUE/ESCALATE)
|
||||
- steps_since_last_accept: consecutive non-ACCEPT steps from the end
|
||||
@@ -98,120 +266,22 @@ def register_worker_monitoring_tools(
|
||||
- stall_minutes: wall-clock minutes since last step (null if < 1 min)
|
||||
- evidence_snippet: last LLM text from the most recent step (truncated)
|
||||
"""
|
||||
# Auto-discover the most recent session if not specified
|
||||
if not session_id or session_id == "auto":
|
||||
sessions_dir = storage_path / "sessions"
|
||||
if not sessions_dir.exists():
|
||||
return json.dumps({"error": "No sessions found — worker has not started yet"})
|
||||
|
||||
# Prefer the queen's own session ID (set at registration time) over
|
||||
# mtime-based discovery, which can pick a stale orphaned session after
|
||||
# a cold-restore when a newer-but-empty session directory exists.
|
||||
if default_session_id and (sessions_dir / default_session_id).is_dir():
|
||||
session_id = default_session_id
|
||||
else:
|
||||
candidates = [
|
||||
d for d in sessions_dir.iterdir() if d.is_dir() and (d / "state.json").exists()
|
||||
]
|
||||
if not candidates:
|
||||
return json.dumps({"error": "No sessions found — worker has not started yet"})
|
||||
|
||||
def _sort_key(d: Path):
|
||||
try:
|
||||
state = json.loads((d / "state.json").read_text(encoding="utf-8"))
|
||||
# in_progress/running sorts before completed/failed
|
||||
priority = 0 if state.get("status", "") in ("in_progress", "running") else 1
|
||||
return (priority, -d.stat().st_mtime)
|
||||
except Exception:
|
||||
return (2, 0)
|
||||
|
||||
candidates.sort(key=_sort_key)
|
||||
session_id = candidates[0].name
|
||||
|
||||
# Resolve log paths
|
||||
session_dir = storage_path / "sessions" / session_id
|
||||
tool_logs_path = session_dir / "logs" / "tool_logs.jsonl"
|
||||
state_path = session_dir / "state.json"
|
||||
|
||||
# Read session status
|
||||
session_status = "unknown"
|
||||
if state_path.exists():
|
||||
try:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
session_status = state.get("status", "unknown")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Read tool logs
|
||||
steps: list[dict] = []
|
||||
if tool_logs_path.exists():
|
||||
try:
|
||||
with open(tool_logs_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
steps.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except OSError as e:
|
||||
return json.dumps({"error": f"Could not read tool logs: {e}"})
|
||||
|
||||
total_steps = len(steps)
|
||||
recent = steps[-last_n_steps:] if len(steps) > last_n_steps else steps
|
||||
|
||||
# Extract verdict sequence
|
||||
recent_verdicts = [s.get("verdict", "") for s in recent if s.get("verdict")]
|
||||
|
||||
# Count consecutive non-ACCEPT from the end
|
||||
steps_since_last_accept = 0
|
||||
for v in reversed(recent_verdicts):
|
||||
if v == "ACCEPT":
|
||||
break
|
||||
steps_since_last_accept += 1
|
||||
|
||||
# Timing: use tool_logs file mtime as proxy for last step time
|
||||
last_step_time_iso: str | None = None
|
||||
stall_minutes: float | None = None
|
||||
if steps and tool_logs_path.exists():
|
||||
try:
|
||||
mtime = tool_logs_path.stat().st_mtime
|
||||
last_step_time_iso = datetime.fromtimestamp(mtime, UTC).isoformat()
|
||||
elapsed = (datetime.now(UTC).timestamp() - mtime) / 60
|
||||
stall_minutes = round(elapsed, 1) if elapsed >= 1.0 else None
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Evidence snippet: last LLM text
|
||||
evidence_snippet = ""
|
||||
for step in reversed(recent):
|
||||
text = step.get("llm_text", "")
|
||||
if text:
|
||||
evidence_snippet = text[:500]
|
||||
break
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"worker_agent_id": _worker_agent_id,
|
||||
"worker_graph_id": _worker_graph_id,
|
||||
"session_id": session_id,
|
||||
"session_status": session_status,
|
||||
"total_steps": total_steps,
|
||||
"recent_verdicts": recent_verdicts,
|
||||
"steps_since_last_accept": steps_since_last_accept,
|
||||
"last_step_time_iso": last_step_time_iso,
|
||||
"stall_minutes": stall_minutes,
|
||||
"evidence_snippet": evidence_snippet,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
snapshot = read_worker_health_snapshot(
|
||||
storage_path,
|
||||
session_id=session_id,
|
||||
last_n_steps=last_n_steps,
|
||||
default_session_id=default_session_id,
|
||||
worker_agent_id=_worker_agent_id,
|
||||
worker_graph_id=_worker_graph_id,
|
||||
)
|
||||
return json.dumps(snapshot, ensure_ascii=False)
|
||||
|
||||
_health_summary_tool = Tool(
|
||||
name="get_worker_health_summary",
|
||||
description=(
|
||||
"Read the worker agent's execution logs and return a compact health snapshot. "
|
||||
"Returns worker_agent_id and worker_graph_id (use these for ticket identity fields), "
|
||||
"recent verdicts, step count, time since last step, and "
|
||||
"health_status, issue_signals, recent verdicts, step count, time since last step, and "
|
||||
"a snippet of the most recent LLM output. "
|
||||
"session_id is optional — omit it to auto-discover the most recent active session."
|
||||
),
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import type { NodeSpec } from "@/api/types";
|
||||
import type { GraphNode } from "@/components/graph-types";
|
||||
|
||||
import {
|
||||
buildStructuredRunQuestions,
|
||||
canShowRunButton,
|
||||
getStructuredRunInputKeys,
|
||||
hasAllStructuredRunInputs,
|
||||
trimStructuredRunInputs,
|
||||
} from "./run-inputs";
|
||||
|
||||
function makeNodeSpec(overrides: Partial<NodeSpec>): NodeSpec {
|
||||
return {
|
||||
id: "node-1",
|
||||
name: "Node 1",
|
||||
description: "",
|
||||
node_type: "event_loop",
|
||||
input_keys: [],
|
||||
output_keys: [],
|
||||
nullable_output_keys: [],
|
||||
tools: [],
|
||||
routes: {},
|
||||
max_retries: 0,
|
||||
max_node_visits: 0,
|
||||
client_facing: false,
|
||||
success_criteria: null,
|
||||
system_prompt: "",
|
||||
sub_agents: [],
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeGraphNode(overrides: Partial<GraphNode>): GraphNode {
|
||||
return {
|
||||
id: "node-1",
|
||||
label: "Node 1",
|
||||
status: "pending",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("getStructuredRunInputKeys", () => {
|
||||
it("returns structured input keys from the first non-trigger graph node", () => {
|
||||
const nodeSpecs = [
|
||||
makeNodeSpec({
|
||||
id: "receive-runtime-inputs",
|
||||
input_keys: ["target_dir", "review_dir", "word_threshold"],
|
||||
}),
|
||||
];
|
||||
const graphNodes = [
|
||||
makeGraphNode({ id: "__trigger_default", nodeType: "trigger" }),
|
||||
makeGraphNode({ id: "receive-runtime-inputs", nodeType: "execution" }),
|
||||
];
|
||||
|
||||
expect(getStructuredRunInputKeys(nodeSpecs, graphNodes)).toEqual([
|
||||
"target_dir",
|
||||
"review_dir",
|
||||
"word_threshold",
|
||||
]);
|
||||
});
|
||||
|
||||
it("filters out generic task-style entry keys", () => {
|
||||
const nodeSpecs = [
|
||||
makeNodeSpec({
|
||||
id: "entry",
|
||||
input_keys: ["user_request", "task", "feedback", "target_dir"],
|
||||
}),
|
||||
];
|
||||
|
||||
expect(getStructuredRunInputKeys(nodeSpecs, [])).toEqual(["target_dir"]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasAllStructuredRunInputs", () => {
|
||||
it("requires every structured key to be present and non-blank", () => {
|
||||
expect(
|
||||
hasAllStructuredRunInputs(["target_dir", "word_threshold"], {
|
||||
target_dir: "/tmp/project",
|
||||
word_threshold: "800",
|
||||
}),
|
||||
).toBe(true);
|
||||
|
||||
expect(
|
||||
hasAllStructuredRunInputs(["target_dir", "word_threshold"], {
|
||||
target_dir: " ",
|
||||
word_threshold: "800",
|
||||
}),
|
||||
).toBe(false);
|
||||
|
||||
expect(
|
||||
hasAllStructuredRunInputs(["target_dir", "word_threshold"], {
|
||||
target_dir: "/tmp/project",
|
||||
}),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildStructuredRunQuestions", () => {
|
||||
it("creates free-text prompts for each required run input", () => {
|
||||
expect(buildStructuredRunQuestions(["target_dir", "review_dir"])).toEqual([
|
||||
{ id: "target_dir", prompt: "Provide target_dir for this run." },
|
||||
{ id: "review_dir", prompt: "Provide review_dir for this run." },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("canShowRunButton", () => {
|
||||
it("only exposes Run when a worker session is ready and staged/running", () => {
|
||||
expect(canShowRunButton("sess-1", true, "staging", true)).toBe(true);
|
||||
expect(canShowRunButton("sess-1", true, "running", true)).toBe(true);
|
||||
|
||||
expect(canShowRunButton("sess-1", true, "planning", true)).toBe(false);
|
||||
expect(canShowRunButton("sess-1", true, "building", true)).toBe(false);
|
||||
expect(canShowRunButton("sess-1", false, "staging", true)).toBe(false);
|
||||
expect(canShowRunButton("sess-1", true, "staging", false)).toBe(false);
|
||||
expect(canShowRunButton(null, true, "staging", true)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("trimStructuredRunInputs", () => {
|
||||
it("drops stale keys that are no longer part of the current schema", () => {
|
||||
expect(
|
||||
trimStructuredRunInputs(["target_dir", "word_threshold"], {
|
||||
target_dir: "/tmp/project",
|
||||
word_threshold: 800,
|
||||
stale_key: "old",
|
||||
}),
|
||||
).toEqual({
|
||||
target_dir: "/tmp/project",
|
||||
word_threshold: 800,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,56 @@
|
||||
import type { NodeSpec } from "@/api/types";
|
||||
import type { GraphNode } from "@/components/graph-types";
|
||||
|
||||
const GENERIC_ENTRY_KEYS = new Set(["task", "user_request", "feedback"]);
|
||||
const RUNNABLE_PHASES = new Set(["staging", "running"]);
|
||||
|
||||
type QueenPhase = "planning" | "building" | "staging" | "running";
|
||||
|
||||
function isMeaningfulValue(value: unknown): boolean {
|
||||
if (typeof value === "string") return value.trim().length > 0;
|
||||
return value !== undefined && value !== null;
|
||||
}
|
||||
|
||||
export function getStructuredRunInputKeys(
|
||||
nodeSpecs: NodeSpec[],
|
||||
graphNodes: GraphNode[],
|
||||
): string[] {
|
||||
const entryNodeId =
|
||||
graphNodes.find((node) => node.nodeType !== "trigger")?.id ?? nodeSpecs[0]?.id;
|
||||
if (!entryNodeId) return [];
|
||||
|
||||
const entrySpec = nodeSpecs.find((node) => node.id === entryNodeId) ?? nodeSpecs[0];
|
||||
return (entrySpec?.input_keys ?? []).filter((key) => !GENERIC_ENTRY_KEYS.has(key));
|
||||
}
|
||||
|
||||
export function hasAllStructuredRunInputs(
|
||||
keys: string[],
|
||||
inputData: Record<string, unknown> | null | undefined,
|
||||
): inputData is Record<string, unknown> {
|
||||
if (!inputData) return false;
|
||||
return keys.every((key) => isMeaningfulValue(inputData[key]));
|
||||
}
|
||||
|
||||
export function buildStructuredRunQuestions(keys: string[]) {
|
||||
return keys.map((key) => ({
|
||||
id: key,
|
||||
prompt: `Provide ${key} for this run.`,
|
||||
}));
|
||||
}
|
||||
|
||||
export function trimStructuredRunInputs(
|
||||
keys: string[],
|
||||
inputData: Record<string, unknown> | null | undefined,
|
||||
): Record<string, unknown> {
|
||||
if (!inputData) return {};
|
||||
return Object.fromEntries(keys.flatMap((key) => (key in inputData ? [[key, inputData[key]]] : [])));
|
||||
}
|
||||
|
||||
export function canShowRunButton(
|
||||
sessionId: string | null | undefined,
|
||||
ready: boolean | null | undefined,
|
||||
queenPhase: QueenPhase | null | undefined,
|
||||
topologyReady: boolean,
|
||||
): boolean {
|
||||
return Boolean(sessionId && ready && topologyReady && queenPhase && RUNNABLE_PHASES.has(queenPhase));
|
||||
}
|
||||
@@ -18,6 +18,13 @@ import type { LiveSession, AgentEvent, DiscoverEntry, NodeSpec, DraftGraph as Dr
|
||||
import { sseEventToChatMessage, formatAgentDisplayName } from "@/lib/chat-helpers";
|
||||
import { topologyToGraphNodes } from "@/lib/graph-converter";
|
||||
import { cronToLabel } from "@/lib/graphUtils";
|
||||
import {
|
||||
buildStructuredRunQuestions,
|
||||
canShowRunButton,
|
||||
getStructuredRunInputKeys,
|
||||
hasAllStructuredRunInputs,
|
||||
trimStructuredRunInputs,
|
||||
} from "@/lib/run-inputs";
|
||||
import { ApiError } from "@/api/client";
|
||||
|
||||
const makeId = () => Math.random().toString(36).slice(2, 9);
|
||||
@@ -351,7 +358,9 @@ interface AgentBackendState {
|
||||
/** Multiple questions from ask_user_multiple */
|
||||
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
|
||||
/** Whether the pending question came from queen or worker */
|
||||
pendingQuestionSource: "queen" | "worker" | null;
|
||||
pendingQuestionSource: "queen" | "worker" | "run" | null;
|
||||
/** Last structured input payload successfully used to start the worker. */
|
||||
lastRunInputData: Record<string, unknown> | null;
|
||||
/** Per-node context window usage (from context_usage_updated events) */
|
||||
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
|
||||
/** Whether the queen's LLM supports image content — false disables the attach button */
|
||||
@@ -393,6 +402,7 @@ function defaultAgentState(): AgentBackendState {
|
||||
pendingOptions: null,
|
||||
pendingQuestions: null,
|
||||
pendingQuestionSource: null,
|
||||
lastRunInputData: null,
|
||||
contextUsage: {},
|
||||
queenSupportsImages: true,
|
||||
};
|
||||
@@ -693,15 +703,71 @@ export default function Workspace() {
|
||||
}
|
||||
}, [sessionsByAgent, activeSessionByAgent, activeWorker, agentStates]);
|
||||
|
||||
const appendSystemMessage = useCallback((agentType: string, content: string) => {
|
||||
setSessionsByAgent((prev) => {
|
||||
const sessions = prev[agentType] || [];
|
||||
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
|
||||
return {
|
||||
...prev,
|
||||
[agentType]: sessions.map((s) => {
|
||||
if (s.id !== activeId) return s;
|
||||
const errorMsg: ChatMessage = {
|
||||
id: makeId(),
|
||||
agent: "System",
|
||||
agentColor: "",
|
||||
content,
|
||||
timestamp: "",
|
||||
type: "system",
|
||||
thread: agentType,
|
||||
createdAt: Date.now(),
|
||||
};
|
||||
return { ...s, messages: [...s.messages, errorMsg] };
|
||||
}),
|
||||
};
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleRun = useCallback(async () => {
|
||||
const state = agentStates[activeWorker];
|
||||
if (!state?.sessionId || !state?.ready) return;
|
||||
|
||||
const sessions = sessionsRef.current[activeWorker] || [];
|
||||
const activeId = activeSessionRef.current[activeWorker] || sessions[0]?.id;
|
||||
const activeSession = sessions.find((s) => s.id === activeId) || sessions[0];
|
||||
const requiredRunKeys = getStructuredRunInputKeys(
|
||||
state.nodeSpecs,
|
||||
activeSession?.graphNodes || [],
|
||||
);
|
||||
|
||||
if (
|
||||
requiredRunKeys.length > 0 &&
|
||||
!hasAllStructuredRunInputs(requiredRunKeys, state.lastRunInputData)
|
||||
) {
|
||||
updateAgentState(activeWorker, {
|
||||
awaitingInput: true,
|
||||
pendingQuestion: null,
|
||||
pendingOptions: null,
|
||||
pendingQuestions: buildStructuredRunQuestions(requiredRunKeys),
|
||||
pendingQuestionSource: "run",
|
||||
workerRunState: "idle",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const inputData =
|
||||
requiredRunKeys.length > 0
|
||||
? trimStructuredRunInputs(requiredRunKeys, state.lastRunInputData)
|
||||
: {};
|
||||
|
||||
// Reset dismissed banner so a repeated 424 re-shows it
|
||||
setDismissedBanner(null);
|
||||
try {
|
||||
updateAgentState(activeWorker, { workerRunState: "deploying" });
|
||||
const result = await executionApi.trigger(state.sessionId, "default", {});
|
||||
updateAgentState(activeWorker, { currentExecutionId: result.execution_id });
|
||||
const result = await executionApi.trigger(state.sessionId, "default", inputData);
|
||||
updateAgentState(activeWorker, {
|
||||
currentExecutionId: result.execution_id,
|
||||
lastRunInputData: inputData,
|
||||
});
|
||||
} catch (err) {
|
||||
// 424 = credentials required — open the credentials modal
|
||||
if (err instanceof ApiError && err.status === 424) {
|
||||
@@ -714,25 +780,23 @@ export default function Workspace() {
|
||||
}
|
||||
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
setSessionsByAgent((prev) => {
|
||||
const sessions = prev[activeWorker] || [];
|
||||
const activeId = activeSessionRef.current[activeWorker] || sessions[0]?.id;
|
||||
return {
|
||||
...prev,
|
||||
[activeWorker]: sessions.map((s) => {
|
||||
if (s.id !== activeId) return s;
|
||||
const errorMsg: ChatMessage = {
|
||||
id: makeId(), agent: "System", agentColor: "",
|
||||
content: `Failed to trigger run: ${errMsg}`,
|
||||
timestamp: "", type: "system", thread: activeWorker, createdAt: Date.now(),
|
||||
};
|
||||
return { ...s, messages: [...s.messages, errorMsg] };
|
||||
}),
|
||||
};
|
||||
});
|
||||
appendSystemMessage(activeWorker, `Failed to trigger run: ${errMsg}`);
|
||||
updateAgentState(activeWorker, { workerRunState: "idle" });
|
||||
}
|
||||
}, [agentStates, activeWorker, updateAgentState]);
|
||||
}, [agentStates, activeWorker, appendSystemMessage, updateAgentState]);
|
||||
|
||||
const canRunLoadedWorker = canShowRunButton(
|
||||
activeAgentState?.sessionId,
|
||||
activeAgentState?.ready,
|
||||
activeAgentState?.queenPhase,
|
||||
Boolean(
|
||||
activeAgentState?.nodeSpecs?.length ||
|
||||
sessionsByAgent[activeWorker]?.some(
|
||||
(session) =>
|
||||
session.id === activeAgentState?.sessionId && session.graphNodes.length > 0,
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
// --- Fetch discovered agents for NewTabPopover ---
|
||||
const [discoverAgents, setDiscoverAgents] = useState<DiscoverEntry[]>([]);
|
||||
@@ -2826,6 +2890,55 @@ export default function Workspace() {
|
||||
|
||||
// --- handleMultiQuestionAnswer: submit answers to ask_user_multiple ---
|
||||
const handleMultiQuestionAnswer = useCallback((answers: Record<string, string>) => {
|
||||
const state = agentStates[activeWorker];
|
||||
if (state?.pendingQuestionSource === "run") {
|
||||
if (!state.sessionId || !state.ready) return;
|
||||
updateAgentState(activeWorker, {
|
||||
pendingQuestion: null,
|
||||
pendingOptions: null,
|
||||
pendingQuestions: null,
|
||||
pendingQuestionSource: null,
|
||||
awaitingInput: false,
|
||||
workerRunState: "deploying",
|
||||
});
|
||||
const requiredRunKeys = getStructuredRunInputKeys(
|
||||
state.nodeSpecs,
|
||||
sessionsRef.current[activeWorker]?.find((s) => s.id === state.sessionId)?.graphNodes || [],
|
||||
);
|
||||
const trimmedAnswers = trimStructuredRunInputs(requiredRunKeys, answers);
|
||||
executionApi.trigger(state.sessionId, "default", trimmedAnswers).then((result) => {
|
||||
updateAgentState(activeWorker, {
|
||||
currentExecutionId: result.execution_id,
|
||||
lastRunInputData: trimmedAnswers,
|
||||
});
|
||||
}).catch((err: unknown) => {
|
||||
if (err instanceof ApiError && err.status === 424) {
|
||||
const errBody = err.body as Record<string, unknown>;
|
||||
const credPath = (errBody?.agent_path as string) || null;
|
||||
if (credPath) setCredentialAgentPath(credPath);
|
||||
updateAgentState(activeWorker, {
|
||||
workerRunState: "idle",
|
||||
error: "credentials_required",
|
||||
lastRunInputData: trimmedAnswers,
|
||||
});
|
||||
setCredentialsOpen(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const errMsg = err instanceof Error ? err.message : String(err);
|
||||
appendSystemMessage(activeWorker, `Failed to trigger run: ${errMsg}`);
|
||||
updateAgentState(activeWorker, {
|
||||
workerRunState: "idle",
|
||||
awaitingInput: true,
|
||||
pendingQuestion: null,
|
||||
pendingOptions: null,
|
||||
pendingQuestions: buildStructuredRunQuestions(requiredRunKeys),
|
||||
pendingQuestionSource: "run",
|
||||
});
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
updateAgentState(activeWorker, {
|
||||
pendingQuestion: null, pendingOptions: null,
|
||||
pendingQuestions: null, pendingQuestionSource: null,
|
||||
@@ -2835,7 +2948,7 @@ export default function Workspace() {
|
||||
([id, answer]) => `[${id}]: ${answer}`,
|
||||
);
|
||||
handleSend(lines.join("\n"), activeWorker);
|
||||
}, [activeWorker, handleSend, updateAgentState]);
|
||||
}, [activeWorker, agentStates, appendSystemMessage, handleSend, updateAgentState]);
|
||||
|
||||
// --- handleQuestionDismiss: user closed the question widget without answering ---
|
||||
// Injects a dismiss signal so the blocked node can continue.
|
||||
@@ -2854,6 +2967,11 @@ export default function Workspace() {
|
||||
awaitingInput: false,
|
||||
});
|
||||
|
||||
if (source === "run") {
|
||||
updateAgentState(activeWorker, { workerRunState: "idle" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Unblock the waiting node with a dismiss signal
|
||||
const dismissMsg = `[User dismissed the question: "${question}"]`;
|
||||
if (source === "worker") {
|
||||
@@ -3145,7 +3263,7 @@ export default function Workspace() {
|
||||
: null
|
||||
}
|
||||
building={activeAgentState?.queenBuilding}
|
||||
onRun={handleRun}
|
||||
onRun={canRunLoadedWorker ? handleRun : undefined}
|
||||
onPause={handlePause}
|
||||
runState={activeAgentState?.workerRunState ?? "idle"}
|
||||
flowchartMap={activeAgentState?.flowchartMap ?? undefined}
|
||||
|
||||
@@ -33,7 +33,7 @@ async def test_codex_stream():
|
||||
print(f"extra_kwargs keys: {list(extra_kwargs.keys())}")
|
||||
print(f"extra_headers: {list(extra_kwargs.get('extra_headers', {}).keys())}")
|
||||
|
||||
model = "openai/gpt-5.3-codex"
|
||||
model = "openai/gpt-5.4"
|
||||
|
||||
# Create the provider
|
||||
provider = LiteLLMProvider(
|
||||
|
||||
@@ -33,7 +33,7 @@ async def main():
|
||||
return
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.3-codex",
|
||||
model="openai/gpt-5.4",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.server.queen_orchestrator import _client_input_counts_as_planning_ask
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Minimal streaming LLM for Codex-vs-control parity checks."""
|
||||
|
||||
def __init__(self, scenarios: list[list[Any]] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools=None,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def text_scenario(text: str) -> list[Any]:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
*,
|
||||
tool_use_id: str = "call_1",
|
||||
preamble_text: str = "",
|
||||
) -> list[Any]:
|
||||
events: list[Any] = []
|
||||
if preamble_text:
|
||||
events.append(TextDeltaEvent(content=preamble_text, snapshot=preamble_text))
|
||||
events.append(
|
||||
ToolCallEvent(
|
||||
tool_use_id=tool_use_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
|
||||
def build_ctx(spec: NodeSpec, llm: LLMProvider, *, stream_id: str) -> NodeContext:
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value=f"session_{stream_id}")
|
||||
runtime.decide = MagicMock(return_value="dec_1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=spec.id,
|
||||
node_spec=spec,
|
||||
memory=SharedMemory(),
|
||||
input_data={},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("style", "first_turn"),
|
||||
[
|
||||
(
|
||||
"control",
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{
|
||||
"question": "What kind of agent should I design for you?",
|
||||
"options": ["Summarizer"],
|
||||
},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
),
|
||||
(
|
||||
"codex",
|
||||
text_scenario("What kind of agent should I design for you?"),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_codex_and_control_styles_both_count_toward_planning_gate(
|
||||
style: str,
|
||||
first_turn: list[Any],
|
||||
) -> None:
|
||||
bus = EventBus()
|
||||
phase_state = QueenPhaseState(phase="planning", event_bus=bus)
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def capture(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
if _client_input_counts_as_planning_ask(event):
|
||||
phase_state.planning_ask_rounds += 1
|
||||
|
||||
bus.subscribe([EventType.CLIENT_INPUT_REQUESTED], capture, filter_stream="queen")
|
||||
|
||||
spec = NodeSpec(
|
||||
id="queen",
|
||||
name="Queen",
|
||||
description="planning orchestrator",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
output_keys=[],
|
||||
skip_judge=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[first_turn])
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(spec, llm, stream_id="queen")
|
||||
|
||||
async def shutdown_after_first_block() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown_after_first_block())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert phase_state.planning_ask_rounds == 1
|
||||
assert received
|
||||
if style == "control":
|
||||
assert received[0].data["prompt"] == "What kind of agent should I design for you?"
|
||||
assert received[0].data.get("auto_blocked") is not True
|
||||
else:
|
||||
assert received[0].data["prompt"] == ""
|
||||
assert received[0].data["auto_blocked"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("style", "scenarios"),
|
||||
[
|
||||
(
|
||||
"control",
|
||||
[
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{
|
||||
"question": "Paste old and new policy text.",
|
||||
"options": ["I'll paste both now"],
|
||||
},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{
|
||||
"key": "important_changes",
|
||||
"value": "- Remote days increased from 2 to 4",
|
||||
},
|
||||
tool_use_id="set_1",
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
"codex",
|
||||
[
|
||||
text_scenario("Paste old and new policy text."),
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{
|
||||
"key": "important_changes",
|
||||
"value": "- Remote days increased from 2 to 4",
|
||||
},
|
||||
tool_use_id="set_1",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_codex_and_control_styles_complete_same_human_in_loop_run(
|
||||
style: str,
|
||||
scenarios: list[list[Any]],
|
||||
) -> None:
|
||||
spec = NodeSpec(
|
||||
id=f"policy_diff_{style}",
|
||||
name="Policy Diff Worker",
|
||||
description="Compare two policy versions",
|
||||
node_type="event_loop",
|
||||
output_keys=["important_changes"],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=scenarios)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=6))
|
||||
ctx = build_ctx(spec, llm, stream_id=f"worker_{style}")
|
||||
|
||||
async def user_responds() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Old policy ... New policy ...")
|
||||
|
||||
task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["important_changes"] == "- Remote days increased from 2 to 4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("style", "scenario"),
|
||||
[
|
||||
(
|
||||
"control",
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{"question": "What would you like to do next?", "options": ["Rerun", "Stop"]},
|
||||
tool_use_id="ask_1",
|
||||
preamble_text="Root cause: checkout is failing because the DB pool is exhausted.",
|
||||
),
|
||||
),
|
||||
(
|
||||
"codex",
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{
|
||||
"question": (
|
||||
"Root cause: checkout is failing because the DB pool is exhausted.\n\n"
|
||||
"What would you like to do next?"
|
||||
),
|
||||
"options": ["Rerun", "Stop"],
|
||||
},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_codex_and_control_styles_surface_result_before_followup_widget(
|
||||
style: str,
|
||||
scenario: list[Any],
|
||||
) -> None:
|
||||
spec = NodeSpec(
|
||||
id=f"queen_{style}",
|
||||
name="Queen",
|
||||
description="orchestrator",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
output_keys=[],
|
||||
skip_judge=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[scenario])
|
||||
bus = EventBus()
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def capture(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=capture,
|
||||
)
|
||||
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(spec, llm, stream_id="queen")
|
||||
|
||||
async def shutdown() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA]
|
||||
input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED]
|
||||
|
||||
assert output_events
|
||||
assert input_events
|
||||
assert "DB pool is exhausted" in output_events[0].data["snapshot"]
|
||||
assert input_events[0].data["prompt"] == "What would you like to do next?"
|
||||
@@ -0,0 +1,448 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
import framework.tools.queen_lifecycle_tools as qlt
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.server.app import create_app, validate_agent_path
|
||||
from framework.server.session_manager import (
|
||||
Session,
|
||||
_run_validation_report_sync,
|
||||
_validation_blocks_stage_or_run,
|
||||
)
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Minimal streaming LLM for parity-gate regressions."""
|
||||
|
||||
def __init__(self, scenarios: list[list[Any]] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools=None,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def text_scenario(text: str) -> list[Any]:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def tool_call_scenario(
|
||||
tool_name: str,
|
||||
tool_input: dict[str, Any],
|
||||
*,
|
||||
tool_use_id: str = "call_1",
|
||||
) -> list[Any]:
|
||||
return [
|
||||
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input),
|
||||
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def build_ctx(
|
||||
spec: NodeSpec,
|
||||
llm: LLMProvider,
|
||||
*,
|
||||
stream_id: str = "worker",
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> NodeContext:
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value="session_codex_parity")
|
||||
runtime.decide = MagicMock(return_value="dec_1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=spec.id,
|
||||
node_spec=spec,
|
||||
memory=SharedMemory(),
|
||||
input_data=input_data or {},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"agent_ref",
|
||||
[
|
||||
"examples/templates/tech_news_reporter",
|
||||
"examples/templates/vulnerability_assessment",
|
||||
],
|
||||
)
|
||||
def test_codex_parity_existing_templates_validate_for_stage_run(agent_ref: str) -> None:
|
||||
"""Existing checked-in agents should pass the shared stage/run gate."""
|
||||
resolved = validate_agent_path(agent_ref)
|
||||
report = _run_validation_report_sync(agent_ref)
|
||||
|
||||
assert resolved.is_dir()
|
||||
assert report.get("valid") is True
|
||||
assert _validation_blocks_stage_or_run(report) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_parity_local_only_human_in_loop_run_completes() -> None:
|
||||
"""A local-only client-facing worker flow should complete end to end."""
|
||||
spec = NodeSpec(
|
||||
id="policy_diff_worker",
|
||||
name="Policy Diff Worker",
|
||||
description="Compare two policy versions",
|
||||
node_type="event_loop",
|
||||
output_keys=["important_changes"],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{"question": "Paste old and new policy text.", "options": ["I'll paste both now"]},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{
|
||||
"key": "important_changes",
|
||||
"value": (
|
||||
"- Remote days increased from 2 to 4\n"
|
||||
"- Security training increased from annual to twice yearly"
|
||||
),
|
||||
},
|
||||
tool_use_id="set_1",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=6))
|
||||
ctx = build_ctx(spec, llm, stream_id="worker")
|
||||
|
||||
async def user_responds() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Old policy ... New policy ...")
|
||||
|
||||
task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert "Remote days increased" in result.output["important_changes"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_parity_result_is_visible_before_followup_widget() -> None:
|
||||
"""Long result-bearing queen prompts should stream the result before the widget."""
|
||||
spec = NodeSpec(
|
||||
id="queen",
|
||||
name="Queen",
|
||||
description="orchestrator",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
output_keys=[],
|
||||
skip_judge=True,
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{
|
||||
"question": (
|
||||
"Root cause: checkout is failing because the DB pool is exhausted.\n\n"
|
||||
"What would you like to do next?"
|
||||
),
|
||||
"options": ["Rerun", "Stop"],
|
||||
},
|
||||
tool_use_id="ask_1",
|
||||
)
|
||||
]
|
||||
)
|
||||
bus = EventBus()
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def capture(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=capture,
|
||||
)
|
||||
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(spec, llm, stream_id="queen")
|
||||
|
||||
async def shutdown() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA]
|
||||
input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED]
|
||||
|
||||
assert output_events
|
||||
assert input_events
|
||||
assert "DB pool is exhausted" in output_events[0].data["snapshot"]
|
||||
assert input_events[0].data["prompt"] == "What would you like to do next?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_parity_rerun_reuses_complete_recent_defaults(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""Rerun should keep structured inputs stable instead of relying on text reconstruction."""
|
||||
registry = ToolRegistry()
|
||||
registry.register(
|
||||
"validate_agent_package",
|
||||
Tool(
|
||||
name="validate_agent_package",
|
||||
description="fake validator",
|
||||
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
|
||||
),
|
||||
lambda _inputs: json.dumps({"valid": True, "steps": {}}),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
|
||||
valid_prior_state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
|
||||
"input_data": {
|
||||
"target_dir": "docs",
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": 800,
|
||||
},
|
||||
}
|
||||
malformed_recent_state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T21:20:23"},
|
||||
"input_data": {
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": "800. Validate inputs and continue.",
|
||||
},
|
||||
}
|
||||
|
||||
for session_name, state in {
|
||||
"session_20260324_204400_good": valid_prior_state,
|
||||
"session_20260324_212023_bad": malformed_recent_state,
|
||||
}.items():
|
||||
session_dir = sessions_dir / session_name
|
||||
session_dir.mkdir()
|
||||
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-rerun"),
|
||||
graph=SimpleNamespace(
|
||||
nodes=[],
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/local_markdown_review_probe_2"),
|
||||
runner=None,
|
||||
)
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-rerun",
|
||||
phase_state=QueenPhaseState(phase="staging"),
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert result["status"] == "started"
|
||||
runtime.trigger.assert_awaited_once()
|
||||
assert runtime.trigger.await_args.kwargs["input_data"] == {
|
||||
"target_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 800,
|
||||
}
|
||||
assert runtime.trigger.await_args.kwargs["session_state"] is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MockEntryPoint:
|
||||
id: str = "default"
|
||||
name: str = "Default"
|
||||
entry_node: str = "start"
|
||||
trigger_type: str = "manual"
|
||||
trigger_config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MockStream:
|
||||
is_awaiting_input: bool = False
|
||||
_execution_tasks: dict = field(default_factory=dict)
|
||||
_active_executors: dict = field(default_factory=dict)
|
||||
active_execution_ids: set = field(default_factory=set)
|
||||
|
||||
async def cancel_execution(self, execution_id: str) -> bool:
|
||||
return execution_id in self._execution_tasks
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MockGraphRegistration:
|
||||
graph: Any = field(default_factory=lambda: SimpleNamespace(nodes=[], edges=[], entry_node=""))
|
||||
streams: dict = field(default_factory=dict)
|
||||
entry_points: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class _MockRuntime:
|
||||
def __init__(self):
|
||||
self._entry_points = [_MockEntryPoint()]
|
||||
self._mock_streams = {"default": _MockStream()}
|
||||
self._registration = _MockGraphRegistration(
|
||||
streams=self._mock_streams,
|
||||
entry_points={"default": self._entry_points[0]},
|
||||
)
|
||||
|
||||
def list_graphs(self):
|
||||
return ["primary"]
|
||||
|
||||
def get_graph_registration(self, graph_id):
|
||||
if graph_id == "primary":
|
||||
return self._registration
|
||||
return None
|
||||
|
||||
def get_entry_points(self):
|
||||
return self._entry_points
|
||||
|
||||
async def trigger(self, ep_id, input_data=None, session_state=None):
|
||||
return "exec_test_123"
|
||||
|
||||
async def inject_input(self, node_id, content, graph_id=None, *, is_client_input=False):
|
||||
return True
|
||||
|
||||
def pause_timers(self):
|
||||
pass
|
||||
|
||||
async def get_goal_progress(self):
|
||||
return {"progress": 0.5, "criteria": []}
|
||||
|
||||
def find_awaiting_node(self):
|
||||
return None, None
|
||||
|
||||
def get_stats(self):
|
||||
return {"running": True, "executions": 1}
|
||||
|
||||
def get_timer_next_fire_in(self, ep_id):
|
||||
return None
|
||||
|
||||
|
||||
def _make_queen_executor():
|
||||
mock_node = MagicMock()
|
||||
mock_node.inject_event = AsyncMock()
|
||||
executor = MagicMock()
|
||||
executor.node_registry = {"queen": mock_node}
|
||||
return executor
|
||||
|
||||
|
||||
def _make_session(agent_id="test_agent") -> Session:
|
||||
runner = MagicMock()
|
||||
runner.intro_message = "Test intro"
|
||||
return Session(
|
||||
id=agent_id,
|
||||
event_bus=EventBus(),
|
||||
llm=MagicMock(),
|
||||
loaded_at=1000000.0,
|
||||
queen_executor=_make_queen_executor(),
|
||||
worker_id=agent_id,
|
||||
worker_path=Path("/tmp/test_agent"),
|
||||
runner=runner,
|
||||
worker_runtime=_MockRuntime(),
|
||||
worker_info=SimpleNamespace(
|
||||
name="test_agent",
|
||||
description="A test agent",
|
||||
goal_name="test_goal",
|
||||
node_count=2,
|
||||
),
|
||||
worker_validation_report={"valid": True, "steps": {}},
|
||||
worker_validation_failures=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_app_with_session(session: Session):
|
||||
app = create_app()
|
||||
mgr = app["manager"]
|
||||
mgr._sessions[session.id] = session
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_parity_done_for_now_parks_queen_without_new_followup() -> None:
|
||||
"""Terminal stop choices should acknowledge once and park the queen."""
|
||||
session = _make_session()
|
||||
session.event_bus.get_history = MagicMock(
|
||||
return_value=[
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id="queen",
|
||||
node_id="queen",
|
||||
execution_id=session.id,
|
||||
data={"options": ["Run again with same input", "Done for now"]},
|
||||
)
|
||||
]
|
||||
)
|
||||
session.event_bus.emit_client_output_delta = AsyncMock()
|
||||
app = _make_app_with_session(session)
|
||||
|
||||
async with TestClient(TestServer(app)) as client:
|
||||
resp = await client.post(
|
||||
"/api/sessions/test_agent/chat",
|
||||
json={"message": "No, stop here"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["status"] == "queen"
|
||||
assert data["delivered"] is True
|
||||
|
||||
assert session.queen_executor is None
|
||||
session.event_bus.emit_client_output_delta.assert_awaited_once()
|
||||
@@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
|
||||
from framework.graph.node import NodeContext, NodeSpec, SharedMemory
|
||||
from framework.llm.provider import LLMProvider
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.server.queen_orchestrator import _client_input_counts_as_planning_ask
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools
|
||||
|
||||
|
||||
class MockStreamingLLM(LLMProvider):
|
||||
"""Minimal streaming LLM for planning-phase regression tests."""
|
||||
|
||||
def __init__(self, scenarios: list[list[Any]] | None = None):
|
||||
self.scenarios = scenarios or []
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools=None,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
if not self.scenarios:
|
||||
return
|
||||
events = self.scenarios[self._call_index % len(self.scenarios)]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def text_scenario(text: str) -> list[Any]:
|
||||
return [
|
||||
TextDeltaEvent(content=text, snapshot=text),
|
||||
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
|
||||
]
|
||||
|
||||
|
||||
def build_ctx(spec: NodeSpec, llm: LLMProvider) -> NodeContext:
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value="session_codex_planning")
|
||||
runtime.decide = MagicMock(return_value="dec_1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
runtime.report_problem = MagicMock()
|
||||
runtime.set_node = MagicMock()
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=spec.id,
|
||||
node_spec=spec,
|
||||
memory=SharedMemory(),
|
||||
input_data={"greeting": "Session started."},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
stream_id="queen",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_style_text_only_planning_turn_counts_toward_ask_rounds() -> None:
|
||||
"""Plain-text planning questions should satisfy the ask_rounds gate.
|
||||
|
||||
This reproduces the Codex failure mode: the queen asks a planning question
|
||||
in plain text instead of calling ask_user(), which triggers an auto-blocked
|
||||
CLIENT_INPUT_REQUESTED event with an empty prompt.
|
||||
"""
|
||||
bus = EventBus()
|
||||
phase_state = QueenPhaseState(phase="planning", event_bus=bus)
|
||||
received: list[AgentEvent] = []
|
||||
|
||||
async def capture(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
if _client_input_counts_as_planning_ask(event):
|
||||
phase_state.planning_ask_rounds += 1
|
||||
|
||||
bus.subscribe([EventType.CLIENT_INPUT_REQUESTED], capture, filter_stream="queen")
|
||||
|
||||
spec = NodeSpec(
|
||||
id="queen",
|
||||
name="Queen",
|
||||
description="planning orchestrator",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
output_keys=[],
|
||||
skip_judge=True,
|
||||
)
|
||||
llm = MockStreamingLLM(scenarios=[text_scenario("What kind of agent should I design for you?")])
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(spec, llm)
|
||||
|
||||
async def shutdown_after_first_block() -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown_after_first_block())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert len(received) >= 1
|
||||
assert received[0].data["prompt"] == ""
|
||||
assert received[0].data["auto_blocked"] is True
|
||||
assert received[0].data["assistant_text_present"] is True
|
||||
assert received[0].data["assistant_text_requires_input"] is True
|
||||
assert phase_state.planning_ask_rounds == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_agent_draft_accepts_two_codex_style_planning_rounds() -> None:
|
||||
"""Two counted auto-blocked planning turns should unlock save_agent_draft()."""
|
||||
phase_state = QueenPhaseState(phase="planning")
|
||||
|
||||
codex_style_event = AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"prompt": "",
|
||||
"auto_blocked": True,
|
||||
"assistant_text_present": True,
|
||||
"assistant_text_requires_input": True,
|
||||
},
|
||||
)
|
||||
for _ in range(2):
|
||||
if _client_input_counts_as_planning_ask(codex_style_event):
|
||||
phase_state.planning_ask_rounds += 1
|
||||
|
||||
registry = ToolRegistry()
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=None,
|
||||
event_bus=None,
|
||||
worker_path=None,
|
||||
runner=None,
|
||||
)
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="session_codex_planning",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
save_draft = registry._tools["save_agent_draft"].executor
|
||||
result_raw = await save_draft(
|
||||
{
|
||||
"agent_name": "codex_planning_repro",
|
||||
"goal": "Reproduce the planning gate.",
|
||||
"nodes": [
|
||||
{"id": "start"},
|
||||
{"id": "discover"},
|
||||
{"id": "plan"},
|
||||
{"id": "review"},
|
||||
{"id": "finish"},
|
||||
],
|
||||
"edges": [
|
||||
{"source": "start", "target": "discover"},
|
||||
{"source": "discover", "target": "plan"},
|
||||
{"source": "plan", "target": "review"},
|
||||
{"source": "review", "target": "finish"},
|
||||
],
|
||||
}
|
||||
)
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert phase_state.planning_ask_rounds == 2
|
||||
assert result["status"] == "draft_saved"
|
||||
|
||||
|
||||
def test_status_only_auto_block_does_not_count_toward_planning_ask_rounds() -> None:
|
||||
"""Auto-blocked acknowledgements should not satisfy the planning ask gate."""
|
||||
event = AgentEvent(
|
||||
type=EventType.CLIENT_INPUT_REQUESTED,
|
||||
stream_id="queen",
|
||||
data={
|
||||
"prompt": "",
|
||||
"auto_blocked": True,
|
||||
"assistant_text_present": True,
|
||||
"assistant_text_requires_input": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert _client_input_counts_as_planning_ask(event) is False
|
||||
@@ -1,8 +1,15 @@
|
||||
"""Tests for framework/config.py - Hive configuration loading."""
|
||||
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
from framework.config import get_api_base, get_hive_config, get_preferred_model
|
||||
from framework.config import (
|
||||
get_api_base,
|
||||
get_hive_config,
|
||||
get_llm_extra_kwargs,
|
||||
get_preferred_model,
|
||||
)
|
||||
from framework.llm.codex_backend import CODEX_API_BASE, is_codex_api_base, normalize_codex_api_base
|
||||
|
||||
|
||||
class TestGetHiveConfig:
|
||||
@@ -59,9 +66,65 @@ class TestOpenRouterConfig:
|
||||
def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}',
|
||||
(
|
||||
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta",'
|
||||
'"api_base":"https://proxy.example/v1"}}'
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == "https://proxy.example/v1"
|
||||
|
||||
|
||||
class TestCodexConfig:
|
||||
"""Codex config helpers should share the same transport defaults."""
|
||||
|
||||
def test_get_api_base_uses_shared_codex_backend(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openai","model":"gpt-5.4","use_codex_subscription":true}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
assert get_api_base() == CODEX_API_BASE
|
||||
|
||||
def test_get_llm_extra_kwargs_uses_shared_codex_transport(self, tmp_path, monkeypatch):
|
||||
config_file = tmp_path / "configuration.json"
|
||||
config_file.write_text(
|
||||
'{"llm":{"provider":"openai","model":"gpt-5.4","use_codex_subscription":true}}',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
|
||||
|
||||
with (
|
||||
patch("framework.runner.runner.get_codex_token", return_value="tok_test"),
|
||||
patch("framework.runner.runner.get_codex_account_id", return_value="acct_123"),
|
||||
):
|
||||
kwargs = get_llm_extra_kwargs()
|
||||
|
||||
assert kwargs["store"] is False
|
||||
assert kwargs["allowed_openai_params"] == ["store"]
|
||||
assert kwargs["extra_headers"] == {
|
||||
"Authorization": "Bearer tok_test",
|
||||
"User-Agent": "CodexBar",
|
||||
"ChatGPT-Account-Id": "acct_123",
|
||||
}
|
||||
|
||||
def test_codex_api_base_detection_requires_real_chatgpt_origin(self):
|
||||
assert is_codex_api_base("https://chatgpt.com/backend-api/codex")
|
||||
assert is_codex_api_base("https://chatgpt.com/backend-api/codex/responses")
|
||||
assert not is_codex_api_base(
|
||||
"https://proxy.example/v1?target=https://chatgpt.com/backend-api/codex"
|
||||
)
|
||||
|
||||
def test_normalize_codex_api_base_strips_only_real_responses_suffix(self):
|
||||
assert (
|
||||
normalize_codex_api_base("https://chatgpt.com/backend-api/codex/responses")
|
||||
== CODEX_API_BASE
|
||||
)
|
||||
assert (
|
||||
normalize_codex_api_base("https://proxy.example/v1/responses")
|
||||
== "https://proxy.example/v1/responses"
|
||||
)
|
||||
|
||||
@@ -224,6 +224,12 @@ class TestUpdateSystemPrompt:
|
||||
conv.update_system_prompt("updated")
|
||||
assert conv.system_prompt == "updated"
|
||||
|
||||
def test_update_replaces_output_keys(self):
|
||||
conv = NodeConversation(system_prompt="original", output_keys=["brief"])
|
||||
conv.update_system_prompt("updated", output_keys=["articles_data"])
|
||||
assert conv.system_prompt == "updated"
|
||||
assert conv._output_keys == ["articles_data"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Conversation threading through executor
|
||||
@@ -372,6 +378,61 @@ class TestContinuousConversation:
|
||||
)
|
||||
assert "PHASE TRANSITION" in all_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transition_marker_uses_next_node_tools_not_stale_previous_tools(self):
|
||||
runtime = _make_runtime()
|
||||
web_scrape = _make_tool("web_scrape")
|
||||
file_tool = _make_tool("read_file")
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
_text_then_set_output("Intake done.", "brief", "enterprise ai"),
|
||||
_text_finish(""),
|
||||
_text_then_set_output("Research done.", "articles_data", '{"articles": []}'),
|
||||
_text_finish(""),
|
||||
]
|
||||
)
|
||||
|
||||
node_a = NodeSpec(
|
||||
id="a",
|
||||
name="Intake",
|
||||
description="Collect user preference",
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
output_keys=["brief"],
|
||||
)
|
||||
node_b = NodeSpec(
|
||||
id="b",
|
||||
name="Research",
|
||||
description="Scrape recent articles",
|
||||
node_type="event_loop",
|
||||
input_keys=["brief"],
|
||||
output_keys=["articles_data"],
|
||||
tools=["web_scrape"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="g1",
|
||||
goal_id="g1",
|
||||
entry_node="a",
|
||||
nodes=[node_a, node_b],
|
||||
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
|
||||
terminal_nodes=["b"],
|
||||
conversation_mode="continuous",
|
||||
)
|
||||
|
||||
executor = GraphExecutor(runtime=runtime, llm=llm, tools=[file_tool, web_scrape])
|
||||
result = await executor.execute(graph=graph, goal=_make_goal())
|
||||
assert result.success
|
||||
|
||||
node_b_messages = llm.stream_calls[2]["messages"]
|
||||
all_content = " ".join(
|
||||
m.get("content", "") for m in node_b_messages if isinstance(m.get("content"), str)
|
||||
)
|
||||
assert "Available tools:" in all_content
|
||||
assert "web_scrape" in all_content
|
||||
assert "set_output" in all_content
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Cumulative tools
|
||||
|
||||
@@ -7,6 +7,7 @@ that yields pre-programmed StreamEvents to control the loop deterministically.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
@@ -348,6 +349,143 @@ class TestSetOutput:
|
||||
assert result.output["result"] == "ok"
|
||||
assert "bad_key" not in result.output
|
||||
|
||||
def test_set_output_rejects_identical_duplicate_value(self):
|
||||
"""Identical repeated set_output calls should be treated as an error, not progress."""
|
||||
node = EventLoopNode()
|
||||
|
||||
result = node._handle_set_output(
|
||||
{"key": "result", "value": "42"},
|
||||
["result"],
|
||||
missing_keys=["result", "summary"],
|
||||
current_value=42,
|
||||
normalized_value=42,
|
||||
)
|
||||
|
||||
assert result.is_error is True
|
||||
assert "already set to the same value" in result.content
|
||||
assert "summary" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_auto_completes_non_client_facing_node(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""A worker node should finish immediately once required outputs are set."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "done"
|
||||
assert len(llm.stream_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_output_auto_completes_client_facing_node(
|
||||
self,
|
||||
runtime,
|
||||
memory,
|
||||
):
|
||||
"""Client-facing nodes should also finish once required outputs are set."""
|
||||
spec = NodeSpec(
|
||||
id="review",
|
||||
name="Review",
|
||||
description="client-facing review node",
|
||||
node_type="event_loop",
|
||||
output_keys=["decision"],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario("set_output", {"key": "decision", "value": "approve"}),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["decision"] == "approve"
|
||||
assert len(llm.stream_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_facing_completes_immediately_after_user_reply_sets_all_outputs(
|
||||
self,
|
||||
runtime,
|
||||
memory,
|
||||
):
|
||||
"""Client-facing nodes should finish once a post-user-reply turn sets all outputs."""
|
||||
spec = NodeSpec(
|
||||
id="findings-review",
|
||||
name="Findings Review",
|
||||
description="review findings",
|
||||
node_type="event_loop",
|
||||
output_keys=["continue_scanning", "feedback", "all_findings"],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{
|
||||
"question": "Continue scanning or generate report?",
|
||||
"options": ["Continue", "Report"],
|
||||
},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_use_id="set_continue",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "continue_scanning", "value": "false"},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="set_feedback",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "feedback", "value": "generate final report"},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="set_all",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "all_findings", "value": '{"ok": true}'},
|
||||
),
|
||||
FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
async def user_responds():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Generate the report")
|
||||
|
||||
task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output == {
|
||||
"continue_scanning": False,
|
||||
"feedback": "generate final report",
|
||||
"all_findings": {"ok": True},
|
||||
}
|
||||
assert len(llm.stream_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
|
||||
"""Judge accepts but output keys are missing -> retry with hint."""
|
||||
@@ -399,6 +537,47 @@ class TestStallDetection:
|
||||
assert result.success is False
|
||||
assert "stalled" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_progress_churn_detection(self, runtime, node_spec, memory):
|
||||
"""Different text with missing outputs should still fail if nothing progresses."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("Reviewing the logs and thinking through the issue."),
|
||||
text_scenario("I am narrowing down the likely cause."),
|
||||
text_scenario("I have more context and am still analyzing."),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10, stall_detection_threshold=3))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "no-progress loop detected" in (result.error or "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_progress_counter_resets_after_output_progress(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""A real output-setting turn should reset the no-progress churn counter."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
text_scenario("Parsing the problem statement."),
|
||||
text_scenario("Extracting the key facts now."),
|
||||
tool_call_scenario("set_output", {"key": "result", "value": "triaged"}),
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10, stall_detection_threshold=3))
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "triaged"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# EventBus lifecycle events
|
||||
@@ -647,6 +826,57 @@ class TestClientFacingBlocking:
|
||||
assert len(received) >= 1
|
||||
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queen_long_ask_user_prompt_surfaces_result_before_widget(
|
||||
self, runtime, memory, client_spec
|
||||
):
|
||||
"""Long queen prompts should stream visible result text before options."""
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{
|
||||
"question": (
|
||||
"Root cause: checkout requests are failing because the DB pool is "
|
||||
"exhausted and cart reads are timing out.\n\n"
|
||||
"What would you like to do next?"
|
||||
),
|
||||
"options": ["Rerun", "Stop"],
|
||||
},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
]
|
||||
)
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def capture(e):
|
||||
received.append(e)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=capture,
|
||||
)
|
||||
|
||||
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
|
||||
ctx = build_ctx(runtime, client_spec, memory, llm, stream_id="queen")
|
||||
|
||||
async def shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(shutdown())
|
||||
await node.execute(ctx)
|
||||
await task
|
||||
|
||||
output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA]
|
||||
input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED]
|
||||
|
||||
assert output_events
|
||||
assert input_events
|
||||
assert "Root cause: checkout requests are failing" in output_events[0].data["snapshot"]
|
||||
assert input_events[0].data["prompt"] == "What would you like to do next?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
|
||||
async def test_ask_user_with_real_tools(self, runtime, memory):
|
||||
@@ -905,9 +1135,79 @@ class TestEscalate:
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["result"] == "resolved after queen guidance"
|
||||
assert judge.evaluate.await_count >= 1
|
||||
assert judge.evaluate.await_count == 0
|
||||
assert len(client_input_events) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalate_then_complete_outputs_autocompletes_without_extra_turn(
|
||||
self, runtime, memory
|
||||
):
|
||||
"""Worker nodes should still auto-complete after queen guidance once outputs are set."""
|
||||
spec = NodeSpec(
|
||||
id="csv-intake",
|
||||
name="CSV Intake",
|
||||
description="parse csv",
|
||||
node_type="event_loop",
|
||||
output_keys=["original_headers", "parsed_rows", "raw_csv_text"],
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
tool_call_scenario(
|
||||
"escalate",
|
||||
{"reason": "need csv", "context": "input malformed"},
|
||||
tool_use_id="esc_1",
|
||||
),
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_use_id="set_headers",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "original_headers", "value": '["name","email"]'},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="set_rows",
|
||||
tool_name="set_output",
|
||||
tool_input={
|
||||
"key": "parsed_rows",
|
||||
"value": '[{"name":"Alice","email":"alice@example.com"}]',
|
||||
},
|
||||
),
|
||||
ToolCallEvent(
|
||||
tool_use_id="set_raw",
|
||||
tool_name="set_output",
|
||||
tool_input={
|
||||
"key": "raw_csv_text",
|
||||
"value": "name,email\\nAlice,alice@example.com",
|
||||
},
|
||||
),
|
||||
FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, spec, memory, llm, stream_id="worker")
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=5))
|
||||
|
||||
async def queen_reply():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Use the sample CSV and continue.")
|
||||
|
||||
task = asyncio.create_task(queen_reply())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output == {
|
||||
"original_headers": ["name", "email"],
|
||||
"parsed_rows": [{"name": "Alice", "email": "alice@example.com"}],
|
||||
"raw_csv_text": "name,email\\nAlice,alice@example.com",
|
||||
}
|
||||
assert len(llm.stream_calls) == 2
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Client-facing: _cf_expecting_work state machine
|
||||
@@ -1382,6 +1682,39 @@ class TestPauseResume:
|
||||
assert llm._call_index == 0
|
||||
|
||||
|
||||
class TestToolExecutionContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_preserves_contextvars_in_threadpool(self):
|
||||
marker = contextvars.ContextVar("marker", default="missing")
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=marker.get(),
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
node = EventLoopNode(
|
||||
tool_executor=tool_exec,
|
||||
config=LoopConfig(tool_call_timeout_seconds=5),
|
||||
)
|
||||
|
||||
token = marker.set("present")
|
||||
try:
|
||||
result = await node._execute_tool(
|
||||
ToolCallEvent(
|
||||
tool_use_id="tool-1",
|
||||
tool_name="echo_marker",
|
||||
tool_input={},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
marker.reset(token)
|
||||
|
||||
assert result.is_error is False
|
||||
assert result.content == "present"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Stream errors
|
||||
# ===========================================================================
|
||||
@@ -1783,6 +2116,17 @@ class TestFingerprintToolCalls:
|
||||
)
|
||||
|
||||
|
||||
class TestFingerprintSetOutputCalls:
|
||||
"""Unit tests for _fingerprint_set_output_calls()."""
|
||||
|
||||
def test_basic_fingerprint(self):
|
||||
results = [
|
||||
{"tool_name": "set_output", "tool_input": {"key": "result", "value": {"a": 1}}},
|
||||
]
|
||||
fps = EventLoopNode._fingerprint_set_output_calls(results)
|
||||
assert fps == [("result", '{"a": 1}')]
|
||||
|
||||
|
||||
class TestIsToolDoomLoop:
|
||||
"""Unit tests for _is_tool_doom_loop()."""
|
||||
|
||||
@@ -1821,6 +2165,25 @@ class TestIsToolDoomLoop:
|
||||
assert is_doom is False
|
||||
|
||||
|
||||
class TestIsOutputDoomLoop:
|
||||
"""Unit tests for _is_output_doom_loop()."""
|
||||
|
||||
def test_at_threshold_identical(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp = [("result", '"done"')]
|
||||
is_doom, desc = node._is_output_doom_loop([fp, fp, fp])
|
||||
assert is_doom is True
|
||||
assert "set_output" in desc
|
||||
|
||||
def test_different_values_no_doom(self):
|
||||
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
|
||||
fp1 = [("result", '"a"')]
|
||||
fp2 = [("result", '"b"')]
|
||||
fp3 = [("result", '"c"')]
|
||||
is_doom, _ = node._is_output_doom_loop([fp1, fp2, fp3])
|
||||
assert is_doom is False
|
||||
|
||||
|
||||
class ToolRepeatLLM(LLMProvider):
|
||||
"""LLM that produces identical tool calls across outer iterations.
|
||||
|
||||
@@ -1879,6 +2242,95 @@ class ToolRepeatLLM(LLMProvider):
|
||||
)
|
||||
|
||||
|
||||
class SetOutputRepeatLLM(LLMProvider):
|
||||
"""LLM that repeats the same set_output-only turn across iterations."""
|
||||
|
||||
def __init__(self, key: str, value: str, tool_turns: int, final_text: str = "done"):
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.tool_turns = tool_turns
|
||||
self.final_text = final_text
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
outer_iter = idx // 2
|
||||
is_tool_call = (idx % 2 == 0) and outer_iter < self.tool_turns
|
||||
if is_tool_call:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"set_{outer_iter}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": self.key, "value": self.value},
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
else:
|
||||
text = f"{self.final_text} (call {idx})"
|
||||
yield TextDeltaEvent(content=text, snapshot=text)
|
||||
yield FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
class VaryingSetOutputRepeatLLM(LLMProvider):
|
||||
"""LLM that repeats set_output turns with different values across iterations."""
|
||||
|
||||
def __init__(self, key: str, values: list[str], final_text: str = "done"):
|
||||
self.key = key
|
||||
self.values = values
|
||||
self.final_text = final_text
|
||||
self._call_index = 0
|
||||
|
||||
async def stream(self, messages, system="", tools=None, max_tokens=4096):
|
||||
idx = self._call_index
|
||||
self._call_index += 1
|
||||
outer_iter = idx // 2
|
||||
is_tool_call = (idx % 2 == 0) and outer_iter < len(self.values)
|
||||
if is_tool_call:
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"set_{outer_iter}",
|
||||
tool_name="set_output",
|
||||
tool_input={"key": self.key, "value": self.values[outer_iter]},
|
||||
)
|
||||
yield FinishEvent(
|
||||
stop_reason="tool_calls",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
else:
|
||||
text = f"{self.final_text} (call {idx})"
|
||||
yield TextDeltaEvent(content=text, snapshot=text)
|
||||
yield FinishEvent(
|
||||
stop_reason="stop",
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
model="mock",
|
||||
)
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="ok",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
class TestToolDoomLoopIntegration:
|
||||
"""Integration tests for doom loop detection in execute().
|
||||
|
||||
@@ -2263,7 +2715,97 @@ class TestToolDoomLoopIntegration:
|
||||
assert result.success is True
|
||||
# Doom loop MUST fire for repeatedly-failing tool calls
|
||||
assert len(doom_events) >= 1
|
||||
assert "failing_tool" in doom_events[0].data["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repeated_identical_set_output_turns_fail_fast(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Repeated identical set_output-only turns should fail instead of spinning forever."""
|
||||
node_spec.output_keys = ["result", "review_manifest"]
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
||||
|
||||
llm = SetOutputRepeatLLM("result", "same summary", tool_turns=4)
|
||||
bus = EventBus()
|
||||
doom_events: list = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
|
||||
handler=lambda e: doom_events.append(e),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm, tools=[])
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
stall_similarity_threshold=1.0,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "Output doom loop detected" in (result.error or "")
|
||||
assert doom_events
|
||||
assert "set_output" in doom_events[0].data["description"]
|
||||
assert "result" in doom_events[0].data["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_reset_set_output_turns_fail_fast(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""Fresh-payload reset chatter should trip the output doom loop guard."""
|
||||
node_spec.output_keys = ["rules", "candidates", "scan_stats"]
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
|
||||
|
||||
llm = VaryingSetOutputRepeatLLM(
|
||||
"rules",
|
||||
[
|
||||
(
|
||||
"New event acknowledged. Awaiting fresh request payload "
|
||||
"(phase transition details + structured inputs) to proceed."
|
||||
),
|
||||
(
|
||||
"Context reset complete. Awaiting fresh phase transition payload "
|
||||
"and structured inputs to proceed."
|
||||
),
|
||||
(
|
||||
"Ready for fresh request payload with phase transition "
|
||||
"instructions and structured inputs."
|
||||
),
|
||||
],
|
||||
)
|
||||
bus = EventBus()
|
||||
doom_events: list = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
|
||||
handler=lambda e: doom_events.append(e),
|
||||
)
|
||||
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm, tools=[])
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
stall_similarity_threshold=1.0,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
|
||||
assert result.success is False
|
||||
assert "fresh payload" in (result.error or "").lower()
|
||||
assert doom_events
|
||||
assert "fresh payload" in doom_events[0].data["description"].lower()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
||||
@@ -241,6 +241,48 @@ class TestToolConversion:
|
||||
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
|
||||
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
|
||||
|
||||
def test_parse_tool_call_arguments_recovers_pythonish_payloads(self):
|
||||
"""Single-quoted and trailing-comma argument payloads should be recovered."""
|
||||
provider = LiteLLMProvider(model="openai/gpt-5.4", api_key="test-key")
|
||||
|
||||
parsed = provider._parse_tool_call_arguments(
|
||||
"{'question': 'Continue?', 'options': ['Yes', 'No'],}",
|
||||
"ask_user",
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
"question": "Continue?",
|
||||
"options": ["Yes", "No"],
|
||||
}
|
||||
|
||||
def test_parse_tool_call_arguments_keeps_null_inside_strings(self):
|
||||
"""Literal normalization should not mutate quoted text values."""
|
||||
provider = LiteLLMProvider(model="openai/gpt-5.4", api_key="test-key")
|
||||
|
||||
parsed = provider._parse_tool_call_arguments(
|
||||
"{'hypothesis': 'null hypothesis', 'approved': false}",
|
||||
"summarize",
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
"hypothesis": "null hypothesis",
|
||||
"approved": False,
|
||||
}
|
||||
|
||||
def test_parse_tool_call_arguments_strips_json_code_fences(self):
|
||||
"""Fence stripping should remove the language tag before JSON parsing."""
|
||||
provider = LiteLLMProvider(model="openai/gpt-5.4", api_key="test-key")
|
||||
|
||||
parsed = provider._parse_tool_call_arguments(
|
||||
'```json\n{"question":"Continue?","options":["Yes","No"]}\n```',
|
||||
"ask_user",
|
||||
)
|
||||
|
||||
assert parsed == {
|
||||
"question": "Continue?",
|
||||
"options": ["Yes", "No"],
|
||||
}
|
||||
|
||||
|
||||
class TestAnthropicProviderBackwardCompatibility:
|
||||
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
|
||||
@@ -731,6 +773,221 @@ class TestMiniMaxStreamFallback:
|
||||
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
|
||||
|
||||
|
||||
class TestCodexEmptyStreamRecovery:
|
||||
"""Codex empty streams should fall back before surfacing ghost-stream retries."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_recovers_empty_codex_stream_via_nonstream_completion(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""An empty Codex stream should be salvaged with a non-stream completion."""
|
||||
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.4",
|
||||
api_key="test-key",
|
||||
api_base="https://chatgpt.com/backend-api/codex",
|
||||
)
|
||||
|
||||
class EmptyStreamResponse:
|
||||
chunks: list = []
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise StopAsyncIteration
|
||||
|
||||
recovered = MagicMock()
|
||||
recovered.choices = [MagicMock()]
|
||||
recovered.choices[0].message.content = "Recovered via fallback"
|
||||
recovered.choices[0].message.tool_calls = []
|
||||
recovered.choices[0].finish_reason = "stop"
|
||||
recovered.model = provider.model
|
||||
recovered.usage.prompt_tokens = 12
|
||||
recovered.usage.completion_tokens = 4
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return EmptyStreamResponse()
|
||||
return recovered
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}]):
|
||||
events.append(event)
|
||||
|
||||
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
|
||||
assert len(text_events) == 1
|
||||
assert text_events[0].snapshot == "Recovered via fallback"
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "stop"
|
||||
assert finish_events[0].input_tokens == 12
|
||||
assert finish_events[0].output_tokens == 4
|
||||
|
||||
assert mock_acompletion.call_count == 2
|
||||
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
|
||||
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_stream_recovers_empty_codex_stream_with_tool_calls(
|
||||
self,
|
||||
mock_acompletion,
|
||||
):
|
||||
"""Non-stream fallback should preserve tool calls, not just text."""
|
||||
from framework.llm.stream_events import FinishEvent, ToolCallEvent
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.4",
|
||||
api_key="test-key",
|
||||
api_base="https://chatgpt.com/backend-api/codex/responses",
|
||||
)
|
||||
|
||||
class EmptyStreamResponse:
|
||||
chunks: list = []
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise StopAsyncIteration
|
||||
|
||||
tc = MagicMock()
|
||||
tc.id = "tool_1"
|
||||
tc.function.name = "ask_user"
|
||||
tc.function.arguments = '{"question":"Continue?","options":["Yes","No"]}'
|
||||
|
||||
recovered = MagicMock()
|
||||
recovered.choices = [MagicMock()]
|
||||
recovered.choices[0].message.content = ""
|
||||
recovered.choices[0].message.tool_calls = [tc]
|
||||
recovered.choices[0].finish_reason = "tool_calls"
|
||||
recovered.model = provider.model
|
||||
recovered.usage.prompt_tokens = 14
|
||||
recovered.usage.completion_tokens = 5
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return EmptyStreamResponse()
|
||||
return recovered
|
||||
|
||||
mock_acompletion.side_effect = side_effect
|
||||
|
||||
events = []
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "Should we continue?"}],
|
||||
tools=[
|
||||
Tool(
|
||||
name="ask_user",
|
||||
description="Ask the user",
|
||||
parameters={"properties": {"question": {"type": "string"}}},
|
||||
)
|
||||
],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
tool_events = [event for event in events if isinstance(event, ToolCallEvent)]
|
||||
assert len(tool_events) == 1
|
||||
assert tool_events[0].tool_name == "ask_user"
|
||||
assert tool_events[0].tool_input == {
|
||||
"question": "Continue?",
|
||||
"options": ["Yes", "No"],
|
||||
}
|
||||
|
||||
finish_events = [event for event in events if isinstance(event, FinishEvent)]
|
||||
assert len(finish_events) == 1
|
||||
assert finish_events[0].stop_reason == "tool_calls"
|
||||
|
||||
|
||||
class TestCodexRequestHardening:
|
||||
def test_codex_build_completion_kwargs_splits_prompt_and_forces_tool_choice(self):
|
||||
"""Codex requests should chunk large system prompts and require tools when needed."""
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.4",
|
||||
api_key="test-key",
|
||||
api_base="https://chatgpt.com/backend-api/codex/responses",
|
||||
)
|
||||
kwargs = provider._build_completion_kwargs(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
system="# Identity\n" + ("rule\n" * 2000),
|
||||
tools=[
|
||||
Tool(
|
||||
name="ask_user",
|
||||
description="Ask the user",
|
||||
parameters={"properties": {"question": {"type": "string"}}},
|
||||
)
|
||||
],
|
||||
max_tokens=256,
|
||||
response_format=None,
|
||||
json_mode=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
system_messages = [m for m in kwargs["messages"] if m["role"] == "system"]
|
||||
assert len(system_messages) >= 2
|
||||
assert system_messages[0]["content"].startswith("# Codex Execution Contract")
|
||||
assert kwargs["tool_choice"] == "required"
|
||||
assert kwargs["store"] is False
|
||||
assert "max_tokens" not in kwargs
|
||||
assert "stream_options" not in kwargs
|
||||
assert kwargs["api_base"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert "store" in kwargs["allowed_openai_params"]
|
||||
|
||||
def test_codex_merge_tool_call_chunk_handles_parallel_calls_with_broken_indexes(self):
|
||||
"""Codex chunk merging should survive index=0 for multiple parallel tool calls."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.4",
|
||||
api_key="test-key",
|
||||
api_base="https://chatgpt.com/backend-api/codex",
|
||||
)
|
||||
acc: dict[int, dict[str, str]] = {}
|
||||
last_idx = 0
|
||||
|
||||
chunks = [
|
||||
SimpleNamespace(
|
||||
id="tool_1",
|
||||
index=0,
|
||||
function=SimpleNamespace(name="web_search", arguments='{"query":"alpha'),
|
||||
),
|
||||
SimpleNamespace(
|
||||
id="tool_2",
|
||||
index=0,
|
||||
function=SimpleNamespace(name="read_file", arguments='{"path":"beta'),
|
||||
),
|
||||
SimpleNamespace(
|
||||
id=None,
|
||||
index=0,
|
||||
function=SimpleNamespace(name=None, arguments='"}'),
|
||||
),
|
||||
SimpleNamespace(
|
||||
id=None,
|
||||
index=0,
|
||||
function=SimpleNamespace(name=None, arguments='"}'),
|
||||
),
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
last_idx = provider._merge_tool_call_chunk(acc, chunk, last_idx)
|
||||
|
||||
assert len(acc) == 2
|
||||
parsed = [
|
||||
provider._parse_tool_call_arguments(slot["arguments"], slot["name"])
|
||||
for _, slot in sorted(acc.items())
|
||||
]
|
||||
assert parsed == [
|
||||
{"query": "alpha"},
|
||||
{"path": "beta"},
|
||||
]
|
||||
|
||||
|
||||
class TestOpenRouterToolCompatFallback:
|
||||
"""OpenRouter models should fall back when native tool use is unavailable."""
|
||||
|
||||
|
||||
@@ -0,0 +1,985 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import framework.tools.queen_lifecycle_tools as qlt
|
||||
from framework.llm.provider import Tool
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools
|
||||
|
||||
|
||||
def _write_worker_logs(
|
||||
storage_path: Path,
|
||||
session_id: str,
|
||||
*,
|
||||
session_status: str,
|
||||
steps: list[dict[str, object]],
|
||||
) -> Path:
|
||||
session_dir = storage_path / "sessions" / session_id
|
||||
logs_dir = session_dir / "logs"
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
(session_dir / "state.json").write_text(
|
||||
json.dumps({"status": session_status}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
log_path = logs_dir / "tool_logs.jsonl"
|
||||
log_path.write_text(
|
||||
"".join(json.dumps(step) + "\n" for step in steps),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return log_path
|
||||
|
||||
|
||||
def _register_fake_validator(registry: ToolRegistry, report: dict) -> None:
|
||||
registry.register(
|
||||
"validate_agent_package",
|
||||
Tool(
|
||||
name="validate_agent_package",
|
||||
description="fake validator",
|
||||
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
|
||||
),
|
||||
lambda _inputs: json.dumps(report),
|
||||
)
|
||||
|
||||
|
||||
def test_parse_validation_report_handles_saved_footer() -> None:
|
||||
raw = (
|
||||
'{\n "valid": false,\n "steps": {"tool_validation": {"passed": false}}\n}\n\n'
|
||||
"[Saved to 'validate.txt']"
|
||||
)
|
||||
|
||||
parsed = qlt._parse_validation_report(raw)
|
||||
|
||||
assert parsed == {"valid": False, "steps": {"tool_validation": {"passed": False}}}
|
||||
|
||||
|
||||
def test_validation_blocks_stage_or_run_ignores_non_blocking_warnings() -> None:
|
||||
report = {
|
||||
"steps": {
|
||||
"behavior_validation": {
|
||||
"passed": True,
|
||||
"warnings": ["placeholder prompt"],
|
||||
"output": "placeholder prompt",
|
||||
},
|
||||
"tests": {
|
||||
"passed": True,
|
||||
"warnings": ["1 failed"],
|
||||
"summary": "1 failed",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert qlt._validation_blocks_stage_or_run(report) is False
|
||||
|
||||
|
||||
def test_invalid_validation_report_blocks_stage_or_run() -> None:
|
||||
report = qlt._invalid_validation_report("validator returned garbage")
|
||||
|
||||
assert report["valid"] is False
|
||||
assert qlt._validation_blocks_stage_or_run(report) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_worker_status_summary_flags_retry_and_judge_pressure() -> None:
|
||||
registry = ToolRegistry()
|
||||
bus = EventBus()
|
||||
|
||||
await bus.emit_node_retry(
|
||||
stream_id="worker",
|
||||
node_id="scan",
|
||||
retry_count=1,
|
||||
max_retries=3,
|
||||
error="still missing required result",
|
||||
)
|
||||
for _ in range(4):
|
||||
await bus.emit_judge_verdict(
|
||||
stream_id="worker",
|
||||
node_id="scan",
|
||||
action="RETRY",
|
||||
feedback="missing structured output",
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
graph_id="worker-graph",
|
||||
get_graph_registration=lambda _gid: SimpleNamespace(
|
||||
streams={
|
||||
"default": SimpleNamespace(
|
||||
active_execution_ids=["exec-1"],
|
||||
get_context=lambda _exec_id: SimpleNamespace(started_at=datetime.now()),
|
||||
get_waiting_nodes=lambda: [],
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(worker_runtime=runtime, event_bus=bus, worker_path=None, runner=None)
|
||||
|
||||
register_queen_lifecycle_tools(registry, session=session, session_id="sess-status")
|
||||
|
||||
summary = await registry._tools["get_worker_status"].executor({})
|
||||
|
||||
assert "issue type(s) detected" in summary
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_worker_status_issues_reports_judge_pressure() -> None:
|
||||
registry = ToolRegistry()
|
||||
bus = EventBus()
|
||||
|
||||
for action in ("CONTINUE", "RETRY", "RETRY", "ESCALATE"):
|
||||
await bus.emit_judge_verdict(
|
||||
stream_id="worker",
|
||||
node_id="review",
|
||||
action=action,
|
||||
feedback="still not converging",
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
graph_id="worker-graph",
|
||||
get_graph_registration=lambda _gid: SimpleNamespace(streams={}),
|
||||
)
|
||||
session = SimpleNamespace(worker_runtime=runtime, event_bus=bus, worker_path=None, runner=None)
|
||||
|
||||
register_queen_lifecycle_tools(registry, session=session, session_id="sess-issues")
|
||||
|
||||
issues = await registry._tools["get_worker_status"].executor({"focus": "issues"})
|
||||
|
||||
assert "Judge pressure detected" in issues
|
||||
assert "consecutive non-ACCEPT judge verdict" in issues
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_worker_status_summary_uses_health_snapshot_signals(tmp_path: Path) -> None:
|
||||
storage_path = tmp_path / "agent_store"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
log_path = _write_worker_logs(
|
||||
storage_path,
|
||||
"sess-health",
|
||||
session_status="running",
|
||||
steps=[
|
||||
{"verdict": "CONTINUE", "llm_text": "thinking"},
|
||||
{"verdict": "RETRY", "llm_text": "retrying"},
|
||||
{"verdict": "RETRY", "llm_text": "still retrying"},
|
||||
{"verdict": "ESCALATE", "llm_text": "need help"},
|
||||
],
|
||||
)
|
||||
three_minutes_ago = time.time() - 180
|
||||
os.utime(log_path, (three_minutes_ago, three_minutes_ago))
|
||||
|
||||
registry = ToolRegistry()
|
||||
bus = EventBus()
|
||||
runtime = SimpleNamespace(
|
||||
graph_id="worker-graph",
|
||||
get_graph_registration=lambda _gid: SimpleNamespace(
|
||||
streams={
|
||||
"default": SimpleNamespace(
|
||||
active_execution_ids=["exec-1"],
|
||||
get_context=lambda _exec_id: SimpleNamespace(started_at=datetime.now()),
|
||||
get_waiting_nodes=lambda: [],
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=bus,
|
||||
worker_path=storage_path,
|
||||
runner=None,
|
||||
)
|
||||
|
||||
register_queen_lifecycle_tools(registry, session=session, session_id="sess-health")
|
||||
|
||||
summary = await registry._tools["get_worker_status"].executor({})
|
||||
|
||||
assert "issue signal(s) detected" in summary
|
||||
assert "judge_pressure" in summary
|
||||
assert "recent_non_accept_churn" in summary
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_worker_status_issues_includes_health_snapshot_signals(tmp_path: Path) -> None:
|
||||
storage_path = tmp_path / "agent_store"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
log_path = _write_worker_logs(
|
||||
storage_path,
|
||||
"sess-health",
|
||||
session_status="running",
|
||||
steps=[
|
||||
{"verdict": "CONTINUE", "llm_text": "thinking"},
|
||||
{"verdict": "RETRY", "llm_text": "retrying"},
|
||||
{"verdict": "RETRY", "llm_text": "still retrying"},
|
||||
{"verdict": "ESCALATE", "llm_text": "need help"},
|
||||
],
|
||||
)
|
||||
three_minutes_ago = time.time() - 180
|
||||
os.utime(log_path, (three_minutes_ago, three_minutes_ago))
|
||||
|
||||
registry = ToolRegistry()
|
||||
bus = EventBus()
|
||||
runtime = SimpleNamespace(
|
||||
graph_id="worker-graph",
|
||||
get_graph_registration=lambda _gid: SimpleNamespace(streams={}),
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=bus,
|
||||
worker_path=storage_path,
|
||||
runner=None,
|
||||
)
|
||||
|
||||
register_queen_lifecycle_tools(registry, session=session, session_id="sess-health")
|
||||
|
||||
issues = await registry._tools["get_worker_status"].executor({"focus": "issues"})
|
||||
|
||||
assert "Health signals:" in issues
|
||||
assert "slow_progress" in issues
|
||||
assert "recent_non_accept_churn" in issues
|
||||
|
||||
|
||||
def test_build_worker_input_data_maps_bullet_task_fields_to_entry_inputs(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "docs").mkdir()
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
graph=SimpleNamespace(
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(
|
||||
input_keys=[
|
||||
"docs_dir",
|
||||
"review_dir",
|
||||
"word_threshold",
|
||||
"style_rules",
|
||||
"target_ratio",
|
||||
]
|
||||
)
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
payload = qlt._build_worker_input_data(
|
||||
runtime,
|
||||
(
|
||||
"Run md_condense_reviewer with the following runtime config:\n"
|
||||
"- docs_dir: docs/\n"
|
||||
"- review_dir: docs_reviews/\n"
|
||||
"- word_threshold: 800\n"
|
||||
"- target_ratio: 0.6 (default)\n"
|
||||
"- style_rules: Preserve headings and links.\n\n"
|
||||
"Execution requirements:\n"
|
||||
"1) Scan the docs directory.\n"
|
||||
"2) Write review copies."
|
||||
),
|
||||
)
|
||||
|
||||
assert payload == {
|
||||
"docs_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 800,
|
||||
"style_rules": "Preserve headings and links.",
|
||||
"target_ratio": 0.6,
|
||||
}
|
||||
|
||||
|
||||
def test_build_worker_input_data_maps_equals_style_runtime_fields(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
graph=SimpleNamespace(
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir_mode", "word_threshold"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
payload = qlt._build_worker_input_data(
|
||||
runtime,
|
||||
("Yes, rerun with target_dir=docs review_dir_mode=next_to_source word_threshold=800"),
|
||||
)
|
||||
|
||||
assert payload == {
|
||||
"target_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir_mode": "next_to_source",
|
||||
"word_threshold": 800,
|
||||
}
|
||||
|
||||
|
||||
def test_build_worker_input_data_backfills_missing_fields_from_recent_session(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
|
||||
valid_prior_state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
|
||||
"input_data": {
|
||||
"target_dir": "docs",
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": 800,
|
||||
},
|
||||
}
|
||||
malformed_recent_state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T21:20:23"},
|
||||
"input_data": {
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": "800. Validate inputs and continue.",
|
||||
},
|
||||
}
|
||||
|
||||
for session_name, state in {
|
||||
"session_20260324_204400_good": valid_prior_state,
|
||||
"session_20260324_212023_bad": malformed_recent_state,
|
||||
}.items():
|
||||
session_dir = sessions_dir / session_name
|
||||
session_dir.mkdir()
|
||||
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
graph=SimpleNamespace(
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
payload = qlt._build_worker_input_data(
|
||||
runtime,
|
||||
("review_dir: docs_reviews\nword_threshold: 800. Validate inputs and continue."),
|
||||
)
|
||||
|
||||
assert payload == {
|
||||
"target_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 800,
|
||||
}
|
||||
|
||||
|
||||
def test_build_worker_input_data_reuses_recent_defaults_for_rerun_phrase(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
|
||||
state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T21:17:00"},
|
||||
"input_data": {
|
||||
"target_dir": "docs",
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": 800,
|
||||
},
|
||||
}
|
||||
session_dir = sessions_dir / "session_20260324_211700_prev"
|
||||
session_dir.mkdir()
|
||||
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
graph=SimpleNamespace(
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
payload = qlt._build_worker_input_data(runtime, "Run again with same defaults")
|
||||
|
||||
assert payload == {
|
||||
"target_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 800,
|
||||
}
|
||||
|
||||
|
||||
def test_build_worker_input_data_backfills_from_recent_result_output(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
|
||||
state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T23:35:19"},
|
||||
"input_data": {
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": 800,
|
||||
},
|
||||
"result": {
|
||||
"output": {
|
||||
"target_dir": "docs",
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": 800,
|
||||
}
|
||||
},
|
||||
}
|
||||
session_dir = sessions_dir / "session_20260324_233519_prev"
|
||||
session_dir.mkdir()
|
||||
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
graph=SimpleNamespace(
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
payload = qlt._build_worker_input_data(
|
||||
runtime,
|
||||
("review_dir: docs_reviews\nword_threshold: 600"),
|
||||
)
|
||||
|
||||
assert payload == {
|
||||
"target_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 600,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_built_agent_blocks_invalid_package(monkeypatch, tmp_path: Path) -> None:
|
||||
registry = ToolRegistry()
|
||||
captured: dict[str, str] = {}
|
||||
registry.register(
|
||||
"validate_agent_package",
|
||||
Tool(
|
||||
name="validate_agent_package",
|
||||
description="fake validator",
|
||||
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
|
||||
),
|
||||
lambda inputs: (
|
||||
captured.setdefault("agent_name", inputs["agent_name"]),
|
||||
json.dumps(
|
||||
{
|
||||
"valid": False,
|
||||
"steps": {
|
||||
"behavior_validation": {
|
||||
"passed": False,
|
||||
"output": (
|
||||
"Node 'scan-markdown' has a blank or placeholder system_prompt"
|
||||
),
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
)[1],
|
||||
)
|
||||
|
||||
session = SimpleNamespace(worker_runtime=None, event_bus=None, worker_path=None, runner=None)
|
||||
fake_manager = SimpleNamespace(
|
||||
get_session=lambda _sid: None,
|
||||
unload_worker=AsyncMock(),
|
||||
load_worker=AsyncMock(),
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="building")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_manager=fake_manager,
|
||||
manager_session_id="sess-1",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
agent_dir = tmp_path / "broken_agent"
|
||||
agent_dir.mkdir()
|
||||
monkeypatch.setattr(qlt, "validate_agent_path", lambda _path: agent_dir)
|
||||
|
||||
result_raw = await registry._tools["load_built_agent"].executor({"agent_path": str(agent_dir)})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert "Cannot load agent" in result["error"]
|
||||
assert "behavior_validation" in result["validation_failures"][0]
|
||||
assert captured["agent_name"] == str(agent_dir)
|
||||
fake_manager.load_worker.assert_not_called()
|
||||
fake_manager.unload_worker.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_built_agent_keeps_current_worker_when_replacement_fails_validation(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
_register_fake_validator(
|
||||
registry,
|
||||
{
|
||||
"valid": False,
|
||||
"steps": {
|
||||
"behavior_validation": {
|
||||
"passed": False,
|
||||
"output": "Node 'scan' has a blank or placeholder system_prompt",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=SimpleNamespace(),
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/existing_agent"),
|
||||
runner=None,
|
||||
)
|
||||
fake_manager = SimpleNamespace(
|
||||
get_session=lambda _sid: None,
|
||||
unload_worker=AsyncMock(),
|
||||
load_worker=AsyncMock(),
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="building")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_manager=fake_manager,
|
||||
manager_session_id="sess-1",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
agent_dir = tmp_path / "broken_agent"
|
||||
agent_dir.mkdir()
|
||||
monkeypatch.setattr(qlt, "validate_agent_path", lambda _path: agent_dir)
|
||||
|
||||
result_raw = await registry._tools["load_built_agent"].executor({"agent_path": str(agent_dir)})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert "Cannot load agent" in result["error"]
|
||||
fake_manager.unload_worker.assert_not_called()
|
||||
fake_manager.load_worker.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_with_input_blocks_loaded_invalid_worker() -> None:
|
||||
registry = ToolRegistry()
|
||||
_register_fake_validator(
|
||||
registry,
|
||||
{
|
||||
"valid": False,
|
||||
"steps": {
|
||||
"tool_validation": {
|
||||
"passed": False,
|
||||
"output": "Scan Markdown Files missing run_command",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(),
|
||||
_get_primary_session_state=MagicMock(return_value={}),
|
||||
graph=SimpleNamespace(nodes=[]),
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/broken_agent"),
|
||||
runner=None,
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="staging")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-2",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["run_agent_with_input"].executor({"task": "run it"})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert "Cannot run agent" in result["error"]
|
||||
assert "tool_validation" in result["validation_failures"][0]
|
||||
runtime.trigger.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_with_input_uses_structured_entry_inputs(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
_register_fake_validator(registry, {"valid": True, "steps": {}})
|
||||
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-1"),
|
||||
_get_primary_session_state=MagicMock(return_value={}),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
graph=SimpleNamespace(
|
||||
nodes=[],
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(
|
||||
input_keys=["docs_path", "review_path", "word_threshold", "style_rules"]
|
||||
)
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/markdown_condense_approver"),
|
||||
runner=None,
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="staging")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-3",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["run_agent_with_input"].executor(
|
||||
{
|
||||
"task": (
|
||||
"docs_path: docs/ review_path: docs_reviews/ word_threshold: 800 "
|
||||
"style_rules: Preserve headings, keep links intact."
|
||||
)
|
||||
}
|
||||
)
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert result["status"] == "started"
|
||||
runtime.trigger.assert_awaited_once()
|
||||
trigger_kwargs = runtime.trigger.await_args.kwargs
|
||||
assert trigger_kwargs["input_data"] == {
|
||||
"docs_path": str((tmp_path / "docs").resolve()),
|
||||
"review_path": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 800,
|
||||
"style_rules": "Preserve headings, keep links intact.",
|
||||
}
|
||||
assert trigger_kwargs["session_state"] is None
|
||||
runtime._get_primary_session_state.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerun_worker_with_last_input_reuses_complete_recent_defaults(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
_register_fake_validator(registry, {"valid": True, "steps": {}})
|
||||
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
|
||||
valid_prior_state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
|
||||
"input_data": {
|
||||
"target_dir": "docs",
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": 800,
|
||||
"feedback": "stale",
|
||||
},
|
||||
}
|
||||
malformed_recent_state = {
|
||||
"timestamps": {"updated_at": "2026-03-24T21:20:23"},
|
||||
"input_data": {
|
||||
"review_dir": "docs_reviews",
|
||||
"word_threshold": "800. Validate inputs and continue.",
|
||||
},
|
||||
}
|
||||
|
||||
for session_name, state in {
|
||||
"session_20260324_204400_good": valid_prior_state,
|
||||
"session_20260324_212023_bad": malformed_recent_state,
|
||||
}.items():
|
||||
session_dir = sessions_dir / session_name
|
||||
session_dir.mkdir()
|
||||
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-rerun"),
|
||||
graph=SimpleNamespace(
|
||||
nodes=[],
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/local_markdown_review_probe_2"),
|
||||
runner=None,
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="staging")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-rerun",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert result["status"] == "started"
|
||||
runtime.trigger.assert_awaited_once()
|
||||
trigger_kwargs = runtime.trigger.await_args.kwargs
|
||||
assert trigger_kwargs["input_data"] == {
|
||||
"target_dir": str((tmp_path / "docs").resolve()),
|
||||
"review_dir": str((tmp_path / "docs_reviews").resolve()),
|
||||
"word_threshold": 800,
|
||||
}
|
||||
assert trigger_kwargs["session_state"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerun_worker_with_last_input_preserves_legacy_task_payload(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
_register_fake_validator(registry, {"valid": True, "steps": {}})
|
||||
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
sessions_dir.mkdir(parents=True)
|
||||
session_dir = sessions_dir / "session_20260324_204400_task"
|
||||
session_dir.mkdir()
|
||||
(session_dir / "state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
|
||||
"input_data": {"task": "re-run the markdown review"},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-rerun"),
|
||||
graph=SimpleNamespace(
|
||||
nodes=[],
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["task"]) if node_id == "process" else None
|
||||
),
|
||||
),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/legacy_worker"),
|
||||
runner=None,
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="staging")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-rerun",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert result["status"] == "started"
|
||||
trigger_kwargs = runtime.trigger.await_args.kwargs
|
||||
assert trigger_kwargs["input_data"] == {"task": "re-run the markdown review"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerun_worker_with_last_input_uses_current_session_defaults_only(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
_register_fake_validator(registry, {"valid": True, "steps": {}})
|
||||
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
sessions_dir = tmp_path / "agent_store" / "sessions"
|
||||
current_session_dir = sessions_dir / "sess-rerun-current"
|
||||
current_session_dir.mkdir(parents=True)
|
||||
(current_session_dir / "state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
|
||||
"input_data": {"target_dir": str((tmp_path / "current").resolve())},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
other_session_dir = sessions_dir / "session_20260325_204400_other"
|
||||
other_session_dir.mkdir(parents=True)
|
||||
(other_session_dir / "state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"timestamps": {"updated_at": "2026-03-25T20:44:00"},
|
||||
"input_data": {
|
||||
"target_dir": str((tmp_path / "other").resolve()),
|
||||
"review_dir": str((tmp_path / "other_reviews").resolve()),
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-rerun"),
|
||||
graph=SimpleNamespace(
|
||||
nodes=[],
|
||||
entry_node="process",
|
||||
get_node=lambda node_id: (
|
||||
SimpleNamespace(input_keys=["target_dir", "review_dir"])
|
||||
if node_id == "process"
|
||||
else None
|
||||
),
|
||||
),
|
||||
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/docs_reviewer"),
|
||||
runner=None,
|
||||
)
|
||||
phase_state = QueenPhaseState(phase="staging")
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-rerun-current",
|
||||
phase_state=phase_state,
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert (
|
||||
result["error"]
|
||||
== "No complete previous worker input is available for a same-defaults rerun."
|
||||
)
|
||||
assert result["missing_inputs"] == ["review_dir"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_with_input_blocks_when_validator_output_is_undecodable(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(
|
||||
"validate_agent_package",
|
||||
Tool(
|
||||
name="validate_agent_package",
|
||||
description="fake validator",
|
||||
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
|
||||
),
|
||||
lambda _inputs: "not-json",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-2"),
|
||||
graph=SimpleNamespace(nodes=[]),
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/docs_sanitizer_agent"),
|
||||
runner=None,
|
||||
)
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-invalid-validator",
|
||||
phase_state=QueenPhaseState(phase="staging"),
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["run_agent_with_input"].executor({"task": "run it"})
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert "validation is failing" in result["error"]
|
||||
assert result["validation_failures"] == [
|
||||
"validator_subprocess: validate_agent_package returned an invalid or undecodable report"
|
||||
]
|
||||
runtime.trigger.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_worker_starts_fresh_worker_session(monkeypatch, tmp_path: Path) -> None:
|
||||
registry = ToolRegistry()
|
||||
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
runtime = SimpleNamespace(
|
||||
resume_timers=MagicMock(),
|
||||
trigger=AsyncMock(return_value="exec-2"),
|
||||
_get_primary_session_state=MagicMock(return_value={"resume_session_id": "old"}),
|
||||
graph=SimpleNamespace(nodes=[]),
|
||||
)
|
||||
session = SimpleNamespace(
|
||||
worker_runtime=runtime,
|
||||
event_bus=None,
|
||||
worker_path=Path("exports/docs_sanitizer_agent"),
|
||||
runner=None,
|
||||
)
|
||||
register_queen_lifecycle_tools(
|
||||
registry,
|
||||
session=session,
|
||||
session_id="sess-4",
|
||||
phase_state=QueenPhaseState(phase="staging"),
|
||||
)
|
||||
|
||||
result_raw = await registry._tools["start_worker"].executor(
|
||||
{"task": "run with docs_path: docs/"}
|
||||
)
|
||||
result = json.loads(result_raw)
|
||||
|
||||
assert result["status"] == "started"
|
||||
runtime.trigger.assert_awaited_once()
|
||||
trigger_kwargs = runtime.trigger.await_args.kwargs
|
||||
assert trigger_kwargs["session_state"] is None
|
||||
runtime._get_primary_session_state.assert_not_called()
|
||||
@@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.runtime.event_bus import EventBus
|
||||
import framework.agents.worker_memory as worker_memory
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.server.session_manager import Session, SessionManager
|
||||
|
||||
|
||||
@@ -123,3 +126,52 @@ async def test_stop_session_unsubscribes_worker_handoff() -> None:
|
||||
reason="after stop",
|
||||
)
|
||||
assert queen_node.inject_event.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_digest_final_completion_does_not_overwrite_terminal_result(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
bus = EventBus()
|
||||
manager = SessionManager()
|
||||
session = _make_session(bus, session_id="session_digest_final")
|
||||
session.worker_path = Path("/tmp/log_triage_agent")
|
||||
|
||||
queen_node = SimpleNamespace(inject_event=AsyncMock())
|
||||
session.queen_executor = _make_executor(queen_node)
|
||||
|
||||
consolidate = AsyncMock()
|
||||
monkeypatch.setattr(worker_memory, "consolidate_worker_run", consolidate)
|
||||
|
||||
manager._subscribe_worker_digest(session)
|
||||
|
||||
bus.get_history = lambda event_type=None, limit=None: [
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="default",
|
||||
execution_id="exec_digest",
|
||||
run_id="run_digest",
|
||||
)
|
||||
]
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="default",
|
||||
execution_id="exec_digest",
|
||||
run_id="run_digest",
|
||||
)
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="default",
|
||||
execution_id="exec_digest",
|
||||
run_id="run_digest",
|
||||
data={"output": {"result": "final answer"}},
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
consolidate.assert_awaited_once()
|
||||
assert queen_node.inject_event.await_count == 0
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import framework.runner as runner_mod
|
||||
from framework.server import session_manager as sm
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_worker_blocks_invalid_package_before_runner_load(monkeypatch) -> None:
|
||||
manager = sm.SessionManager()
|
||||
session = sm.Session(
|
||||
id="sess-1",
|
||||
event_bus=MagicMock(),
|
||||
llm=MagicMock(),
|
||||
loaded_at=0.0,
|
||||
)
|
||||
manager._sessions[session.id] = session
|
||||
|
||||
captured: dict[str, str] = {}
|
||||
|
||||
def _fake_validation(agent_ref):
|
||||
captured["agent_ref"] = str(agent_ref)
|
||||
return {
|
||||
"valid": False,
|
||||
"steps": {
|
||||
"behavior_validation": {
|
||||
"passed": False,
|
||||
"errors": ["Node 'scan' has a blank or placeholder system_prompt"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
sm,
|
||||
"_run_validation_report_sync",
|
||||
_fake_validation,
|
||||
)
|
||||
|
||||
called = {"runner_load": False}
|
||||
|
||||
class _FakeAgentRunner:
|
||||
@staticmethod
|
||||
def load(*args, **kwargs):
|
||||
called["runner_load"] = True
|
||||
raise AssertionError("AgentRunner.load should not run for invalid workers")
|
||||
|
||||
monkeypatch.setattr(runner_mod, "AgentRunner", _FakeAgentRunner)
|
||||
|
||||
with pytest.raises(sm.WorkerValidationError) as exc:
|
||||
await manager.load_worker(session.id, Path("/tmp/bad_worker"))
|
||||
|
||||
assert "blank or placeholder system_prompt" in str(exc.value)
|
||||
assert Path(captured["agent_ref"]).as_posix() == "/tmp/bad_worker"
|
||||
assert called["runner_load"] is False
|
||||
|
||||
|
||||
def test_run_validation_report_sync_uses_internal_validator_impl(monkeypatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _Proc:
|
||||
returncode = 0
|
||||
stdout = '{"valid": true, "steps": {}}'
|
||||
stderr = ""
|
||||
|
||||
def _fake_run(cmd, **kwargs):
|
||||
captured["cmd"] = cmd
|
||||
return _Proc()
|
||||
|
||||
monkeypatch.setattr(sm.subprocess, "run", _fake_run)
|
||||
|
||||
report = sm._run_validation_report_sync("/tmp/demo_agent")
|
||||
|
||||
assert report["valid"] is True
|
||||
script = captured["cmd"][4]
|
||||
assert "_validate_agent_package_impl" in script
|
||||
assert "validate_agent_package(agent_name)" not in script
|
||||
assert captured["cmd"][6] == "/tmp/demo_agent"
|
||||
|
||||
|
||||
def test_validation_blocks_stage_or_run_ignores_non_blocking_warnings() -> None:
|
||||
report = {
|
||||
"steps": {
|
||||
"behavior_validation": {
|
||||
"passed": True,
|
||||
"warnings": ["placeholder prompt"],
|
||||
"output": "placeholder prompt",
|
||||
},
|
||||
"tests": {
|
||||
"passed": True,
|
||||
"warnings": ["1 failed"],
|
||||
"summary": "1 failed",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
assert sm._validation_blocks_stage_or_run(report) is False
|
||||
|
||||
|
||||
def test_run_validation_report_sync_handles_subprocess_launcher_errors(monkeypatch) -> None:
|
||||
def _boom(*args, **kwargs):
|
||||
raise FileNotFoundError("uv not found")
|
||||
|
||||
monkeypatch.setattr(sm.subprocess, "run", _boom)
|
||||
|
||||
report = sm._run_validation_report_sync("/tmp/demo_agent")
|
||||
|
||||
assert report["valid"] is False
|
||||
assert report["steps"]["validator_subprocess"]["passed"] is False
|
||||
assert "uv not found" in report["steps"]["validator_subprocess"]["error"]
|
||||
@@ -331,7 +331,7 @@ def _make_codex_provider():
|
||||
if not api_key or not api_base:
|
||||
return None
|
||||
return LiteLLMProvider(
|
||||
model="openai/gpt-5.3-codex",
|
||||
model="openai/gpt-5.4",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
|
||||
@@ -123,6 +123,28 @@ class TestValidateAgentPathPositive:
|
||||
result = validate_agent_path(str(agent_dir))
|
||||
assert isinstance(result, Path)
|
||||
|
||||
def test_repo_relative_path_resolves_from_repo_root_not_cwd(self, tmp_path, monkeypatch):
|
||||
import framework.server.app as app_module
|
||||
|
||||
repo_root = tmp_path / "repo"
|
||||
examples_root = repo_root / "examples"
|
||||
agent_dir = examples_root / "some_agent"
|
||||
agent_dir.mkdir(parents=True)
|
||||
other_cwd = tmp_path / "elsewhere"
|
||||
other_cwd.mkdir()
|
||||
|
||||
monkeypatch.setattr(app_module, "_REPO_ROOT", repo_root)
|
||||
app_module._ALLOWED_AGENT_ROOTS = (
|
||||
repo_root / "exports",
|
||||
examples_root,
|
||||
tmp_path / ".hive" / "agents",
|
||||
)
|
||||
monkeypatch.chdir(other_cwd)
|
||||
|
||||
result = validate_agent_path("examples/some_agent")
|
||||
|
||||
assert result == agent_dir.resolve()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# validate_agent_path: negative cases (should raise ValueError)
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.event_bus import EventBus
|
||||
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
|
||||
|
||||
|
||||
def _write_session_logs(
|
||||
storage_path: Path,
|
||||
session_id: str,
|
||||
*,
|
||||
session_status: str,
|
||||
steps: list[dict],
|
||||
) -> Path:
|
||||
session_dir = storage_path / "sessions" / session_id
|
||||
logs_dir = session_dir / "logs"
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
(session_dir / "state.json").write_text(
|
||||
json.dumps({"status": session_status}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
log_path = logs_dir / "tool_logs.jsonl"
|
||||
log_path.write_text(
|
||||
"".join(json.dumps(step) + "\n" for step in steps),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return log_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_health_summary_marks_healthy_runs(tmp_path: Path) -> None:
|
||||
registry = ToolRegistry()
|
||||
event_bus = EventBus()
|
||||
storage_path = tmp_path / "agent_store"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_write_session_logs(
|
||||
storage_path,
|
||||
"session-healthy",
|
||||
session_status="running",
|
||||
steps=[
|
||||
{"verdict": "RETRY", "llm_text": "first pass"},
|
||||
{"verdict": "ACCEPT", "llm_text": "done"},
|
||||
],
|
||||
)
|
||||
|
||||
register_worker_monitoring_tools(
|
||||
registry,
|
||||
event_bus,
|
||||
storage_path,
|
||||
default_session_id="session-healthy",
|
||||
)
|
||||
|
||||
raw = await registry._tools["get_worker_health_summary"].executor({})
|
||||
data = json.loads(raw)
|
||||
|
||||
assert data["health_status"] == "healthy"
|
||||
assert data["issue_signals"] == []
|
||||
assert data["recent_verdicts"] == ["RETRY", "ACCEPT"]
|
||||
assert data["steps_since_last_accept"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_health_summary_flags_stall_and_non_accept_churn(tmp_path: Path) -> None:
|
||||
registry = ToolRegistry()
|
||||
event_bus = EventBus()
|
||||
storage_path = tmp_path / "agent_store"
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
log_path = _write_session_logs(
|
||||
storage_path,
|
||||
"session-stalled",
|
||||
session_status="running",
|
||||
steps=[
|
||||
{"verdict": "CONTINUE", "llm_text": "thinking"},
|
||||
{"verdict": "RETRY", "llm_text": "still working"},
|
||||
{"verdict": "RETRY", "llm_text": "trying again"},
|
||||
{"verdict": "ESCALATE", "llm_text": "blocked"},
|
||||
],
|
||||
)
|
||||
ten_minutes_ago = time.time() - 600
|
||||
os.utime(log_path, (ten_minutes_ago, ten_minutes_ago))
|
||||
|
||||
register_worker_monitoring_tools(
|
||||
registry,
|
||||
event_bus,
|
||||
storage_path,
|
||||
default_session_id="session-stalled",
|
||||
)
|
||||
|
||||
raw = await registry._tools["get_worker_health_summary"].executor({})
|
||||
data = json.loads(raw)
|
||||
|
||||
assert data["health_status"] == "critical"
|
||||
assert "stalled" in data["issue_signals"]
|
||||
assert "judge_pressure" in data["issue_signals"]
|
||||
assert "recent_non_accept_churn" in data["issue_signals"]
|
||||
assert data["steps_since_last_accept"] == 4
|
||||
assert data["stall_minutes"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_health_summary_errors_when_default_session_has_no_state(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
registry = ToolRegistry()
|
||||
event_bus = EventBus()
|
||||
storage_path = tmp_path / "agent_store"
|
||||
(storage_path / "sessions" / "session-stale").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
register_worker_monitoring_tools(
|
||||
registry,
|
||||
event_bus,
|
||||
storage_path,
|
||||
default_session_id="session-stale",
|
||||
)
|
||||
|
||||
raw = await registry._tools["get_worker_health_summary"].executor({})
|
||||
data = json.loads(raw)
|
||||
|
||||
assert "error" in data
|
||||
assert "session-stale" in data["error"]
|
||||
+3
-3
@@ -1279,9 +1279,9 @@ switch ($num) {
|
||||
if ($CodexCredDetected) {
|
||||
$SubscriptionMode = "codex"
|
||||
$SelectedProviderId = "openai"
|
||||
$SelectedModel = "gpt-5.3-codex"
|
||||
$SelectedMaxTokens = 16384
|
||||
$SelectedMaxContextTokens = 120000
|
||||
$SelectedModel = "gpt-5.4"
|
||||
$SelectedMaxTokens = 128000
|
||||
$SelectedMaxContextTokens = 900000
|
||||
Write-Host ""
|
||||
Write-Ok "Using OpenAI Codex subscription"
|
||||
}
|
||||
|
||||
+3
-3
@@ -1306,9 +1306,9 @@ case $choice in
|
||||
if [ "$CODEX_CRED_DETECTED" = true ]; then
|
||||
SUBSCRIPTION_MODE="codex"
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
SELECTED_MODEL="gpt-5.3-codex"
|
||||
SELECTED_MAX_TOKENS=16384
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # GPT Codex — 128k context window
|
||||
SELECTED_MODEL="gpt-5.4"
|
||||
SELECTED_MAX_TOKENS=128000
|
||||
SELECTED_MAX_CONTEXT_TOKENS=900000 # GPT-5.4 — 1.05M context window
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using OpenAI Codex subscription"
|
||||
fi
|
||||
|
||||
@@ -561,9 +561,9 @@ switch ($num) {
|
||||
if ($CodexCredDetected) {
|
||||
$SubscriptionMode = "codex"
|
||||
$SelectedProviderId = "openai"
|
||||
$SelectedModel = "gpt-5.3-codex"
|
||||
$SelectedMaxTokens = 16384
|
||||
$SelectedMaxContextTokens = 120000
|
||||
$SelectedModel = "gpt-5.4"
|
||||
$SelectedMaxTokens = 128000
|
||||
$SelectedMaxContextTokens = 900000
|
||||
Write-Host ""
|
||||
Write-Ok "Using OpenAI Codex subscription"
|
||||
}
|
||||
|
||||
@@ -870,9 +870,9 @@ case $choice in
|
||||
if [ "$CODEX_CRED_DETECTED" = true ]; then
|
||||
SUBSCRIPTION_MODE="codex"
|
||||
SELECTED_PROVIDER_ID="openai"
|
||||
SELECTED_MODEL="gpt-5.3-codex"
|
||||
SELECTED_MAX_TOKENS=16384
|
||||
SELECTED_MAX_CONTEXT_TOKENS=120000 # GPT Codex — 128k context window
|
||||
SELECTED_MODEL="gpt-5.4"
|
||||
SELECTED_MAX_TOKENS=128000
|
||||
SELECTED_MAX_CONTEXT_TOKENS=900000 # GPT-5.4 — 1.05M context window
|
||||
echo ""
|
||||
echo -e "${GREEN}⬢${NC} Using OpenAI Codex subscription"
|
||||
fi
|
||||
|
||||
+588
-83
@@ -68,6 +68,15 @@ mcp = FastMCP("coder-tools")
|
||||
|
||||
PROJECT_ROOT: str = ""
|
||||
SNAPSHOT_DIR: str = ""
|
||||
_PLACEHOLDER_MARKERS = (
|
||||
"TODO",
|
||||
"TODO:",
|
||||
"TODO ",
|
||||
"Add system prompt for this node",
|
||||
"Add identity prompt",
|
||||
"Define success criteria",
|
||||
"Describe what this node does",
|
||||
)
|
||||
|
||||
|
||||
# ── Path resolution ───────────────────────────────────────────────────────
|
||||
@@ -138,6 +147,388 @@ def _resolve_path(path: str) -> str:
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_placeholder_text(value: str | None) -> bool:
|
||||
text = (value or "").strip()
|
||||
if not text:
|
||||
return True
|
||||
return any(marker in text for marker in _PLACEHOLDER_MARKERS)
|
||||
|
||||
|
||||
_ENTRY_INTAKE_HINTS = (
|
||||
"parse the incoming task",
|
||||
"parse the task text",
|
||||
"structured runtime task",
|
||||
"accept structured runtime task",
|
||||
"read runtime task input",
|
||||
"validate runtime",
|
||||
"validate runtime path",
|
||||
"validate runtime paths",
|
||||
"intake & validate",
|
||||
"configuration values",
|
||||
"intake config",
|
||||
"infer values conservatively",
|
||||
)
|
||||
_ENTRY_DIRECT_WORK_HINTS = (
|
||||
"scan",
|
||||
"scanning",
|
||||
"discover",
|
||||
"discovery",
|
||||
"search",
|
||||
"fetch",
|
||||
"analyze",
|
||||
"analyse",
|
||||
"transform",
|
||||
"sanitize",
|
||||
"summarize",
|
||||
"summarise",
|
||||
"generate",
|
||||
"write",
|
||||
"apply",
|
||||
"candidate",
|
||||
)
|
||||
_TOOL_ALIAS_HINTS = {
|
||||
"run_command": "execute_command_tool",
|
||||
}
|
||||
_OUTPUT_DIRECTORY_INPUT_HINTS = {
|
||||
"review_dir",
|
||||
"output_dir",
|
||||
"destination_dir",
|
||||
"dest_dir",
|
||||
"target_dir",
|
||||
"review_path",
|
||||
"output_path",
|
||||
"destination_path",
|
||||
"dest_path",
|
||||
"target_path",
|
||||
}
|
||||
_SESSION_DATA_TOOLS = frozenset(
|
||||
{
|
||||
"save_data",
|
||||
"load_data",
|
||||
"list_data_files",
|
||||
"append_data",
|
||||
"edit_data",
|
||||
"serve_file_to_user",
|
||||
}
|
||||
)
|
||||
_WORKSPACE_PATH_HINTS = frozenset(
|
||||
{
|
||||
"review_dir",
|
||||
"review_root",
|
||||
"output_dir",
|
||||
"output_root",
|
||||
"output_path",
|
||||
"target_dir",
|
||||
"target_root",
|
||||
"target_path",
|
||||
"workspace",
|
||||
"project folder",
|
||||
"project folders",
|
||||
}
|
||||
)
|
||||
_SESSION_DATA_TOOL_PATH_OP_RE = re.compile(
|
||||
r"(save_data|load_data|list_data_files|append_data|edit_data|serve_file_to_user)"
|
||||
r"[^.\n]{0,200}\b(?:to|into|in|inside|under|within|from|at|on)\s+(?:the\s+)?"
|
||||
r"(review_dir|review_root|output_dir|output_root|output_path|target_dir|target_root|"
|
||||
r"target_path|workspace|project folder|project folders)\b|"
|
||||
r"\b(?:to|into|in|inside|under|within|from|at|on)\s+(?:the\s+)?"
|
||||
r"(review_dir|review_root|output_dir|output_root|output_path|target_dir|target_root|"
|
||||
r"target_path|workspace|project folder|project folders)\b[^.\n]{0,200}"
|
||||
r"(save_data|load_data|list_data_files|append_data|edit_data|serve_file_to_user)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _contains_hint_word(text: str, hint: str) -> bool:
|
||||
"""Return True when *hint* appears as a word/phrase, not just a substring."""
|
||||
if " " in hint:
|
||||
return hint in text
|
||||
return re.search(rf"\b{re.escape(hint)}\b", text) is not None
|
||||
|
||||
|
||||
def _default_intro_message(human_name: str, description: str) -> str:
|
||||
"""Return a non-placeholder intro line for generated agents."""
|
||||
desc = (description or "").strip().rstrip(".")
|
||||
if desc:
|
||||
return f"{desc}."
|
||||
return f"Ready to run {human_name}."
|
||||
|
||||
|
||||
def _default_success_metric(index: int) -> str:
|
||||
"""Return a generic but non-placeholder success metric name."""
|
||||
return f"criterion_{index}_satisfied"
|
||||
|
||||
|
||||
def _looks_like_agent_path(agent_ref: str) -> bool:
|
||||
"""Return True when *agent_ref* should be treated as a filesystem path."""
|
||||
candidate = Path(agent_ref).expanduser()
|
||||
return candidate.is_absolute() or len(candidate.parts) > 1 or agent_ref.startswith((".", "~"))
|
||||
|
||||
|
||||
def _resolve_agent_package_target(agent_ref: str) -> tuple[Path, str, str]:
|
||||
"""Resolve a validator target to (agent_dir, package_name, display_ref).
|
||||
|
||||
Bare names still target exports/<name> for build-time validation.
|
||||
Paths are resolved relative to the repository root and must pass the
|
||||
server allowlist so existing example agents can be staged safely.
|
||||
"""
|
||||
ref = (agent_ref or "").strip()
|
||||
if not ref:
|
||||
raise ValueError("Agent reference is required")
|
||||
|
||||
if not PROJECT_ROOT:
|
||||
raise ValueError("PROJECT_ROOT is not configured")
|
||||
|
||||
if _looks_like_agent_path(ref):
|
||||
resolved = Path(_resolve_path(ref))
|
||||
try:
|
||||
from framework.server.app import validate_agent_path
|
||||
except ImportError as exc:
|
||||
raise ValueError("Cannot validate agent path: framework package not available") from exc
|
||||
|
||||
resolved = validate_agent_path(str(resolved))
|
||||
return resolved, resolved.name, str(resolved)
|
||||
|
||||
resolved = (Path(PROJECT_ROOT) / "exports" / ref).resolve()
|
||||
return resolved, ref, f"exports/{ref}"
|
||||
|
||||
|
||||
def _default_success_target() -> str:
|
||||
"""Return a generic non-placeholder success target."""
|
||||
return "true"
|
||||
|
||||
|
||||
def _node_can_progress_without_declared_tools(node) -> bool:
|
||||
"""Return True when a node can legitimately work without MCP/local tools.
|
||||
|
||||
Runtime supports two common cases that should not be blocked by static
|
||||
validation:
|
||||
- ``gcu`` nodes, whose browser tools are injected by the framework.
|
||||
- pure LLM work nodes that consume inputs and explicitly write outputs via
|
||||
``set_output`` without needing external tools.
|
||||
"""
|
||||
if getattr(node, "node_type", "") == "gcu":
|
||||
return True
|
||||
|
||||
output_keys = list(getattr(node, "output_keys", []) or [])
|
||||
if not output_keys:
|
||||
return False
|
||||
|
||||
prompt = (getattr(node, "system_prompt", "") or "").lower()
|
||||
text = " ".join(
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
getattr(node, "name", "") or "",
|
||||
getattr(node, "description", "") or "",
|
||||
getattr(node, "system_prompt", "") or "",
|
||||
],
|
||||
)
|
||||
).lower()
|
||||
mentions_set_output = (
|
||||
"set_output(" in prompt
|
||||
or "call set_output" in prompt
|
||||
or "use set_output" in prompt
|
||||
or "set_output " in prompt
|
||||
)
|
||||
looks_like_real_work = any(_contains_hint_word(text, hint) for hint in _ENTRY_DIRECT_WORK_HINTS)
|
||||
return mentions_set_output and looks_like_real_work
|
||||
|
||||
|
||||
def _behavior_validation_errors(agent_module) -> list[str]:
|
||||
"""Return behavior-level validation errors for a generated agent package."""
|
||||
errors: list[str] = []
|
||||
nodes = list(getattr(agent_module, "nodes", []) or [])
|
||||
terminal_ids = set(getattr(agent_module, "terminal_nodes", []) or [])
|
||||
entry_node_id = getattr(agent_module, "entry_node", None) or ""
|
||||
identity_prompt = getattr(agent_module, "identity_prompt", "") or ""
|
||||
metadata = getattr(agent_module, "metadata", None)
|
||||
goal = getattr(agent_module, "goal", None)
|
||||
|
||||
identity_prompt_text = identity_prompt.strip()
|
||||
if not identity_prompt_text:
|
||||
errors.append("identity_prompt is blank")
|
||||
elif any(marker in identity_prompt_text for marker in _PLACEHOLDER_MARKERS):
|
||||
errors.append("identity_prompt still contains TODO placeholders")
|
||||
|
||||
if metadata is not None:
|
||||
if _is_placeholder_text(getattr(metadata, "description", "") or ""):
|
||||
errors.append("metadata.description is blank or still contains TODO placeholders")
|
||||
if _is_placeholder_text(getattr(metadata, "intro_message", "") or ""):
|
||||
errors.append("metadata.intro_message is blank or still contains TODO placeholders")
|
||||
|
||||
if goal is not None:
|
||||
if _is_placeholder_text(getattr(goal, "description", "") or ""):
|
||||
errors.append("goal.description is blank or still contains TODO placeholders")
|
||||
for criterion in list(getattr(goal, "success_criteria", []) or []):
|
||||
cid = getattr(criterion, "id", "<unknown>")
|
||||
for attr in ("description", "metric", "target"):
|
||||
if _is_placeholder_text(getattr(criterion, attr, "") or ""):
|
||||
errors.append(f"Success criterion '{cid}' has blank or placeholder {attr}")
|
||||
for constraint in list(getattr(goal, "constraints", []) or []):
|
||||
cid = getattr(constraint, "id", "<unknown>")
|
||||
if _is_placeholder_text(getattr(constraint, "description", "") or ""):
|
||||
errors.append(f"Constraint '{cid}' has blank or placeholder description")
|
||||
|
||||
for node in nodes:
|
||||
node_id = getattr(node, "id", "<unknown>")
|
||||
node_desc = getattr(node, "description", "") or ""
|
||||
if _is_placeholder_text(node_desc):
|
||||
errors.append(f"Node '{node_id}' has a blank or placeholder description")
|
||||
|
||||
prompt = getattr(node, "system_prompt", "") or ""
|
||||
prompt_lower = prompt.lower()
|
||||
if _is_placeholder_text(prompt):
|
||||
errors.append(f"Node '{node_id}' has a blank or placeholder system_prompt")
|
||||
else:
|
||||
tools = list(getattr(node, "tools", []) or [])
|
||||
for tool_name in tools:
|
||||
if isinstance(tool_name, str) and f"{tool_name}(" in prompt:
|
||||
errors.append(
|
||||
f"Node '{node_id}' system_prompt uses callable-style tool syntax for "
|
||||
f"'{tool_name}'. Describe tool usage in prose instead of "
|
||||
"Python-style calls."
|
||||
)
|
||||
for alias, actual in _TOOL_ALIAS_HINTS.items():
|
||||
if alias in prompt and actual in tools and alias not in tools:
|
||||
errors.append(
|
||||
f"Node '{node_id}' system_prompt references unsupported tool alias "
|
||||
f"'{alias}'. Use the actual registered tool name '{actual}'."
|
||||
)
|
||||
data_tools_used = [tool for tool in tools if tool in _SESSION_DATA_TOOLS]
|
||||
if data_tools_used:
|
||||
workspace_path_hints = []
|
||||
if _SESSION_DATA_TOOL_PATH_OP_RE.search(prompt):
|
||||
workspace_path_hints = [
|
||||
hint
|
||||
for hint in _WORKSPACE_PATH_HINTS
|
||||
if _contains_hint_word(prompt_lower, hint)
|
||||
]
|
||||
if workspace_path_hints:
|
||||
joined_tools = ", ".join(sorted(data_tools_used))
|
||||
joined_hints = ", ".join(sorted(workspace_path_hints))
|
||||
errors.append(
|
||||
f"Node '{node_id}' uses session data tools ({joined_tools}) as if they "
|
||||
f"can operate on workspace paths ({joined_hints}). Data tools use the "
|
||||
"framework-managed session data directory; use execute_command_tool for "
|
||||
"workspace/review/output directories."
|
||||
)
|
||||
|
||||
success_criteria = getattr(node, "success_criteria", "") or ""
|
||||
if _is_placeholder_text(success_criteria):
|
||||
errors.append(f"Node '{node_id}' has blank or placeholder success_criteria")
|
||||
|
||||
tools = list(getattr(node, "tools", []) or [])
|
||||
sub_agents = list(getattr(node, "sub_agents", []) or [])
|
||||
client_facing = bool(getattr(node, "client_facing", False))
|
||||
if (
|
||||
node_id not in terminal_ids
|
||||
and not client_facing
|
||||
and not tools
|
||||
and not sub_agents
|
||||
and not _node_can_progress_without_declared_tools(node)
|
||||
):
|
||||
errors.append(f"Autonomous node '{node_id}' has no tools or sub_agents")
|
||||
|
||||
if node_id == entry_node_id:
|
||||
input_keys = list(getattr(node, "input_keys", []) or [])
|
||||
output_keys = list(getattr(node, "output_keys", []) or [])
|
||||
text = " ".join(
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
getattr(node, "name", "") or "",
|
||||
node_desc,
|
||||
prompt,
|
||||
],
|
||||
)
|
||||
).lower()
|
||||
lowered_input_keys = {str(key).lower() for key in input_keys}
|
||||
lowered_output_keys = {str(key).lower() for key in output_keys}
|
||||
generic_task_only = len(input_keys) == 1 and input_keys[0] in {
|
||||
"task",
|
||||
"user_request",
|
||||
"raw",
|
||||
"input",
|
||||
"request",
|
||||
"message",
|
||||
}
|
||||
intake_like = any(hint in text for hint in _ENTRY_INTAKE_HINTS)
|
||||
direct_work_like = any(
|
||||
_contains_hint_word(text, hint) for hint in _ENTRY_DIRECT_WORK_HINTS
|
||||
)
|
||||
pass_through_inputs = bool(lowered_input_keys) and (
|
||||
lowered_input_keys <= lowered_output_keys
|
||||
)
|
||||
runtime_normalization_only = any(
|
||||
hint in text
|
||||
for hint in (
|
||||
"validate",
|
||||
"validation",
|
||||
"normalize",
|
||||
"normalise",
|
||||
"config",
|
||||
"configuration",
|
||||
"runtime",
|
||||
"path",
|
||||
"paths",
|
||||
)
|
||||
)
|
||||
if (
|
||||
intake_like
|
||||
and not direct_work_like
|
||||
and (generic_task_only or pass_through_inputs or runtime_normalization_only)
|
||||
):
|
||||
errors.append(
|
||||
f"Entry node '{node_id}' appears to be an intake/config parser. "
|
||||
"The queen handles intake. Make the first real work node consume "
|
||||
"structured input_keys directly instead of reparsing a generic task string."
|
||||
)
|
||||
for input_key in input_keys:
|
||||
lowered_key = str(input_key).lower()
|
||||
if lowered_key not in _OUTPUT_DIRECTORY_INPUT_HINTS:
|
||||
continue
|
||||
if (
|
||||
lowered_key in text
|
||||
and "exist" in text
|
||||
and ("directory" in text or "directories" in text)
|
||||
):
|
||||
errors.append(
|
||||
f"Entry node '{node_id}' requires output path '{input_key}' to pre-exist. "
|
||||
"Output/review directories should be created if missing instead of "
|
||||
"blocking the run during intake validation."
|
||||
)
|
||||
break
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _classify_behavior_validation_errors(errors: list[str]) -> tuple[list[str], list[str]]:
|
||||
"""Split behavior validation findings into blocking errors and warnings.
|
||||
|
||||
Hard failures are reserved for issues that are likely to break runtime
|
||||
execution or violate framework contracts. Quality/style issues remain
|
||||
warnings so they can be surfaced without preventing staging/runs.
|
||||
"""
|
||||
blocking_markers = (
|
||||
"identity_prompt still contains TODO placeholders",
|
||||
"blank or placeholder system_prompt",
|
||||
"uses session data tools",
|
||||
"Autonomous node ",
|
||||
"appears to be an intake/config parser",
|
||||
"requires output path",
|
||||
)
|
||||
|
||||
blocking: list[str] = []
|
||||
warnings: list[str] = []
|
||||
for error in errors:
|
||||
if any(marker in error for marker in blocking_markers):
|
||||
blocking.append(error)
|
||||
else:
|
||||
warnings.append(error)
|
||||
return blocking, warnings
|
||||
|
||||
|
||||
# ── Git snapshot system (ported from opencode's shadow git) ───────────────
|
||||
|
||||
|
||||
@@ -752,9 +1143,14 @@ def list_agent_tools(
|
||||
|
||||
|
||||
def _validate_agent_tools_impl(agent_path: str) -> dict:
|
||||
"""Validate that all tools declared in an agent's nodes exist in its MCP servers.
|
||||
"""Validate that all tools declared in an agent's nodes exist at runtime.
|
||||
|
||||
Returns a dict with validation result: pass/fail, missing tools per node, available tools.
|
||||
Mirrors runtime tool discovery:
|
||||
1. MCP tools from ``mcp_servers.json`` (when present)
|
||||
2. Agent-local ``tools.py`` custom tools (when present)
|
||||
|
||||
Returns a dict with validation result: pass/fail, missing tools per node,
|
||||
available tools, and discovery warnings.
|
||||
"""
|
||||
try:
|
||||
resolved = _resolve_path(agent_path)
|
||||
@@ -781,11 +1177,8 @@ def _validate_agent_tools_impl(agent_path: str) -> dict:
|
||||
|
||||
agent_dir = resolved # Keep path; 'resolved' is reused for MCP config in loop
|
||||
|
||||
# --- Discover available tools from agent's MCP servers ---
|
||||
# --- Discover available tools from MCP + local tools.py ---
|
||||
mcp_config_path = os.path.join(agent_dir, "mcp_servers.json")
|
||||
if not os.path.isfile(mcp_config_path):
|
||||
return {"error": f"No mcp_servers.json found in {agent_path}"}
|
||||
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
@@ -796,36 +1189,45 @@ def _validate_agent_tools_impl(agent_path: str) -> dict:
|
||||
|
||||
available_tools: set[str] = set()
|
||||
discovery_errors = []
|
||||
config_dir = Path(mcp_config_path).parent
|
||||
|
||||
try:
|
||||
with open(mcp_config_path, encoding="utf-8") as f:
|
||||
servers_config = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
return {"error": f"Failed to read mcp_servers.json: {e}"}
|
||||
|
||||
for server_name, server_conf in servers_config.items():
|
||||
resolved = ToolRegistry.resolve_mcp_stdio_config(
|
||||
{"name": server_name, **server_conf}, config_dir
|
||||
)
|
||||
if os.path.isfile(mcp_config_path):
|
||||
config_dir = Path(mcp_config_path).parent
|
||||
try:
|
||||
config = MCPServerConfig(
|
||||
name=server_name,
|
||||
transport=resolved.get("transport", "stdio"),
|
||||
command=resolved.get("command"),
|
||||
args=resolved.get("args", []),
|
||||
env=resolved.get("env", {}),
|
||||
cwd=resolved.get("cwd"),
|
||||
url=resolved.get("url"),
|
||||
headers=resolved.get("headers", {}),
|
||||
with open(mcp_config_path, encoding="utf-8") as f:
|
||||
servers_config = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
return {"error": f"Failed to read mcp_servers.json: {e}"}
|
||||
|
||||
for server_name, server_conf in servers_config.items():
|
||||
resolved = ToolRegistry.resolve_mcp_stdio_config(
|
||||
{"name": server_name, **server_conf}, config_dir
|
||||
)
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
for tool in client.list_tools():
|
||||
available_tools.add(tool.name)
|
||||
client.disconnect()
|
||||
try:
|
||||
config = MCPServerConfig(
|
||||
name=server_name,
|
||||
transport=resolved.get("transport", "stdio"),
|
||||
command=resolved.get("command"),
|
||||
args=resolved.get("args", []),
|
||||
env=resolved.get("env", {}),
|
||||
cwd=resolved.get("cwd"),
|
||||
url=resolved.get("url"),
|
||||
headers=resolved.get("headers", {}),
|
||||
)
|
||||
client = MCPClient(config)
|
||||
client.connect()
|
||||
for tool in client.list_tools():
|
||||
available_tools.add(tool.name)
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
discovery_errors.append({"server": server_name, "error": str(e)})
|
||||
|
||||
local_tools_path = Path(agent_dir) / "tools.py"
|
||||
if local_tools_path.is_file():
|
||||
try:
|
||||
registry = ToolRegistry()
|
||||
registry.discover_from_module(local_tools_path)
|
||||
available_tools.update(registry.get_tools().keys())
|
||||
except Exception as e:
|
||||
discovery_errors.append({"server": server_name, "error": str(e)})
|
||||
discovery_errors.append({"server": "tools.py", "error": str(e)})
|
||||
|
||||
# --- Load agent nodes and extract declared tools ---
|
||||
agent_py = os.path.join(agent_dir, "agent.py")
|
||||
@@ -1227,7 +1629,7 @@ def get_agent_checkpoint(
|
||||
|
||||
|
||||
def _run_agent_tests_impl(
|
||||
agent_name: str,
|
||||
agent_ref: str,
|
||||
test_types: str = "all",
|
||||
fail_fast: bool = False,
|
||||
) -> dict:
|
||||
@@ -1235,22 +1637,35 @@ def _run_agent_tests_impl(
|
||||
|
||||
Returns a dict with summary counts, per-test results, and failure details.
|
||||
"""
|
||||
agent_path = Path(PROJECT_ROOT) / "exports" / agent_name
|
||||
try:
|
||||
agent_path, agent_name, display_ref = _resolve_agent_package_target(agent_ref)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
if not agent_path.is_dir():
|
||||
# Fall back to framework agents
|
||||
# Fall back to framework agents for bare framework package names.
|
||||
agent_path = Path(PROJECT_ROOT) / "core" / "framework" / "agents" / agent_name
|
||||
tests_dir = agent_path / "tests"
|
||||
|
||||
if not agent_path.is_dir():
|
||||
return {
|
||||
"error": f"Agent not found: {agent_name}",
|
||||
"error": f"Agent not found: {agent_ref}",
|
||||
"hint": "Use list_agents() to see available agents.",
|
||||
}
|
||||
|
||||
if not tests_dir.exists():
|
||||
return {
|
||||
"error": f"No tests directory: exports/{agent_name}/tests/",
|
||||
"hint": "Create test files in the tests/ directory first.",
|
||||
"agent_name": agent_name,
|
||||
"agent_path": str(agent_path),
|
||||
"summary": f"No tests directory: {tests_dir}",
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"skipped": 1,
|
||||
"errors": 0,
|
||||
"total": 0,
|
||||
"test_results": [],
|
||||
"failures": [],
|
||||
"skipped_all": True,
|
||||
}
|
||||
|
||||
# Parse test types
|
||||
@@ -1297,7 +1712,10 @@ def _run_agent_tests_impl(
|
||||
core_path = os.path.join(PROJECT_ROOT, "core")
|
||||
exports_path = os.path.join(PROJECT_ROOT, "exports")
|
||||
fw_agents_path = os.path.join(PROJECT_ROOT, "core", "framework", "agents")
|
||||
package_parent = str(agent_path.parent)
|
||||
path_parts = [core_path, exports_path, fw_agents_path, PROJECT_ROOT]
|
||||
if package_parent not in path_parts:
|
||||
path_parts.insert(1, package_parent)
|
||||
if pythonpath:
|
||||
path_parts.append(pythonpath)
|
||||
env["PYTHONPATH"] = os.pathsep.join(path_parts)
|
||||
@@ -1387,6 +1805,7 @@ def _run_agent_tests_impl(
|
||||
|
||||
return {
|
||||
"agent_name": agent_name,
|
||||
"agent_path": str(agent_path),
|
||||
"summary": summary_text,
|
||||
"passed": passed,
|
||||
"failed": failed,
|
||||
@@ -1425,28 +1844,50 @@ def run_agent_tests(
|
||||
# ── Meta-agent: Unified agent validation ───────────────────────────────────
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def validate_agent_package(agent_name: str) -> str:
|
||||
def _validate_agent_package_impl(agent_name: str) -> dict[str, object]:
|
||||
"""Run structural validation checks on a built agent package in one call.
|
||||
|
||||
Executes 5 steps and reports all results (does not stop on first failure):
|
||||
Executes multiple checks and reports all results (does not stop on first failure):
|
||||
1. Class validation — checks graph structure and entry_points contract
|
||||
2. Node completeness — every NodeSpec in nodes/ must be in the nodes list,
|
||||
and GCU nodes must be referenced in a parent's sub_agents
|
||||
3. Graph validation — loads the agent graph without credential checks
|
||||
4. Tool validation — checks declared tools exist in MCP servers
|
||||
5. Tests — runs the agent's pytest suite
|
||||
4. Behavior validation — rejects placeholder prompts and empty autonomous nodes
|
||||
5. Tool validation — checks declared tools exist in MCP servers
|
||||
6. Tests — runs the agent's pytest suite
|
||||
|
||||
Note: Credential validation is intentionally skipped here (building phase).
|
||||
Credentials are validated at run time by run_agent_with_input() preflight.
|
||||
|
||||
Args:
|
||||
agent_name: Agent package name (e.g. 'my_agent'). Must exist in exports/.
|
||||
agent_name: Agent package name (e.g. 'my_agent') or an allowed
|
||||
agent path such as examples/templates/my_agent.
|
||||
|
||||
Returns:
|
||||
JSON with per-step results and overall pass/fail summary
|
||||
Dict with per-step results and overall pass/fail summary
|
||||
"""
|
||||
agent_path = f"exports/{agent_name}"
|
||||
global PROJECT_ROOT, SNAPSHOT_DIR
|
||||
|
||||
if not PROJECT_ROOT:
|
||||
PROJECT_ROOT = _find_project_root()
|
||||
if not SNAPSHOT_DIR and PROJECT_ROOT:
|
||||
SNAPSHOT_DIR = os.path.join(
|
||||
os.path.expanduser("~"),
|
||||
".hive",
|
||||
"snapshots",
|
||||
os.path.basename(PROJECT_ROOT),
|
||||
)
|
||||
|
||||
try:
|
||||
agent_dir, package_name, display_ref = _resolve_agent_package_target(agent_name)
|
||||
except ValueError as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"agent_name": agent_name,
|
||||
"steps": {"target_resolution": {"passed": False, "error": str(e)}},
|
||||
"summary": "FAIL: 1 of 1 steps failed (target_resolution)",
|
||||
}
|
||||
|
||||
steps: dict[str, dict] = {}
|
||||
|
||||
# Set up env for subprocess calls
|
||||
@@ -1454,8 +1895,11 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
core_path = os.path.join(PROJECT_ROOT, "core")
|
||||
exports_path = os.path.join(PROJECT_ROOT, "exports")
|
||||
fw_agents_path = os.path.join(PROJECT_ROOT, "core", "framework", "agents")
|
||||
package_parent = str(agent_dir.parent)
|
||||
pythonpath = env.get("PYTHONPATH", "")
|
||||
path_parts = [core_path, exports_path, fw_agents_path, PROJECT_ROOT]
|
||||
if package_parent not in path_parts:
|
||||
path_parts.insert(1, package_parent)
|
||||
if pythonpath:
|
||||
path_parts.append(pythonpath)
|
||||
env["PYTHONPATH"] = os.pathsep.join(path_parts)
|
||||
@@ -1464,22 +1908,22 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
try:
|
||||
_contract_script = textwrap.dedent("""\
|
||||
import importlib, json
|
||||
mod = importlib.import_module('{agent_name}')
|
||||
mod = importlib.import_module('{package_name}')
|
||||
missing = [a for a in ('goal', 'nodes', 'edges') if getattr(mod, a, None) is None]
|
||||
if missing:
|
||||
print(json.dumps({{
|
||||
'valid': False,
|
||||
'error': (
|
||||
"Module '{agent_name}' is missing module-level attributes: "
|
||||
"Module '{package_name}' is missing module-level attributes: "
|
||||
+ ", ".join(missing) + ". "
|
||||
"Fix: in {agent_name}/__init__.py, add "
|
||||
"Fix: in {package_name}/__init__.py, add "
|
||||
"'from .agent import " + ", ".join(missing) + "' "
|
||||
"so that 'import {agent_name}' exposes them at package level."
|
||||
"so that 'import {package_name}' exposes them at package level."
|
||||
)
|
||||
}}))
|
||||
else:
|
||||
print(json.dumps({{'valid': True}}))
|
||||
""").format(agent_name=agent_name)
|
||||
""").format(package_name=package_name)
|
||||
proc = subprocess.run(
|
||||
["uv", "run", "python", "-c", _contract_script],
|
||||
capture_output=True,
|
||||
@@ -1499,8 +1943,8 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
steps["module_contract"] = {
|
||||
"passed": False,
|
||||
"error": (
|
||||
f"Failed to import '{agent_name}': {proc.stderr.strip()[:1000]}. "
|
||||
f"Fix: ensure {agent_name}/__init__.py exists and can be imported "
|
||||
f"Failed to import '{package_name}': {proc.stderr.strip()[:1000]}. "
|
||||
f"Fix: ensure {package_name}/__init__.py exists and can be imported "
|
||||
f"without errors (check syntax, missing dependencies, relative imports)."
|
||||
),
|
||||
}
|
||||
@@ -1515,7 +1959,7 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
"run",
|
||||
"python",
|
||||
"-c",
|
||||
f"from {agent_name} import default_agent; print(default_agent.validate())",
|
||||
f"from {package_name} import default_agent; print(default_agent.validate())",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
@@ -1562,7 +2006,7 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
)
|
||||
print(json.dumps({{'valid': len(errors) == 0, 'errors': errors}}))
|
||||
""")
|
||||
check_script = _check_template.format(agent_name=agent_name)
|
||||
check_script = _check_template.format(agent_name=package_name)
|
||||
proc = subprocess.run(
|
||||
["uv", "run", "python", "-c", check_script],
|
||||
capture_output=True,
|
||||
@@ -1603,7 +2047,7 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
"python",
|
||||
"-c",
|
||||
f"from framework.runner.runner import AgentRunner; "
|
||||
f'r = AgentRunner.load("exports/{agent_name}", '
|
||||
f"r = AgentRunner.load({str(display_ref)!r}, "
|
||||
f"skip_credential_validation=True); "
|
||||
f'print("AgentRunner.load (graph-only): OK")',
|
||||
],
|
||||
@@ -1624,9 +2068,42 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
except Exception as e:
|
||||
steps["graph_validation"] = {"passed": False, "error": str(e)}
|
||||
|
||||
# Step B2: Behavior validation — reject placeholder prompts and empty work nodes
|
||||
try:
|
||||
import importlib
|
||||
|
||||
if package_parent not in sys.path:
|
||||
sys.path.insert(0, package_parent)
|
||||
|
||||
stale = [
|
||||
name
|
||||
for name in sys.modules
|
||||
if name == package_name or name.startswith(f"{package_name}.")
|
||||
]
|
||||
for name in stale:
|
||||
del sys.modules[name]
|
||||
|
||||
agent_mod = importlib.import_module(package_name)
|
||||
behavior_errors = _behavior_validation_errors(agent_mod)
|
||||
behavior_blockers, behavior_warnings = _classify_behavior_validation_errors(behavior_errors)
|
||||
steps["behavior_validation"] = {
|
||||
"passed": len(behavior_blockers) == 0,
|
||||
"output": (
|
||||
"No placeholder prompts or empty autonomous nodes detected"
|
||||
if not behavior_errors
|
||||
else "; ".join(behavior_blockers or behavior_warnings)
|
||||
),
|
||||
}
|
||||
if behavior_blockers:
|
||||
steps["behavior_validation"]["errors"] = behavior_blockers
|
||||
if behavior_warnings:
|
||||
steps["behavior_validation"]["warnings"] = behavior_warnings
|
||||
except Exception as e:
|
||||
steps["behavior_validation"] = {"passed": False, "error": str(e)}
|
||||
|
||||
# Step C: Tool validation (direct call)
|
||||
try:
|
||||
tool_result = _validate_agent_tools_impl(agent_path)
|
||||
tool_result = _validate_agent_tools_impl(str(agent_dir))
|
||||
if "error" in tool_result:
|
||||
steps["tool_validation"] = {"passed": False, "error": tool_result["error"]}
|
||||
else:
|
||||
@@ -1641,17 +2118,33 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
|
||||
# Step D: Tests (direct call)
|
||||
try:
|
||||
test_result = _run_agent_tests_impl(agent_name)
|
||||
if "error" in test_result:
|
||||
steps["tests"] = {"passed": False, "error": test_result["error"]}
|
||||
test_result = _run_agent_tests_impl(str(agent_dir))
|
||||
if test_result.get("skipped_all"):
|
||||
steps["tests"] = {
|
||||
"passed": True,
|
||||
"skipped": True,
|
||||
"summary": test_result.get("summary", "No tests directory found; skipped"),
|
||||
}
|
||||
elif "error" in test_result:
|
||||
steps["tests"] = {
|
||||
"passed": False,
|
||||
"warning": test_result["error"],
|
||||
"warnings": [test_result["error"]],
|
||||
}
|
||||
else:
|
||||
all_passed = test_result.get("failed", 0) == 0 and test_result.get("errors", 0) == 0
|
||||
steps["tests"] = {
|
||||
"passed": all_passed,
|
||||
"summary": test_result.get("summary", "unknown"),
|
||||
}
|
||||
if not all_passed and test_result.get("failures"):
|
||||
steps["tests"]["failures"] = test_result["failures"]
|
||||
if not all_passed:
|
||||
warning_summary = (
|
||||
f"Test suite not fully passing: {test_result.get('summary', 'unknown')}"
|
||||
)
|
||||
steps["tests"]["warning"] = warning_summary
|
||||
steps["tests"]["warnings"] = [warning_summary]
|
||||
if test_result.get("failures"):
|
||||
steps["tests"]["failures"] = test_result["failures"]
|
||||
except Exception as e:
|
||||
steps["tests"] = {"passed": False, "error": str(e)}
|
||||
|
||||
@@ -1665,13 +2158,20 @@ def validate_agent_package(agent_name: str) -> str:
|
||||
else:
|
||||
summary = f"FAIL: {len(failed_steps)} of {total} steps failed ({', '.join(failed_steps)})"
|
||||
|
||||
return {
|
||||
"valid": valid,
|
||||
"agent_name": package_name,
|
||||
"agent_path": str(agent_dir),
|
||||
"steps": steps,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def validate_agent_package(agent_name: str) -> str:
|
||||
"""Run structural validation checks on a built agent package in one call."""
|
||||
return json.dumps(
|
||||
{
|
||||
"valid": valid,
|
||||
"agent_name": agent_name,
|
||||
"steps": steps,
|
||||
"summary": summary,
|
||||
},
|
||||
_validate_agent_package_impl(agent_name),
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
@@ -1804,10 +2304,10 @@ default_config = RuntimeConfig()
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
name: str = "{human_name}"
|
||||
name: str = {human_name!r}
|
||||
version: str = "1.0.0"
|
||||
description: str = "{_draft_desc or "TODO: Add agent description."}"
|
||||
intro_message: str = "TODO: Add intro message."
|
||||
description: str = {(_draft_desc or "TODO: Add agent description.")!r}
|
||||
intro_message: str = {_default_intro_message(human_name, _draft_desc)!r}
|
||||
|
||||
|
||||
metadata = AgentMetadata()
|
||||
@@ -1913,8 +2413,8 @@ __all__ = {node_var_names!r}
|
||||
SuccessCriterion(
|
||||
id="sc-{i + 1}",
|
||||
description="{sc}",
|
||||
metric="TODO",
|
||||
target="TODO",
|
||||
metric="{_default_success_metric(i + 1)}",
|
||||
target="{_default_success_target()}",
|
||||
weight=1.0,
|
||||
),"""
|
||||
for i, sc in enumerate(_draft_sc)
|
||||
@@ -1924,8 +2424,8 @@ __all__ = {node_var_names!r}
|
||||
SuccessCriterion(
|
||||
id="sc-1",
|
||||
description="TODO: Define success criterion.",
|
||||
metric="TODO",
|
||||
target="TODO",
|
||||
metric="criterion_1_satisfied",
|
||||
target="true",
|
||||
weight=1.0,
|
||||
),"""
|
||||
|
||||
@@ -1997,7 +2497,10 @@ pause_nodes = []
|
||||
terminal_nodes = []
|
||||
|
||||
conversation_mode = "continuous"
|
||||
identity_prompt = "TODO: Add identity prompt."
|
||||
identity_prompt = (
|
||||
"You are {human_name}, a focused Hive worker that follows the goal, "
|
||||
"constraints, and node instructions precisely."
|
||||
)
|
||||
loop_config = {{
|
||||
"max_iterations": 100,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
@@ -2063,7 +2566,7 @@ class {class_name}:
|
||||
name="Default",
|
||||
entry_node=self.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
isolation_level="isolated",
|
||||
),
|
||||
],
|
||||
llm=llm,
|
||||
@@ -2347,11 +2850,13 @@ def runner_loaded():
|
||||
"files": all_file_paths,
|
||||
"next_steps": [
|
||||
(
|
||||
"IMPORTANT: All generated files are structurally complete "
|
||||
"with correct imports, class definition, validate() method, "
|
||||
"and __init__.py exports. Use edit_file to customize TODO "
|
||||
"placeholders — do NOT use write_file to rewrite entire files, "
|
||||
"as this will break imports and structure."
|
||||
"IMPORTANT: The generated scaffold has correct imports, class "
|
||||
"definition, validate() method, and __init__.py exports, but "
|
||||
"it is NOT ready to load or run yet. Replace every TODO / "
|
||||
"placeholder prompt and make validation pass before staging. "
|
||||
"Use edit_file to customize placeholders — do NOT use "
|
||||
"write_file to rewrite entire files, as this will break "
|
||||
"imports and structure."
|
||||
),
|
||||
(
|
||||
f"Use edit_file to customize system prompts, tools, "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user