Compare commits

...

16 Commits

Author SHA1 Message Date
Richard Tang 9e252fc33e fix: missing MCPRegistry import 2026-03-31 18:50:08 -07:00
RichardTang-Aden bf57220c8f Merge branch 'main' into codex/pr3-codex-workspace-parity 2026-03-31 18:45:46 -07:00
Richard Tang cb1c07e60c fix: codex api wiring mode 2026-03-31 18:38:13 -07:00
Richard Tang 2aa38ad9bb feat: models and parameters 2026-03-31 18:17:43 -07:00
Richard Tang c0b8980447 fix: broken ci tests 2026-03-31 17:12:15 -07:00
Richard Tang 601a5d87e9 fix: orchrator fallback 2026-03-31 15:49:08 -07:00
Richard Tang 29e85a13c7 Merge remote-tracking branch 'origin/main' into codex/pr3-codex-workspace-parity 2026-03-31 15:43:28 -07:00
Richard Tang 220beb5c64 revert: revert orchestrator primary-result heuristics per review 2026-03-31 15:32:09 -07:00
Vasu Bansal f56600f7af fix: tighten validation and rerun default guards 2026-03-28 12:14:01 +05:30
Vasu Bansal 148f61ac3e fix: address codex parity review findings 2026-03-28 09:14:07 +05:30
Vasu Bansal e28d989c92 fix: restore image handoff support in codex parity pr 2026-03-28 07:39:23 +05:30
Vasu Bansal 4681e52f86 fix: align codex parity PR with CI formatting and windows tests 2026-03-27 22:51:27 +05:30
Vasu Bansal 8351c808dc fix: resolve codex parity CI regressions 2026-03-27 22:43:00 +05:30
Vasu Bansal c959cab9c2 fix: improve codex workspace run and rerun parity 2026-03-27 21:02:11 +05:30
Vasu Bansal 9526570e0a fix: harden codex backend parity across runtime and validation 2026-03-27 20:51:58 +05:30
Vasu Bansal e87a40a7c3 fix: add adapter-first codex provider integration 2026-03-27 20:45:47 +05:30
43 changed files with 10489 additions and 1202 deletions
+7 -31
View File
@@ -13,6 +13,7 @@ from pathlib import Path
from typing import Any
from framework.graph.edge import DEFAULT_MAX_TOKENS
from framework.llm.codex_backend import CODEX_API_BASE, build_codex_litellm_kwargs
# ---------------------------------------------------------------------------
# Low-level config file access
@@ -125,7 +126,6 @@ def get_worker_api_key() -> str | None:
return token
except ImportError:
pass
api_key_env_var = worker_llm.get("api_key_env_var")
if api_key_env_var:
return os.environ.get(api_key_env_var)
@@ -141,7 +141,7 @@ def get_worker_api_base() -> str | None:
return get_api_base()
if worker_llm.get("use_codex_subscription"):
return "https://chatgpt.com/backend-api/codex"
return CODEX_API_BASE
if worker_llm.get("use_kimi_code_subscription"):
return "https://api.kimi.com/coding"
if worker_llm.get("use_antigravity_subscription"):
@@ -169,25 +169,14 @@ def get_worker_llm_extra_kwargs() -> dict[str, Any]:
if worker_llm.get("use_codex_subscription"):
api_key = get_worker_api_key()
if api_key:
headers: dict[str, str] = {
"Authorization": f"Bearer {api_key}",
"User-Agent": "CodexBar",
}
account_id = None
try:
from framework.runner.runner import get_codex_account_id
account_id = get_codex_account_id()
if account_id:
headers["ChatGPT-Account-Id"] = account_id
except ImportError:
pass
return {
"extra_headers": headers,
"store": False,
"allowed_openai_params": ["store"],
}
if worker_llm.get("provider") == "ollama":
return {"num_ctx": worker_llm.get("num_ctx", 16384)}
return build_codex_litellm_kwargs(api_key, account_id=account_id)
return {}
@@ -276,7 +265,6 @@ def get_api_key() -> str | None:
return token
except ImportError:
pass
# Standard env-var path (covers ZAI Code and all API-key providers)
api_key_env_var = llm.get("api_key_env_var")
if api_key_env_var:
@@ -382,7 +370,7 @@ def get_api_base() -> str | None:
llm = get_hive_config().get("llm", {})
if llm.get("use_codex_subscription"):
# Codex subscription routes through the ChatGPT backend, not api.openai.com.
return "https://chatgpt.com/backend-api/codex"
return CODEX_API_BASE
if llm.get("use_kimi_code_subscription"):
# Kimi Code uses an Anthropic-compatible endpoint (no /v1 suffix).
return "https://api.kimi.com/coding"
@@ -417,27 +405,15 @@ def get_llm_extra_kwargs() -> dict[str, Any]:
if llm.get("use_codex_subscription"):
api_key = get_api_key()
if api_key:
headers: dict[str, str] = {
"Authorization": f"Bearer {api_key}",
"User-Agent": "CodexBar",
}
account_id = None
try:
from framework.runner.runner import get_codex_account_id
account_id = get_codex_account_id()
if account_id:
headers["ChatGPT-Account-Id"] = account_id
except ImportError:
pass
return {
"extra_headers": headers,
"store": False,
"allowed_openai_params": ["store"],
}
return build_codex_litellm_kwargs(api_key, account_id=account_id)
if llm.get("provider") == "ollama":
# Pass num_ctx to Ollama so it doesn't silently truncate the ~9.5k Queen prompt.
# Ollama's default num_ctx is only 2048. We set it to 16384 here so LiteLLM
# passes it through as a provider-specific option.
return {"num_ctx": llm.get("num_ctx", 16384)}
return {}
+15 -5
View File
@@ -351,13 +351,15 @@ class NodeConversation:
def system_prompt(self) -> str:
return self._system_prompt
def update_system_prompt(self, new_prompt: str) -> None:
def update_system_prompt(self, new_prompt: str, output_keys: list[str] | None = None) -> None:
"""Update the system prompt.
Used in continuous conversation mode at phase transitions to swap
Layer 3 (focus) while preserving the conversation history.
"""
self._system_prompt = new_prompt
if output_keys is not None:
self._output_keys = output_keys
self._meta_persisted = False # re-persist with new prompt
def set_current_phase(self, phase_id: str) -> None:
@@ -771,7 +773,7 @@ class NodeConversation:
delete_before = recent_messages[0].seq if recent_messages else self._next_seq
await self._store.delete_parts_before(delete_before)
await self._store.write_part(summary_msg.seq, summary_msg.to_storage_dict())
await self._store.write_cursor({"next_seq": self._next_seq})
await self._write_cursor_update({"next_seq": self._next_seq})
self._messages = [summary_msg] + recent_messages
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
@@ -975,7 +977,7 @@ class NodeConversation:
# Write kept structural messages (they may have been modified)
for msg in kept_structural:
await self._store.write_part(msg.seq, msg.to_storage_dict())
await self._store.write_cursor({"next_seq": self._next_seq})
await self._write_cursor_update({"next_seq": self._next_seq})
# Reassemble: reference + kept structural (in original order) + recent
self._messages = [ref_msg] + kept_structural + recent_messages
@@ -1012,7 +1014,7 @@ class NodeConversation:
"""Remove all messages, keep system prompt, preserve ``_next_seq``."""
if self._store:
await self._store.delete_parts_before(self._next_seq)
await self._store.write_cursor({"next_seq": self._next_seq})
await self._write_cursor_update({"next_seq": self._next_seq})
self._messages.clear()
self._last_api_input_tokens = None
@@ -1047,6 +1049,14 @@ class NodeConversation:
# --- Persistence internals ---------------------------------------------
async def _write_cursor_update(self, data: dict[str, Any]) -> None:
"""Merge cursor updates instead of clobbering existing crash-recovery state."""
if self._store is None:
return
cursor = await self._store.read_cursor() or {}
cursor.update(data)
await self._store.write_cursor(cursor)
async def _persist(self, message: Message) -> None:
"""Write-through a single message. No-op when store is None."""
if self._store is None:
@@ -1054,7 +1064,7 @@ class NodeConversation:
if not self._meta_persisted:
await self._persist_meta()
await self._store.write_part(message.seq, message.to_storage_dict())
await self._store.write_cursor({"next_seq": self._next_seq})
await self._write_cursor_update({"next_seq": self._next_seq})
async def _persist_meta(self) -> None:
"""Lazily write conversation metadata to the store (called once)."""
File diff suppressed because it is too large Load Diff
+17 -5
View File
@@ -1480,7 +1480,22 @@ class GraphExecutor:
narrative=narrative,
accounts_prompt=_node_accounts,
)
continuous_conversation.update_system_prompt(new_system)
continuous_conversation.update_system_prompt(
new_system,
output_keys=list(next_spec.output_keys or []),
)
# Stamp the next phase before inserting the transition
# marker so the marker itself is preserved with the
# phase it introduces during compaction/restore.
continuous_conversation.set_current_phase(next_spec.id)
transition_tool_names = set(cumulative_tool_names)
transition_tool_names.update(next_spec.tools or [])
if next_spec.output_keys:
transition_tool_names.add("set_output")
if next_spec.client_facing:
transition_tool_names.update({"ask_user", "ask_user_multiple"})
# Insert transition marker into conversation
data_dir = str(self._storage_path / "data") if self._storage_path else None
@@ -1488,7 +1503,7 @@ class GraphExecutor:
previous_node=node_spec,
next_node=next_spec,
memory=memory,
cumulative_tool_names=sorted(cumulative_tool_names),
cumulative_tool_names=sorted(transition_tool_names),
data_dir=data_dir,
adapt_content=_adapt_text,
)
@@ -1497,9 +1512,6 @@ class GraphExecutor:
is_transition_marker=True,
)
# Set current phase for phase-aware compaction
continuous_conversation.set_current_phase(next_spec.id)
# Phase-boundary compaction (same flow as EventLoopNode._compact)
if continuous_conversation.usage_ratio() > 0.5:
await continuous_conversation.prune_old_tool_results(
+2 -16
View File
@@ -152,8 +152,6 @@ def compose_system_prompt(
accounts_prompt: str | None = None,
skills_catalog_prompt: str | None = None,
protocols_prompt: str | None = None,
execution_preamble: str | None = None,
node_type_preamble: str | None = None,
) -> str:
"""Compose the multi-layer system prompt.
@@ -164,10 +162,6 @@ def compose_system_prompt(
accounts_prompt: Connected accounts block (sits between identity and narrative).
skills_catalog_prompt: Available skills catalog XML (Agent Skills standard).
protocols_prompt: Default skill operational protocols section.
execution_preamble: EXECUTION_SCOPE_PREAMBLE for worker nodes
(prepended before focus so the LLM knows its pipeline scope).
node_type_preamble: Node-type-specific preamble, e.g. GCU browser
best-practices prompt (prepended before focus).
Returns:
Composed system prompt with all layers present, plus current datetime.
@@ -194,15 +188,6 @@ def compose_system_prompt(
if narrative:
parts.append(f"\n--- Context (what has happened so far) ---\n{narrative}")
# Execution scope preamble (worker nodes — tells the LLM it is one
# step in a multi-node pipeline and should not overreach)
if execution_preamble:
parts.append(f"\n{execution_preamble}")
# Node-type preamble (e.g. GCU browser best-practices)
if node_type_preamble:
parts.append(f"\n{node_type_preamble}")
# Layer 3: Focus (current phase directive)
if focus_prompt:
parts.append(f"\n--- Current Focus ---\n{focus_prompt}")
@@ -320,7 +305,8 @@ def build_transition_marker(
file_size = (data_path / filename).stat().st_size
val_str = (
f"[Saved to '{filename}' ({file_size:,} bytes). "
f"Use load_data(filename='{filename}') to access.]"
f"Use load_data(filename='{filename}') to access from session data. "
"Do NOT open it as a workspace file or expect it in the current directory.]"
)
except Exception:
val_str = val_str[:300] + "..."
+255
View File
@@ -0,0 +1,255 @@
"""Codex adapter for Hive's LiteLLM provider.
Codex CLI is tool-first and event-structured: tool invocations and tool results
are emitted as explicit response items, not as plain-text workflow narration.
This adapter keeps the ChatGPT Codex backend aligned with Hive's normal
provider contract by normalizing Codex request shaping and response recovery at
the provider boundary.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from framework.llm.codex_backend import (
build_codex_extra_headers,
is_codex_api_base,
merge_codex_allowed_openai_params,
normalize_codex_api_base,
)
from framework.llm.provider import Tool
if TYPE_CHECKING:
from collections.abc import Callable
from framework.llm.litellm import LiteLLMProvider
from framework.llm.stream_events import StreamEvent
logger = logging.getLogger(__name__)
_CODEX_CRITICAL_TOOL_NAMES = frozenset(
{
"ask_user",
"ask_user_multiple",
"set_output",
"escalate",
"save_agent_draft",
"confirm_and_build",
"initialize_and_build_agent",
}
)
_CODEX_SYSTEM_CHUNK_CHARS = 3500
_CODEX_SYSTEM_PREAMBLE = """# Codex Execution Contract
Follow the system sections below in order.
- Obey every CRITICAL, MUST, NEVER, and ONLY instruction exactly.
- When tools are available, emit structured tool calls instead of replying with plain-text promises.
- Do not skip required workflow boundaries or approval gates.
"""
class CodexResponsesAdapter:
"""Normalize the ChatGPT Codex backend to Hive's standard provider semantics."""
def __init__(self, provider: LiteLLMProvider):
self._provider = provider
@property
def enabled(self) -> bool:
"""Return True when the provider targets the ChatGPT Codex backend."""
return is_codex_api_base(self._provider.api_base)
def chunk_system_prompt(self, system: str) -> list[str]:
"""Break large system prompts into smaller Codex-friendly chunks."""
normalized = system.replace("\r\n", "\n").strip()
if not normalized:
return []
sections: list[str] = []
current: list[str] = []
for line in normalized.splitlines():
if line.startswith("#") and current:
sections.append("\n".join(current).strip())
current = [line]
else:
current.append(line)
if current:
sections.append("\n".join(current).strip())
chunks: list[str] = []
for section in sections:
if len(section) <= _CODEX_SYSTEM_CHUNK_CHARS:
chunks.append(section)
continue
paragraphs = [
paragraph.strip() for paragraph in section.split("\n\n") if paragraph.strip()
]
current_chunk = ""
for paragraph in paragraphs:
candidate = paragraph if not current_chunk else f"{current_chunk}\n\n{paragraph}"
if current_chunk and len(candidate) > _CODEX_SYSTEM_CHUNK_CHARS:
chunks.append(current_chunk)
current_chunk = paragraph
else:
current_chunk = candidate
if current_chunk:
chunks.append(current_chunk)
return chunks or [normalized]
def build_system_messages(
self,
system: str,
*,
json_mode: bool,
) -> list[dict[str, Any]]:
"""Build Codex system messages in the tool-first format Codex CLI expects."""
system_messages: list[dict[str, Any]] = []
if system:
chunks = self.chunk_system_prompt(system)
if len(chunks) > 1 or len(chunks[0]) > _CODEX_SYSTEM_CHUNK_CHARS:
system_messages.append({"role": "system", "content": _CODEX_SYSTEM_PREAMBLE})
for chunk in chunks:
system_messages.append({"role": "system", "content": chunk})
else:
system_messages.append({"role": "system", "content": "You are a helpful assistant."})
if json_mode:
system_messages.append(
{"role": "system", "content": "Please respond with a valid JSON object."}
)
return system_messages
def derive_tool_choice(
self,
messages: list[dict[str, Any]],
tools: list[Tool] | None,
) -> str | dict[str, Any] | None:
"""Force structured tool use when Codex sees critical framework tools."""
if not tools:
return None
tool_names = {tool.name for tool in tools}
if not (tool_names & _CODEX_CRITICAL_TOOL_NAMES):
return None
last_role = next(
(m.get("role") for m in reversed(messages) if m.get("role") != "system"),
None,
)
if last_role == "assistant":
return None
return "required"
def harden_request_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Strip unsupported params and inject the Codex backend headers."""
cleaned = dict(kwargs)
cleaned["api_base"] = normalize_codex_api_base(
cleaned.get("api_base") or self._provider.api_base
)
cleaned["store"] = False
cleaned["allowed_openai_params"] = merge_codex_allowed_openai_params(
cleaned.get("allowed_openai_params")
)
cleaned.pop("max_tokens", None)
cleaned.pop("stream_options", None)
extra_headers = dict(cleaned.get("extra_headers") or {})
if "ChatGPT-Account-Id" not in extra_headers:
try:
from framework.runner.runner import get_codex_account_id
account_id = get_codex_account_id()
if account_id:
extra_headers["ChatGPT-Account-Id"] = account_id
except Exception:
logger.debug("Could not populate ChatGPT-Account-Id", exc_info=True)
cleaned["extra_headers"] = build_codex_extra_headers(
self._provider.api_key,
account_id=extra_headers.get("ChatGPT-Account-Id"),
extra_headers=extra_headers,
)
return cleaned
async def recover_empty_stream(
self,
kwargs: dict[str, Any],
*,
last_role: str | None,
acompletion: Callable[..., Any],
) -> list[StreamEvent] | None:
"""Try a non-stream completion when Codex returns an empty stream."""
fallback_kwargs = dict(kwargs)
fallback_kwargs.pop("stream", None)
fallback_kwargs.pop("stream_options", None)
fallback_kwargs = self._provider._sanitize_request_kwargs(fallback_kwargs, stream=False)
try:
response = await acompletion(**fallback_kwargs)
except Exception as exc:
logger.debug(
"[stream-recover] %s non-stream fallback after empty %s stream failed: %s",
self._provider.model,
last_role,
exc,
)
return None
events = self._provider._build_stream_events_from_nonstream_response(response)
if events:
logger.info(
"[stream-recover] %s recovered empty %s stream via non-stream completion",
self._provider.model,
last_role,
)
return events
return None
def merge_tool_call_chunk(
self,
tool_calls_acc: dict[int, dict[str, str]],
tc: Any,
last_tool_idx: int,
) -> int:
"""Merge a streamed tool-call chunk, compensating for broken bridge indexes."""
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
tc_id = getattr(tc, "id", None) or ""
func = getattr(tc, "function", None)
func_name = getattr(func, "name", "") if func is not None else ""
func_args = getattr(func, "arguments", "") if func is not None else ""
if tc_id:
existing_idx = next(
(key for key, value in tool_calls_acc.items() if value["id"] == tc_id),
None,
)
if existing_idx is not None:
idx = existing_idx
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in ("", tc_id):
idx = max(tool_calls_acc.keys(), default=-1) + 1
last_tool_idx = idx
elif func_name:
if (
last_tool_idx in tool_calls_acc
and tool_calls_acc[last_tool_idx]["name"]
and tool_calls_acc[last_tool_idx]["name"] != func_name
and tool_calls_acc[last_tool_idx]["arguments"]
):
idx = max(tool_calls_acc.keys(), default=-1) + 1
last_tool_idx = idx
else:
idx = last_tool_idx if tool_calls_acc else idx
else:
idx = last_tool_idx if tool_calls_acc else idx
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
if tc_id:
tool_calls_acc[idx]["id"] = tc_id
if func_name:
tool_calls_acc[idx]["name"] = func_name
if func_args:
tool_calls_acc[idx]["arguments"] += func_args
return idx
+85
View File
@@ -0,0 +1,85 @@
"""Shared helpers for Codex's ChatGPT-backed transport.
Codex CLI talks to the ChatGPT Codex backend, which is not the standard
platform OpenAI API. Hive keeps its normal provider contract by centralizing
the transport-specific headers and request kwargs here.
"""
from __future__ import annotations
from typing import Any
from urllib.parse import urlparse, urlunparse
CODEX_API_BASE = "https://chatgpt.com/backend-api/codex"
CODEX_USER_AGENT = "CodexBar"
CODEX_ALLOWED_OPENAI_PARAMS = ("store",)
_CODEX_HOST = "chatgpt.com"
_CODEX_PATH = "/backend-api/codex"
def is_codex_api_base(api_base: str | None) -> bool:
"""Return True when *api_base* targets the ChatGPT Codex backend."""
if not api_base:
return False
parsed = urlparse(api_base)
path = parsed.path.rstrip("/")
return (
parsed.scheme in {"http", "https"}
and parsed.hostname == _CODEX_HOST
and (path == _CODEX_PATH or path == f"{_CODEX_PATH}/responses")
)
def normalize_codex_api_base(api_base: str | None) -> str | None:
"""Normalize ChatGPT Codex backend URLs to the stable base endpoint."""
if not api_base:
return api_base
parsed = urlparse(api_base)
path = parsed.path.rstrip("/")
if not is_codex_api_base(api_base):
return api_base.rstrip("/")
if path.endswith("/responses"):
path = path[: -len("/responses")]
normalized = parsed._replace(path=path, params="", query="", fragment="")
return urlunparse(normalized).rstrip("/")
def merge_codex_allowed_openai_params(params: list[str] | tuple[str, ...] | None) -> list[str]:
"""Ensure Codex-required pass-through params are always present."""
allowed = set(params or [])
allowed.update(CODEX_ALLOWED_OPENAI_PARAMS)
return sorted(allowed)
def build_codex_extra_headers(
api_key: str | None,
*,
account_id: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> dict[str, str]:
"""Build headers for the ChatGPT Codex backend."""
headers = dict(extra_headers or {})
if api_key:
headers.setdefault("Authorization", f"Bearer {api_key}")
headers.setdefault("User-Agent", CODEX_USER_AGENT)
if account_id:
headers.setdefault("ChatGPT-Account-Id", account_id)
return headers
def build_codex_litellm_kwargs(
api_key: str | None,
*,
account_id: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Return the LiteLLM kwargs required by the ChatGPT Codex backend."""
return {
"extra_headers": build_codex_extra_headers(
api_key,
account_id=account_id,
extra_headers=extra_headers,
),
"store": False,
"allowed_openai_params": list(CODEX_ALLOWED_OPENAI_PARAMS),
}
+404 -185
View File
@@ -28,6 +28,8 @@ except ImportError:
RateLimitError = Exception # type: ignore[assignment, misc]
from framework.config import HIVE_LLM_ENDPOINT as HIVE_API_BASE
from framework.llm.codex_adapter import CodexResponsesAdapter
from framework.llm.codex_backend import normalize_codex_api_base
from framework.llm.provider import LLMProvider, LLMResponse, Tool
from framework.llm.stream_events import StreamEvent
@@ -534,7 +536,9 @@ class LiteLLMProvider(LLMProvider):
api_base = api_base.rstrip("/")[:-3]
self.model = model
self.api_key = api_key
self.api_base = api_base or self._default_api_base_for_model(_original_model)
self.api_base = normalize_codex_api_base(
api_base or self._default_api_base_for_model(_original_model)
)
self.extra_kwargs = kwargs
# Detect Claude Code OAuth subscription by checking the api_key prefix.
self._claude_code_oauth = bool(api_key and api_key.startswith("sk-ant-oat"))
@@ -542,24 +546,28 @@ class LiteLLMProvider(LLMProvider):
# Anthropic requires a specific User-Agent for OAuth requests.
eh = self.extra_kwargs.setdefault("extra_headers", {})
eh.setdefault("user-agent", CLAUDE_CODE_USER_AGENT)
# The Codex ChatGPT backend (chatgpt.com/backend-api/codex) rejects
# several standard OpenAI params: max_output_tokens, stream_options.
self._codex_backend = bool(
self.api_base and "chatgpt.com/backend-api/codex" in self.api_base
)
# Antigravity routes through a local OpenAI-compatible proxy — no patches needed.
self._antigravity = bool(self.api_base and "localhost:8069" in self.api_base)
self._codex_adapter = CodexResponsesAdapter(self)
# Backward-compatible alias for existing tests/callers.
self._codex_backend = self._codex_adapter.enabled
if litellm is None:
raise ImportError(
"LiteLLM is not installed. Please install it with: uv pip install litellm"
)
# Note: The Codex ChatGPT backend is a Responses API endpoint at
# The Codex ChatGPT backend is a Responses API endpoint at
# chatgpt.com/backend-api/codex/responses. LiteLLM's model registry
# correctly marks codex models with mode="responses", so we do NOT
# override the mode. The responses_api_bridge in litellm handles
# converting Chat Completions requests to Responses API format.
# marks legacy codex models (gpt-5.3-codex) with mode="responses",
# but newer models like gpt-5.4 default to mode="chat". Force
# mode="responses" so litellm routes through the responses_api_bridge.
if self._codex_backend and litellm is not None:
_strip = self.model.removeprefix("openai/")
_entry = litellm.model_cost.get(_strip, {})
if _entry.get("mode") != "responses":
litellm.model_cost.setdefault(_strip, {})
litellm.model_cost[_strip]["mode"] = "responses"
@staticmethod
def _default_api_base_for_model(model: str) -> str | None:
@@ -575,6 +583,134 @@ class LiteLLMProvider(LLMProvider):
return HIVE_API_BASE
return None
@staticmethod
def _normalize_codex_api_base(api_base: str | None) -> str | None:
"""Normalize ChatGPT Codex backend URLs to the stable base endpoint."""
return normalize_codex_api_base(api_base)
def _chunk_codex_system_prompt(self, system: str) -> list[str]:
"""Break large system prompts into smaller Codex-friendly chunks."""
return self._codex_adapter.chunk_system_prompt(system)
def _build_request_messages(
self,
messages: list[dict[str, Any]],
system: str,
*,
json_mode: bool,
) -> list[dict[str, Any]]:
"""Build request messages, including Codex-specific prompt chunking."""
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
system_messages: list[dict[str, Any]] = []
if system:
if self._codex_backend:
system_messages.extend(
self._codex_adapter.build_system_messages(system, json_mode=json_mode)
)
else:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
sys_msg["cache_control"] = {"type": "ephemeral"}
system_messages.append(sys_msg)
elif self._codex_backend:
system_messages.extend(
self._codex_adapter.build_system_messages("", json_mode=json_mode)
)
if json_mode and not self._codex_backend:
json_instruction = "Please respond with a valid JSON object."
if system_messages:
system_messages[0] = {
**system_messages[0],
"content": f"{system_messages[0]['content']}\n\n{json_instruction}",
}
else:
system_messages.append({"role": "system", "content": json_instruction})
full_messages.extend(system_messages)
full_messages.extend(messages)
return [
m
for m in full_messages
if not (
m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls")
)
]
def _derive_codex_tool_choice(
self,
messages: list[dict[str, Any]],
tools: list[Tool] | None,
) -> str | dict[str, Any] | None:
"""Force tool use for Codex when critical framework tools are available."""
if not self._codex_backend:
return None
return self._codex_adapter.derive_tool_choice(messages, tools)
def _sanitize_request_kwargs(
self,
kwargs: dict[str, Any],
*,
stream: bool,
) -> dict[str, Any]:
"""Normalize provider kwargs, with extra hardening for Codex."""
cleaned = dict(kwargs)
if cleaned.get("metadata") is None:
cleaned.pop("metadata", None)
if self._codex_backend:
cleaned = self._codex_adapter.harden_request_kwargs(cleaned)
if stream:
cleaned["stream"] = True
return cleaned
def _build_completion_kwargs(
self,
messages: list[dict[str, Any]],
system: str,
*,
tools: list[Tool] | None,
max_tokens: int,
response_format: dict[str, Any] | None,
json_mode: bool,
stream: bool,
) -> dict[str, Any]:
"""Build request kwargs for completion/stream calls."""
full_messages = self._build_request_messages(messages, system, json_mode=json_mode)
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
**self.extra_kwargs,
}
if not stream:
kwargs["max_tokens"] = max_tokens
else:
kwargs["max_tokens"] = max_tokens
if not self._is_anthropic_model():
kwargs["stream_options"] = {"include_usage": True}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
tool_choice = self._derive_codex_tool_choice(full_messages, tools)
if tool_choice is not None:
kwargs["tool_choice"] = tool_choice
elif _is_ollama_model(self.model):
kwargs.setdefault("tool_choice", "auto")
if response_format:
kwargs["response_format"] = response_format
return self._sanitize_request_kwargs(kwargs, stream=stream)
def _completion_with_rate_limit_retry(
self, max_retries: int | None = None, **kwargs: Any
) -> Any:
@@ -713,46 +849,15 @@ class LiteLLMProvider(LLMProvider):
)
)
# Prepare messages with system prompt
full_messages = []
if system:
full_messages.append({"role": "system", "content": system})
full_messages.extend(messages)
# Add JSON mode via prompt engineering (works across all providers)
if json_mode:
json_instruction = "\n\nPlease respond with a valid JSON object."
# Append to system message if present, otherwise add as system message
if full_messages and full_messages[0]["role"] == "system":
full_messages[0]["content"] += json_instruction
else:
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
# Build kwargs
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": max_tokens,
**self.extra_kwargs,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
# Add tools if provided
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if _is_ollama_model(self.model):
# Ollama requires explicit tool_choice=auto for function calling
# so future readers don't have to guess.
kwargs.setdefault("tool_choice", "auto")
# Add response_format for structured output
# LiteLLM passes this through to the underlying provider
if response_format:
kwargs["response_format"] = response_format
kwargs = self._build_completion_kwargs(
messages,
system,
tools=tools,
max_tokens=max_tokens,
response_format=response_format,
json_mode=json_mode,
stream=False,
)
# Make the call
response = self._completion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
@@ -913,44 +1018,15 @@ class LiteLLMProvider(LLMProvider):
json_mode=json_mode,
)
return await self._collect_stream_to_response(stream_iter)
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
if system:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
sys_msg["cache_control"] = {"type": "ephemeral"}
full_messages.append(sys_msg)
full_messages.extend(messages)
if json_mode:
json_instruction = "\n\nPlease respond with a valid JSON object."
if full_messages and full_messages[0]["role"] == "system":
full_messages[0]["content"] += json_instruction
else:
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": max_tokens,
**self.extra_kwargs,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if _is_ollama_model(self.model):
# Ollama requires explicit tool_choice=auto for function calling
# so future readers don't have to guess.
kwargs.setdefault("tool_choice", "auto")
if response_format:
kwargs["response_format"] = response_format
kwargs = self._build_completion_kwargs(
messages,
system,
tools=tools,
max_tokens=max_tokens,
response_format=response_format,
json_mode=json_mode,
stream=False,
)
response = await self._acompletion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
@@ -1200,17 +1276,92 @@ class LiteLLMProvider(LLMProvider):
return parsed
return None
@staticmethod
def _normalize_pythonish_tool_arguments(raw_arguments: str) -> str:
"""Convert common JSON-like literals into a form ast.literal_eval can parse."""
replacements = {
"true": "True",
"false": "False",
"null": "None",
}
out: list[str] = []
token: list[str] = []
in_string = False
string_quote = ""
escaped = False
def flush_token() -> None:
if not token:
return
word = "".join(token)
out.append(replacements.get(word, word))
token.clear()
for char in raw_arguments:
if in_string:
out.append(char)
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == string_quote:
in_string = False
continue
if char in {'"', "'"}:
flush_token()
in_string = True
string_quote = char
out.append(char)
continue
if char.isalpha():
token.append(char)
continue
flush_token()
out.append(char)
flush_token()
return "".join(out)
@staticmethod
def _strip_tool_argument_fence(raw_arguments: str) -> str:
"""Remove surrounding fenced-code markers from streamed tool arguments."""
stripped = raw_arguments.strip()
if not stripped.startswith("```") or not stripped.endswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 2:
return "\n".join(lines[1:-1]).strip()
return stripped.strip("`").strip()
def _parse_pythonish_tool_arguments(self, raw_arguments: str) -> dict[str, Any] | None:
"""Parse single-quoted / trailing-comma argument payloads safely."""
stripped = self._strip_tool_argument_fence(raw_arguments)
if not stripped or stripped[0] != "{":
return None
candidate = self._close_truncated_json_fragment(stripped)
candidate = self._normalize_pythonish_tool_arguments(candidate)
try:
parsed = ast.literal_eval(candidate)
except (SyntaxError, ValueError):
return None
return parsed if isinstance(parsed, dict) else None
def _parse_tool_call_arguments(self, raw_arguments: str, tool_name: str) -> dict[str, Any]:
"""Parse streamed tool arguments, repairing truncation when possible."""
stripped = self._strip_tool_argument_fence(raw_arguments)
try:
parsed = json.loads(raw_arguments) if raw_arguments else {}
parsed = json.loads(stripped) if stripped else {}
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, dict):
return parsed
repaired = self._repair_truncated_tool_arguments(raw_arguments)
repaired = self._repair_truncated_tool_arguments(stripped)
if repaired is not None:
logger.warning(
"[tool-args] Recovered truncated arguments for %s on %s",
@@ -1219,6 +1370,15 @@ class LiteLLMProvider(LLMProvider):
)
return repaired
pythonish = self._parse_pythonish_tool_arguments(stripped)
if pythonish is not None:
logger.warning(
"[tool-args] Recovered malformed arguments for %s on %s",
tool_name,
self.model,
)
return pythonish
raise ValueError(
f"Failed to parse tool call arguments for '{tool_name}' (likely truncated JSON)."
)
@@ -1546,6 +1706,139 @@ class LiteLLMProvider(LLMProvider):
model=response.model,
)
def _build_stream_events_from_nonstream_response(
self,
response: Any,
) -> list[StreamEvent]:
"""Convert a non-stream completion response into stream events."""
from framework.llm.stream_events import (
FinishEvent,
TextDeltaEvent,
TextEndEvent,
ToolCallEvent,
)
choices = getattr(response, "choices", None) or []
if not choices:
output_text = getattr(response, "output_text", "") or ""
if not output_text:
return []
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, TextEndEvent
usage = getattr(response, "usage", None)
return [
TextDeltaEvent(content=output_text, snapshot=output_text),
TextEndEvent(full_text=output_text),
FinishEvent(
stop_reason="stop",
input_tokens=getattr(usage, "prompt_tokens", 0) or 0 if usage else 0,
output_tokens=getattr(usage, "completion_tokens", 0) or 0 if usage else 0,
model=getattr(response, "model", None) or self.model,
),
]
choice = choices[0]
message = getattr(choice, "message", None)
content = self._extract_message_text(message)
tool_calls = getattr(message, "tool_calls", None) or []
events: list[StreamEvent] = []
for tc in tool_calls:
parsed_args = self._coerce_tool_input(
tc.function.arguments if tc.function else {},
tc.function.name if tc.function else "",
)
events.append(
ToolCallEvent(
tool_use_id=getattr(tc, "id", ""),
tool_name=tc.function.name if tc.function else "",
tool_input=parsed_args,
)
)
if content:
events.append(TextDeltaEvent(content=content, snapshot=content))
events.append(TextEndEvent(full_text=content))
usage = getattr(response, "usage", None)
input_tokens = getattr(usage, "prompt_tokens", 0) or 0 if usage else 0
output_tokens = getattr(usage, "completion_tokens", 0) or 0 if usage else 0
cached_tokens = 0
if usage:
details = getattr(usage, "prompt_tokens_details", None)
cached_tokens = (
getattr(details, "cached_tokens", 0) or 0
if details is not None
else getattr(usage, "cache_read_input_tokens", 0) or 0
)
events.append(
FinishEvent(
stop_reason=getattr(choice, "finish_reason", None)
or ("tool_calls" if tool_calls else "stop"),
input_tokens=input_tokens,
output_tokens=output_tokens,
cached_tokens=cached_tokens,
model=getattr(response, "model", None) or self.model,
)
)
return events
@staticmethod
def _extract_message_text(message: Any) -> str:
"""Extract text from a provider message object across response shapes."""
if message is None:
return ""
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict):
text = block.get("text") or block.get("content") or ""
if isinstance(text, str):
parts.append(text)
else:
text = getattr(block, "text", "") or getattr(block, "content", "")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content or "")
def _coerce_tool_input(self, raw_arguments: Any, tool_name: str) -> dict[str, Any]:
"""Normalize raw tool-call arguments from either string or object forms."""
if isinstance(raw_arguments, dict):
return raw_arguments
if raw_arguments in (None, ""):
return {}
return self._parse_tool_call_arguments(str(raw_arguments), tool_name)
async def _recover_empty_codex_stream(
self,
kwargs: dict[str, Any],
last_role: str | None,
) -> list[StreamEvent] | None:
"""Try a non-stream completion when Codex returns an empty stream."""
if not self._codex_backend:
return None
return await self._codex_adapter.recover_empty_stream(
kwargs,
last_role=last_role,
acompletion=litellm.acompletion, # type: ignore[union-attr]
)
def _merge_tool_call_chunk(
self,
tool_calls_acc: dict[int, dict[str, str]],
tc: Any,
last_tool_idx: int,
) -> int:
"""Merge a streamed tool-call chunk, compensating for broken Codex indexes."""
return self._codex_adapter.merge_tool_call_chunk(tool_calls_acc, tc, last_tool_idx)
async def stream(
self,
messages: list[dict[str, Any]],
@@ -1597,69 +1890,16 @@ class LiteLLMProvider(LLMProvider):
yield event
return
full_messages: list[dict[str, Any]] = []
if self._claude_code_oauth:
billing = _claude_code_billing_header(messages)
full_messages.append({"role": "system", "content": billing})
if system:
sys_msg: dict[str, Any] = {"role": "system", "content": system}
if _model_supports_cache_control(self.model):
sys_msg["cache_control"] = {"type": "ephemeral"}
full_messages.append(sys_msg)
full_messages.extend(messages)
# Codex Responses API requires an `instructions` field (system prompt).
# Inject a minimal one when callers don't provide a system message.
if self._codex_backend and not any(m["role"] == "system" for m in full_messages):
full_messages.insert(0, {"role": "system", "content": "You are a helpful assistant."})
# Add JSON mode via prompt engineering (works across all providers)
if json_mode:
json_instruction = "\n\nPlease respond with a valid JSON object."
if full_messages and full_messages[0]["role"] == "system":
full_messages[0]["content"] += json_instruction
else:
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
# Remove ghost empty assistant messages (content="" and no tool_calls).
# These arise when a model returns an empty stream after a tool result
# (an "expected" no-op turn). Keeping them in history confuses some
# models (notably Codex/gpt-5.3) and causes cascading empty streams.
full_messages = [
m
for m in full_messages
if not (
m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls")
)
]
kwargs: dict[str, Any] = {
"model": self.model,
"messages": full_messages,
"max_tokens": max_tokens,
"stream": True,
**self.extra_kwargs,
}
# stream_options is OpenAI-specific; Anthropic rejects it with 400.
# Only include it for providers that support it.
if not self._is_anthropic_model():
kwargs["stream_options"] = {"include_usage": True}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
if _is_ollama_model(self.model):
# Ollama requires explicit tool_choice=auto for function calling
# so future readers don't have to guess.
kwargs.setdefault("tool_choice", "auto")
if response_format:
kwargs["response_format"] = response_format
# The Codex ChatGPT backend (Responses API) rejects several params.
if self._codex_backend:
kwargs.pop("max_tokens", None)
kwargs.pop("stream_options", None)
kwargs = self._build_completion_kwargs(
messages,
system,
tools=tools,
max_tokens=max_tokens,
response_format=response_format,
json_mode=json_mode,
stream=True,
)
full_messages = kwargs["messages"]
for attempt in range(RATE_LIMIT_MAX_RETRIES + 1):
# Post-stream events (ToolCall, TextEnd, Finish) are buffered
@@ -1717,43 +1957,17 @@ class LiteLLMProvider(LLMProvider):
# argument deltas that arrive with id=None.
if delta and delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
if tc.id:
# New tool call announced (or done event re-sent).
# Check if this id already has a slot.
existing_idx = next(
(k for k, v in tool_calls_acc.items() if v["id"] == tc.id),
None,
)
if existing_idx is not None:
idx = existing_idx
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in (
"",
tc.id,
):
# Slot taken by a different call — assign new index
idx = max(tool_calls_acc.keys()) + 1
_last_tool_idx = idx
else:
# Argument delta with no id — route to last opened slot
idx = _last_tool_idx
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
if tc.id:
tool_calls_acc[idx]["id"] = tc.id
if tc.function:
if tc.function.name:
tool_calls_acc[idx]["name"] = tc.function.name
if tc.function.arguments:
tool_calls_acc[idx]["arguments"] += tc.function.arguments
_last_tool_idx = self._merge_tool_call_chunk(
tool_calls_acc,
tc,
_last_tool_idx,
)
# --- Finish ---
if choice.finish_reason:
stream_finish_reason = choice.finish_reason
for _idx, tc_data in sorted(tool_calls_acc.items()):
parsed_args = self._parse_tool_call_arguments(
parsed_args = self._coerce_tool_input(
tc_data.get("arguments", ""),
tc_data.get("name", ""),
)
@@ -1886,6 +2100,11 @@ class LiteLLMProvider(LLMProvider):
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
None,
)
recovered_events = await self._recover_empty_codex_stream(kwargs, last_role)
if recovered_events:
for event in recovered_events:
yield event
return
if attempt < EMPTY_STREAM_MAX_RETRIES:
token_count, token_method = _estimate_tokens(
self.model,
+80 -33
View File
@@ -22,6 +22,7 @@ from framework.graph.edge import (
)
from framework.graph.executor import ExecutionResult
from framework.graph.node import NodeSpec
from framework.llm.codex_backend import CODEX_API_BASE, build_codex_litellm_kwargs
from framework.llm.provider import LLMProvider, Tool
from framework.runner.preload_validation import run_preload_validation
from framework.runner.tool_registry import ToolRegistry
@@ -327,17 +328,68 @@ def _read_codex_auth_file() -> dict | None:
return None
def _get_jwt_claims(token: str) -> dict | None:
"""Decode JWT claims without verification for local expiry/account inspection."""
import base64
try:
parts = token.split(".")
if len(parts) != 3:
return None
payload = parts[1]
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
decoded = base64.urlsafe_b64decode(payload)
claims = json.loads(decoded)
return claims if isinstance(claims, dict) else None
except Exception:
return None
def _get_codex_token_expiry(auth_data: dict) -> float | None:
"""Return the best-known expiry timestamp for a Codex access token."""
from datetime import datetime
tokens = auth_data.get("tokens", {})
access_token = tokens.get("access_token")
explicit = (
auth_data.get("expires_at")
or auth_data.get("expiresAt")
or tokens.get("expires_at")
or tokens.get("expiresAt")
)
if isinstance(explicit, (int, float)):
return float(explicit)
if isinstance(explicit, str):
try:
return datetime.fromisoformat(explicit.replace("Z", "+00:00")).timestamp()
except (ValueError, TypeError):
pass
if isinstance(access_token, str):
claims = _get_jwt_claims(access_token) or {}
exp = claims.get("exp")
if isinstance(exp, (int, float)):
return float(exp)
return None
def _is_codex_token_expired(auth_data: dict) -> bool:
"""Check whether the Codex token is expired or close to expiry.
The Codex auth.json has no explicit ``expiresAt`` field, so we infer
expiry as ``last_refresh + _CODEX_TOKEN_LIFETIME_SECS``. Falls back
to the file mtime when ``last_refresh`` is absent.
to JWT ``exp`` or file age heuristics when no explicit timestamp exists.
"""
import time
from datetime import datetime
now = time.time()
explicit_expiry = _get_codex_token_expiry(auth_data)
if explicit_expiry is not None:
return now >= (explicit_expiry - _TOKEN_REFRESH_BUFFER_SECS)
last_refresh = auth_data.get("last_refresh")
if last_refresh is None:
@@ -431,6 +483,8 @@ def get_codex_token() -> str | None:
Returns:
The access token if available, None otherwise.
"""
import time
# Try Keychain first, then file
auth_data = _read_codex_keychain() or _read_codex_auth_file()
if not auth_data:
@@ -441,15 +495,20 @@ def get_codex_token() -> str | None:
if not access_token:
return None
explicit_expiry = _get_codex_token_expiry(auth_data)
is_expired = _is_codex_token_expired(auth_data)
# Check if token is still valid
if not _is_codex_token_expired(auth_data):
if not is_expired:
return access_token
# Token is expired or near expiry — attempt refresh
refresh_token = tokens.get("refresh_token")
if not refresh_token:
logger.warning("Codex token expired and no refresh token available")
return access_token # Return expired token; it may still work briefly
if explicit_expiry is not None and time.time() >= explicit_expiry:
return None
return access_token
logger.info("Codex token expired or near expiry, refreshing...")
token_data = _refresh_codex_token(refresh_token)
@@ -460,6 +519,8 @@ def get_codex_token() -> str | None:
# Refresh failed — return the existing token and warn
logger.warning("Codex token refresh failed. Run 'codex' to re-authenticate.")
if explicit_expiry is not None and time.time() >= explicit_expiry:
return None
return access_token
@@ -471,26 +532,12 @@ def _get_account_id_from_jwt(access_token: str) -> str | None:
This is used as a fallback when the auth.json doesn't store the
account_id explicitly.
"""
import base64
try:
parts = access_token.split(".")
if len(parts) != 3:
return None
payload = parts[1]
# Add base64 padding
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
decoded = base64.urlsafe_b64decode(payload)
claims = json.loads(decoded)
auth = claims.get("https://api.openai.com/auth")
if isinstance(auth, dict):
account_id = auth.get("chatgpt_account_id")
if isinstance(account_id, str) and account_id:
return account_id
except Exception:
pass
claims = _get_jwt_claims(access_token) or {}
auth = claims.get("https://api.openai.com/auth")
if isinstance(auth, dict):
account_id = auth.get("chatgpt_account_id")
if isinstance(account_id, str) and account_id:
return account_id
return None
@@ -1569,20 +1616,20 @@ class AgentRunner:
# OpenAI Codex subscription routes through the ChatGPT backend
# (chatgpt.com/backend-api/codex/responses), NOT the standard
# OpenAI API. The consumer OAuth token lacks platform API scopes.
extra_headers: dict[str, str] = {
"Authorization": f"Bearer {api_key}",
"User-Agent": "CodexBar",
}
account_id = get_codex_account_id()
if account_id:
extra_headers["ChatGPT-Account-Id"] = account_id
self._llm = LiteLLMProvider(
model=self.model,
api_key=api_key,
api_base="https://chatgpt.com/backend-api/codex",
extra_headers=extra_headers,
store=False,
allowed_openai_params=["store"],
api_base=CODEX_API_BASE,
**build_codex_litellm_kwargs(api_key, account_id=account_id),
)
elif api_key and use_kimi_code:
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
# The api_base is set automatically by LiteLLMProvider for kimi/ models.
self._llm = LiteLLMProvider(
model=self.model,
api_key=api_key,
api_base=api_base,
)
elif api_key and use_kimi_code:
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
+11 -2
View File
@@ -535,8 +535,8 @@ class EventBus:
async with self._semaphore:
try:
await handler(event)
except Exception:
logger.exception(f"Handler error for {event.type}")
except Exception as e:
logger.error(f"Handler error for {event.type}: {e}")
# Run all handlers concurrently
await asyncio.gather(*[run_handler(h) for h in handlers], return_exceptions=True)
@@ -901,6 +901,9 @@ class EventBus:
execution_id: str | None = None,
options: list[str] | None = None,
questions: list[dict] | None = None,
auto_blocked: bool = False,
assistant_text_present: bool = False,
assistant_text_requires_input: bool = False,
) -> None:
"""Emit client input requested event (client_facing=True nodes).
@@ -917,6 +920,12 @@ class EventBus:
data["options"] = options
if questions:
data["questions"] = questions
if auto_blocked:
data["auto_blocked"] = True
if assistant_text_present:
data["assistant_text_present"] = True
if assistant_text_requires_input:
data["assistant_text_requires_input"] = True
await self.publish(
AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
+21 -3
View File
@@ -452,7 +452,9 @@ class ExecutionStream:
node = executor.node_registry.get(node_id)
if node is not None and hasattr(node, "inject_event"):
await node.inject_event(
content, is_client_input=is_client_input, image_content=image_content
content,
is_client_input=is_client_input,
image_content=image_content,
)
return True
return False
@@ -1030,6 +1032,22 @@ class ExecutionStream:
else:
status = SessionStatus.ACTIVE
persisted_input_data = dict(ctx.input_data or {})
entry_node_id = getattr(self.entry_spec, "entry_node", None) or getattr(
self.graph, "entry_node", None
)
entry_input_keys: list[str] = []
if entry_node_id and hasattr(self.graph, "get_node"):
entry_node = self.graph.get_node(entry_node_id)
entry_input_keys = list(getattr(entry_node, "input_keys", []) or [])
if result and isinstance(result.output, dict):
for key in entry_input_keys:
if persisted_input_data.get(key) in (None, ""):
value = result.output.get(key)
if value not in (None, ""):
persisted_input_data[key] = value
# Create SessionState
if result:
# Create from execution result
@@ -1040,7 +1058,7 @@ class ExecutionStream:
stream_id=self.stream_id,
correlation_id=ctx.correlation_id,
started_at=ctx.started_at.isoformat(),
input_data=ctx.input_data,
input_data=persisted_input_data,
agent_id=self.graph.id,
entry_point=self.entry_spec.id,
)
@@ -1075,7 +1093,7 @@ class ExecutionStream:
),
progress=progress,
memory=ss.get("memory", {}),
input_data=ctx.input_data,
input_data=persisted_input_data,
)
# Handle error case
+16 -1
View File
@@ -48,7 +48,20 @@ def validate_agent_path(agent_path: str | Path) -> Path:
Raises:
ValueError: If the path is outside all allowed roots.
"""
resolved = Path(agent_path).expanduser().resolve()
raw_path = str(agent_path).strip()
if not raw_path:
raise ValueError(
"agent_path must be inside an allowed directory "
"(exports/, examples/, or ~/.hive/agents/)"
)
candidate = Path(agent_path).expanduser()
if not candidate.is_absolute():
# Resolve relative paths from the repository root so server-side
# validation is independent of the process working directory.
candidate = _REPO_ROOT / candidate
resolved = candidate.resolve()
for root in _get_allowed_agent_roots():
if resolved.is_relative_to(root) and resolved != root:
return resolved
@@ -281,6 +294,8 @@ def _setup_static_serving(app: web.Application) -> None:
async def handle_spa(request: web.Request) -> web.FileResponse:
"""Serve static files with SPA fallback to index.html."""
rel_path = request.match_info.get("path", "")
if rel_path == "api" or rel_path.startswith("api/"):
raise web.HTTPNotFound()
file_path = (dist_dir / rel_path).resolve()
if file_path.is_file() and file_path.is_relative_to(dist_dir):
+22 -15
View File
@@ -17,6 +17,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _client_input_counts_as_planning_ask(event: Any) -> bool:
"""Return True when a queen input-request should satisfy planning ask rounds.
Explicit ask_user / ask_user_multiple calls always count. We also count
queen auto-blocks that followed assistant text which clearly invited a
reply, which covers Codex-style plain-text planning questions that failed
to call ask_user. Empty/status-only auto-blocks do not count.
"""
data = getattr(event, "data", None) or {}
if data.get("prompt") or data.get("questions") or data.get("options"):
return True
if not data.get("auto_blocked"):
return False
requires_input = data.get("assistant_text_requires_input", False)
return bool(requires_input)
async def create_queen(
session: Session,
session_manager: Any,
@@ -62,7 +79,6 @@ async def create_queen(
from framework.agents.queen.nodes.thinking_hook import select_expert_persona
from framework.graph.event_loop_node import HookContext, HookResult
from framework.graph.executor import GraphExecutor
from framework.runner.mcp_registry import MCPRegistry
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.core import Runtime
from framework.runtime.event_bus import AgentEvent, EventType
@@ -88,6 +104,8 @@ async def create_queen(
logger.warning("Queen: MCP config failed to load", exc_info=True)
try:
from framework.runner.mcp_registry import MCPRegistry
registry = MCPRegistry()
registry.initialize()
if (queen_pkg_dir / "mcp_registry.json").is_file():
@@ -115,14 +133,7 @@ async def create_queen(
async def _track_planning_asks(event: AgentEvent) -> None:
if phase_state.phase != "planning":
return
# Only count explicit ask_user / ask_user_multiple calls, not
# auto-block (text-only turns emit CLIENT_INPUT_REQUESTED with
# an empty prompt and no options/questions).
data = event.data or {}
has_prompt = bool(data.get("prompt"))
has_questions = bool(data.get("questions"))
has_options = bool(data.get("options"))
if has_prompt or has_questions or has_options:
if _client_input_counts_as_planning_ask(event):
phase_state.planning_ask_rounds += 1
session.event_bus.subscribe(
@@ -240,15 +251,11 @@ async def create_queen(
# ---- Default skill protocols -------------------------------------
try:
from framework.skills.manager import SkillsManager, SkillsManagerConfig
from framework.skills.manager import SkillsManager
# Pass project_root so user-scope skills (~/.hive/skills/, ~/.agents/skills/)
# are discovered. Queen has no agent-specific project root, so we use its
# own directory — the value just needs to be non-None to enable user-scope scanning.
_queen_skills_mgr = SkillsManager(SkillsManagerConfig(project_root=Path(__file__).parent))
_queen_skills_mgr = SkillsManager()
_queen_skills_mgr.load()
phase_state.protocols_prompt = _queen_skills_mgr.protocols_prompt
phase_state.skills_catalog_prompt = _queen_skills_mgr.skills_catalog_prompt
except Exception:
logger.debug("Queen skill loading failed (non-fatal)", exc_info=True)
+135 -23
View File
@@ -8,11 +8,113 @@ from typing import Any
from aiohttp import web
from framework.credentials.validation import validate_agent_credentials
from framework.runtime.event_bus import AgentEvent, EventType
from framework.server.app import resolve_session, safe_path_segment, sessions_dir
from framework.server.routes_sessions import _credential_error_response
from framework.server.session_manager import (
_run_validation_report_sync,
_validation_blocks_stage_or_run,
_validation_failures,
)
logger = logging.getLogger(__name__)
_TERMINAL_STOP_MARKERS = (
"done for now",
"stop here",
"stop for now",
"end session",
"finish and close session",
"finish and close",
)
def _normalize_choice_text(text: str) -> str:
lowered = str(text or "").strip().lower()
return " ".join(lowered.replace("_", " ").split())
def _looks_like_terminal_stop_reply(text: str) -> bool:
normalized = _normalize_choice_text(text)
return any(marker in normalized for marker in _TERMINAL_STOP_MARKERS)
def _queen_is_waiting_on_terminal_followup(session: Any) -> bool:
"""Return True when the latest queen question offered a terminal stop option."""
bus = getattr(session, "event_bus", None)
if bus is None or not hasattr(bus, "get_history"):
return False
events = bus.get_history(
event_type=EventType.CLIENT_INPUT_REQUESTED,
stream_id="queen",
limit=5,
)
for event in events:
data = getattr(event, "data", None) or {}
options = [str(opt) for opt in (data.get("options") or []) if opt]
for question in data.get("questions") or []:
options.extend(str(opt) for opt in (question.get("options") or []) if opt)
if options:
return any(_looks_like_terminal_stop_reply(opt) for opt in options)
return False
async def _acknowledge_terminal_queen_choice(session: Any, message: str) -> None:
"""Emit a final acknowledgment when the user chooses to stop."""
ack = "Okay, stopping here. Ill wait for your next message."
bus = getattr(session, "event_bus", None)
if bus is None:
return
if hasattr(bus, "publish"):
await bus.publish(
AgentEvent(
type=EventType.CLIENT_INPUT_RECEIVED,
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={"content": message},
)
)
if hasattr(bus, "emit_client_output_delta"):
await bus.emit_client_output_delta(
"queen",
"queen",
ack,
ack,
execution_id=session.id,
)
async def _worker_validation_error(session) -> web.Response | None:
"""Return a 409 response when the loaded worker is invalid."""
report = getattr(session, "worker_validation_report", None)
if report is None and getattr(session, "worker_path", None):
loop = asyncio.get_running_loop()
report = await loop.run_in_executor(
None, lambda: _run_validation_report_sync(str(session.worker_path))
)
session.worker_validation_report = report
session.worker_validation_failures = _validation_failures(report)
if _validation_blocks_stage_or_run(report):
failures = getattr(session, "worker_validation_failures", None) or _validation_failures(
report
)
worker_name = getattr(getattr(session, "worker_path", None), "name", "") or "current worker"
return web.json_response(
{
"error": (
f"Worker '{worker_name}' failed validation and cannot be executed. "
"Fix the package and reload it before running or resuming."
),
"validation_failures": failures,
},
status=409,
)
return None
async def handle_trigger(request: web.Request) -> web.Response:
"""POST /api/sessions/{session_id}/trigger — start an execution.
@@ -26,6 +128,10 @@ async def handle_trigger(request: web.Request) -> web.Response:
if not session.worker_runtime:
return web.json_response({"error": "No worker loaded in this session"}, status=503)
validation_err = await _worker_validation_error(session)
if validation_err is not None:
return validation_err
# Validate credentials before running — deferred from load time to avoid
# showing the modal before the user clicks Run. Runs in executor because
# validate_agent_credentials makes blocking HTTP health-check calls.
@@ -53,11 +159,7 @@ async def handle_trigger(request: web.Request) -> web.Response:
body = await request.json()
entry_point_id = body.get("entry_point_id", "default")
input_data = body.get("input_data", {})
session_state = body.get("session_state") or {}
# Scope the worker execution to the live session ID
if "resume_session_id" not in session_state:
session_state["resume_session_id"] = session.id
session_state = body.get("session_state") or None
execution_id = await session.worker_runtime.trigger(
entry_point_id,
@@ -108,10 +210,7 @@ async def handle_chat(request: web.Request) -> web.Response:
The input box is permanently connected to the queen agent.
Worker input is handled separately via /worker-input.
Body: {"message": "hello", "images": [{"type": "image_url", "image_url": {"url": "data:..."}}]}
The optional ``images`` field accepts a list of OpenAI-format image_url
content blocks. The frontend encodes images as base64 data URIs.
Body: {"message": "hello"}
"""
session, err = resolve_session(request)
if err:
@@ -119,29 +218,34 @@ async def handle_chat(request: web.Request) -> web.Response:
body = await request.json()
message = body.get("message", "")
image_content = body.get("images") or None # list[dict] | None
if not message and not image_content:
if not message:
return web.json_response({"error": "message is required"}, status=400)
manager: Any = request.app["manager"]
if _looks_like_terminal_stop_reply(message) and _queen_is_waiting_on_terminal_followup(session):
await _acknowledge_terminal_queen_choice(session, message)
await manager.suspend_queen(session)
return web.json_response(
{
"status": "queen",
"delivered": True,
}
)
queen_executor = session.queen_executor
if queen_executor is not None:
node = queen_executor.node_registry.get("queen")
if node is not None and hasattr(node, "inject_event"):
await node.inject_event(message, is_client_input=True, image_content=image_content)
# Publish to EventBus so the session event log captures user messages
from framework.runtime.event_bus import AgentEvent, EventType
await node.inject_event(message, is_client_input=True)
await session.event_bus.publish(
AgentEvent(
type=EventType.CLIENT_INPUT_RECEIVED,
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={
"content": message,
"image_count": len(image_content) if image_content else 0,
},
data={"content": message},
)
)
return web.json_response(
@@ -152,7 +256,6 @@ async def handle_chat(request: web.Request) -> web.Response:
)
# Queen is dead — try to revive her
manager: Any = request.app["manager"]
try:
await manager.revive_queen(session, initial_prompt=message)
return web.json_response(
@@ -274,6 +377,10 @@ async def handle_resume(request: web.Request) -> web.Response:
if not session.worker_runtime:
return web.json_response({"error": "No worker loaded in this session"}, status=503)
validation_err = await _worker_validation_error(session)
if validation_err is not None:
return validation_err
body = await request.json()
worker_session_id = body.get("session_id")
checkpoint_id = body.get("checkpoint_id")
@@ -419,9 +526,14 @@ async def handle_stop(request: web.Request) -> web.Response:
if hasattr(node, "cancel_current_turn"):
node.cancel_current_turn()
cancelled = await stream.cancel_execution(
execution_id, reason="Execution stopped by user"
)
try:
cancelled = await stream.cancel_execution(
execution_id, reason="Execution stopped by user"
)
except TypeError:
# Backward compatibility for older stream/test doubles that
# still expose cancel_execution(execution_id) only.
cancelled = await stream.cancel_execution(execution_id)
if cancelled:
# Cancel queen's in-progress LLM turn
if session.queen_executor:
+45 -31
View File
@@ -28,8 +28,6 @@ import contextlib
import json
import logging
import shutil
import subprocess
import sys
import time
from pathlib import Path
@@ -42,22 +40,54 @@ from framework.server.app import (
sessions_dir,
validate_agent_path,
)
from framework.server.session_manager import SessionManager
from framework.server.session_manager import (
SessionManager,
WorkerValidationError,
_run_validation_report_sync,
_validation_blocks_stage_or_run,
_validation_failures,
)
logger = logging.getLogger(__name__)
async def _worker_validation_error(session) -> web.Response | None:
"""Return a 409 response when the loaded worker is invalid."""
report = getattr(session, "worker_validation_report", None)
if report is None and getattr(session, "worker_path", None):
loop = asyncio.get_running_loop()
report = await loop.run_in_executor(
None, lambda: _run_validation_report_sync(str(session.worker_path))
)
session.worker_validation_report = report
session.worker_validation_failures = _validation_failures(report)
if _validation_blocks_stage_or_run(report):
failures = getattr(session, "worker_validation_failures", None) or _validation_failures(
report
)
worker_name = getattr(getattr(session, "worker_path", None), "name", "") or "current worker"
return web.json_response(
{
"error": (
f"Worker '{worker_name}' failed validation and cannot be executed. "
"Fix the package and reload it before running or restoring."
),
"validation_failures": failures,
},
status=409,
)
return None
def _get_manager(request: web.Request) -> SessionManager:
return request.app["manager"]
def _session_to_live_dict(session) -> dict:
"""Serialize a live Session to the session-primary JSON shape."""
from framework.llm.capabilities import supports_image_tool_results
info = session.worker_info
phase_state = getattr(session, "phase_state", None)
queen_model: str = getattr(getattr(session, "runner", None), "model", "") or ""
return {
"session_id": session.id,
"worker_id": session.worker_id,
@@ -73,7 +103,6 @@ def _session_to_live_dict(session) -> dict:
"queen_phase": phase_state.phase
if phase_state
else ("staging" if session.worker_runtime else "planning"),
"queen_supports_images": supports_image_tool_results(queen_model) if queen_model else True,
}
@@ -311,6 +340,11 @@ async def handle_load_worker(request: web.Request) -> web.Response:
model=model,
)
except ValueError as e:
if isinstance(e, WorkerValidationError):
return web.json_response(
{"error": str(e), "validation_failures": e.failures},
status=409,
)
return web.json_response({"error": str(e)}, status=409)
except FileNotFoundError:
return web.json_response({"error": f"Agent not found: {agent_path}"}, status=404)
@@ -729,6 +763,10 @@ async def handle_restore_checkpoint(request: web.Request) -> web.Response:
if not session.worker_runtime:
return web.json_response({"error": "No worker loaded in this session"}, status=503)
validation_err = await _worker_validation_error(session)
if validation_err is not None:
return validation_err
ws_id = request.match_info.get("ws_id") or request.match_info.get("session_id", "")
ws_id = safe_path_segment(ws_id)
checkpoint_id = safe_path_segment(request.match_info["checkpoint_id"])
@@ -984,29 +1022,6 @@ async def handle_discover(request: web.Request) -> web.Response:
return web.json_response(result)
async def handle_reveal_session_folder(request: web.Request) -> web.Response:
"""POST /api/sessions/{session_id}/reveal — open session data folder in the OS file manager."""
manager: SessionManager = request.app["manager"]
session_id = request.match_info["session_id"]
session = manager.get_session(session_id)
storage_session_id = (session.queen_resume_from or session.id) if session else session_id
folder = Path.home() / ".hive" / "queen" / "session" / storage_session_id
folder.mkdir(parents=True, exist_ok=True)
try:
if sys.platform == "darwin":
subprocess.Popen(["open", str(folder)])
elif sys.platform == "win32":
subprocess.Popen(["explorer", str(folder)])
else:
subprocess.Popen(["xdg-open", str(folder)])
except Exception as exc:
return web.json_response({"error": str(exc)}, status=500)
return web.json_response({"path": str(folder)})
# ------------------------------------------------------------------
# Route registration
# ------------------------------------------------------------------
@@ -1031,7 +1046,6 @@ def register_routes(app: web.Application) -> None:
app.router.add_delete("/api/sessions/{session_id}/worker", handle_unload_worker)
# Session info
app.router.add_post("/api/sessions/{session_id}/reveal", handle_reveal_session_folder)
app.router.add_get("/api/sessions/{session_id}/stats", handle_session_stats)
app.router.add_get("/api/sessions/{session_id}/entry-points", handle_session_entry_points)
app.router.add_patch(
+220 -47
View File
@@ -12,6 +12,8 @@ Architecture:
import asyncio
import json
import logging
import subprocess
import textwrap
import time
import uuid
from dataclasses import dataclass, field
@@ -22,6 +24,134 @@ from typing import Any
from framework.runtime.triggers import TriggerDefinition
logger = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parents[3]
CODER_TOOLS_SERVER = REPO_ROOT / "tools" / "coder_tools_server.py"
def _parse_validation_report(raw: str | dict[str, Any] | None) -> dict[str, Any]:
"""Best-effort parse of validate_agent_package output."""
if isinstance(raw, dict):
return raw
if not isinstance(raw, str):
return {}
cleaned = raw.strip()
if "\n\n[Saved to " in cleaned:
cleaned = cleaned.split("\n\n[Saved to ", 1)[0].strip()
if not cleaned:
return {}
try:
return json.loads(cleaned)
except json.JSONDecodeError:
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
try:
return json.loads(cleaned[start : end + 1])
except json.JSONDecodeError:
return {}
return {}
def _validation_failures(report: dict[str, Any] | None) -> list[str]:
"""Extract readable failure summaries from a validation report."""
if not isinstance(report, dict):
return []
steps = report.get("steps") or {}
failures: list[str] = []
for step_name, step in steps.items():
if not isinstance(step, dict) or step.get("passed", False):
continue
if step.get("errors"):
errors = step["errors"]
if isinstance(errors, list):
failures.extend(f"{step_name}: {err}" for err in errors)
continue
if step.get("missing_tools"):
missing = step["missing_tools"]
if isinstance(missing, list):
failures.extend(f"{step_name}: missing tool {tool}" for tool in missing)
continue
detail = step.get("error") or step.get("output") or "validation failed"
failures.append(f"{step_name}: {detail}")
if not failures and report.get("summary"):
failures.append(str(report["summary"]))
return failures
def _validation_blocks_stage_or_run(report: dict[str, Any] | None) -> bool:
"""Return True when a validation report contains any failed step."""
if not isinstance(report, dict):
return False
steps = report.get("steps")
if not isinstance(steps, dict):
return bool(report.get("valid") is False)
return any(isinstance(step, dict) and not step.get("passed", False) for step in steps.values())
def _run_validation_report_sync(agent_ref: str | Path) -> dict[str, Any]:
"""Run validate_agent_package in an isolated subprocess.
Accepts either a built-agent package name (for exports/) or a full
allowed agent path such as examples/templates/<agent>.
"""
if not agent_ref:
return {}
agent_ref_str = str(agent_ref)
script = textwrap.dedent(
"""
import importlib.util
import sys
server_path = sys.argv[1]
agent_ref = sys.argv[2]
spec = importlib.util.spec_from_file_location("coder_tools_server", server_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
import json
print(json.dumps(module._validate_agent_package_impl(agent_ref), default=str))
"""
)
try:
proc = subprocess.run(
["uv", "run", "python", "-c", script, str(CODER_TOOLS_SERVER), agent_ref_str],
capture_output=True,
text=True,
timeout=120,
cwd=REPO_ROOT,
stdin=subprocess.DEVNULL,
)
except (OSError, subprocess.SubprocessError) as exc:
return {
"valid": False,
"summary": f"validate_agent_package failed for '{agent_ref_str}'",
"steps": {"validator_subprocess": {"passed": False, "error": str(exc)[:2000]}},
}
if proc.returncode != 0:
detail = proc.stderr.strip() or proc.stdout.strip() or "validation subprocess failed"
return {
"valid": False,
"summary": f"validate_agent_package failed for '{agent_ref_str}'",
"steps": {"validator_subprocess": {"passed": False, "error": detail[:2000]}},
}
return _parse_validation_report(proc.stdout)
class WorkerValidationError(ValueError):
"""Raised when a worker package fails validation before load/run."""
def __init__(self, agent_name: str, report: dict[str, Any]):
self.agent_name = agent_name
self.report = report
self.failures = _validation_failures(report)
super().__init__(
f"Worker '{agent_name}' failed validation: "
+ ("; ".join(self.failures) if self.failures else "validation failed")
)
@dataclass
@@ -41,6 +171,8 @@ class Session:
runner: Any | None = None # AgentRunner
worker_runtime: Any | None = None # AgentRuntime
worker_info: Any | None = None # AgentInfo
worker_validation_report: dict[str, Any] | None = None
worker_validation_failures: list[str] = field(default_factory=list)
# Queen phase state (building/staging/running)
phase_state: Any = None # QueenPhaseState
# Worker handoff subscription
@@ -83,6 +215,47 @@ class SessionManager:
self._credential_store = credential_store
self._lock = asyncio.Lock()
async def suspend_queen(self, session: Session) -> None:
"""Park the queen until the user sends a fresh message.
This is lighter than stopping the full session: it tears down the
queen executor and its subscriptions, but preserves the live session,
loaded worker, and persisted history. The next `/chat` call will
revive the queen via the normal code path.
"""
if session.worker_handoff_sub is not None:
try:
session.event_bus.unsubscribe(session.worker_handoff_sub)
except Exception:
pass
session.worker_handoff_sub = None
if session.memory_consolidation_sub is not None:
try:
session.event_bus.unsubscribe(session.memory_consolidation_sub)
except Exception:
pass
session.memory_consolidation_sub = None
executor = session.queen_executor
if executor is not None:
node = executor.node_registry.get("queen")
if node is not None and hasattr(node, "signal_shutdown"):
node.signal_shutdown()
if session.queen_task is not None:
task = session.queen_task
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception:
logger.debug("Queen task exited with error during suspend", exc_info=True)
session.queen_task = None
session.queen_executor = None
# ------------------------------------------------------------------
# Session lifecycle
# ------------------------------------------------------------------
@@ -96,7 +269,8 @@ class SessionManager:
Internal helper use create_session() or create_session_with_worker().
"""
from framework.config import RuntimeConfig, get_hive_config
from framework.config import RuntimeConfig
from framework.llm.litellm import LiteLLMProvider
from framework.runtime.event_bus import EventBus
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -110,20 +284,12 @@ class SessionManager:
rc = RuntimeConfig(model=model or self._model or RuntimeConfig().model)
# Session owns these — shared with queen and worker
llm_config = get_hive_config().get("llm", {})
if llm_config.get("use_antigravity_subscription"):
from framework.llm.antigravity import AntigravityProvider
llm = AntigravityProvider(model=rc.model)
else:
from framework.llm.litellm import LiteLLMProvider
llm = LiteLLMProvider(
model=rc.model,
api_key=rc.api_key,
api_base=rc.api_base,
**rc.extra_kwargs,
)
llm = LiteLLMProvider(
model=rc.model,
api_key=rc.api_key,
api_base=rc.api_base,
**rc.extra_kwargs,
)
event_bus = EventBus()
session = Session(
@@ -294,6 +460,11 @@ class SessionManager:
try:
# Blocking I/O — load in executor
loop = asyncio.get_running_loop()
validation_report = await loop.run_in_executor(
None, lambda: _run_validation_report_sync(str(agent_path))
)
if _validation_blocks_stage_or_run(validation_report):
raise WorkerValidationError(agent_path.name, validation_report)
# Prioritize: explicit model arg > worker-specific model > session default
from framework.config import (
@@ -320,25 +491,17 @@ class SessionManager:
# with the correct worker credentials so _setup() doesn't fall back
# to the queen's llm config (which may be a different provider).
if worker_model and not model:
from framework.config import get_hive_config
from framework.llm.litellm import LiteLLMProvider
worker_llm_cfg = get_hive_config().get("worker_llm", {})
if worker_llm_cfg.get("use_antigravity_subscription"):
from framework.llm.antigravity import AntigravityProvider
runner._llm = AntigravityProvider(model=resolved_model)
else:
from framework.llm.litellm import LiteLLMProvider
worker_api_key = get_worker_api_key()
worker_api_base = get_worker_api_base()
worker_extra = get_worker_llm_extra_kwargs()
runner._llm = LiteLLMProvider(
model=resolved_model,
api_key=worker_api_key,
api_base=worker_api_base,
**worker_extra,
)
worker_api_key = get_worker_api_key()
worker_api_base = get_worker_api_base()
worker_extra = get_worker_llm_extra_kwargs()
runner._llm = LiteLLMProvider(
model=resolved_model,
api_key=worker_api_key,
api_base=worker_api_base,
**worker_extra,
)
# Setup with session's event bus
if runner._agent_runtime is None:
@@ -383,6 +546,8 @@ class SessionManager:
session.runner = runner
session.worker_runtime = runtime
session.worker_info = info
session.worker_validation_report = validation_report
session.worker_validation_failures = _validation_failures(validation_report)
# Subscribe to execution completion for per-run digest generation
self._subscribe_worker_digest(session)
@@ -637,6 +802,8 @@ class SessionManager:
session.runner = None
session.worker_runtime = None
session.worker_info = None
session.worker_validation_report = None
session.worker_validation_failures = []
# Notify queen
await self._notify_queen_worker_unloaded(session)
@@ -820,12 +987,25 @@ class SessionManager:
return
await node.inject_event(f"[WORKER_DIGEST]\n{content}")
async def _consolidate_and_notify(run_id: str, outcome_event: Any) -> None:
"""Write the digest then push it to the queen."""
async def _consolidate_and_notify(
run_id: str,
outcome_event: Any,
*,
inject_to_queen: bool,
) -> None:
"""Write the digest and optionally push it into the queen.
Final worker completion/failure already emits a richer
[WORKER_TERMINAL] handoff with the real primary result. Injecting the
final digest as a second queen event causes Codex to replace that
result with a bland generic follow-up prompt. Keep writing digests to
disk for memory/history, but only inject mid-run snapshots.
"""
from framework.agents.worker_memory import consolidate_worker_run
await consolidate_worker_run(_agent_name, run_id, outcome_event, _bus, _llm)
await _inject_digest_to_queen(run_id)
if inject_to_queen:
await _inject_digest_to_queen(run_id)
async def _on_worker_event(event: Any) -> None:
if event.stream_id == "queen":
@@ -851,7 +1031,7 @@ class SessionManager:
run_id = getattr(event, "run_id", None) or _resolve_run_id(exec_id)
if run_id:
asyncio.create_task(
_consolidate_and_notify(run_id, event),
_consolidate_and_notify(run_id, event, inject_to_queen=False),
name=f"worker-digest-final-{run_id}",
)
@@ -872,7 +1052,7 @@ class SessionManager:
if run_id:
_last_digest[exec_id] = now
asyncio.create_task(
_consolidate_and_notify(run_id, None),
_consolidate_and_notify(run_id, None, inject_to_queen=True),
name=f"worker-digest-{run_id}",
)
@@ -1047,17 +1227,10 @@ class SessionManager:
_consolidation_session_dir = queen_dir
async def _on_compaction(_event) -> None:
# Only consolidate on queen compactions — worker and subagent
# compactions are frequent and don't warrant a memory update.
if getattr(_event, "stream_id", None) != "queen":
return
from framework.agents.queen.queen_memory import consolidate_queen_memory
asyncio.create_task(
consolidate_queen_memory(
session.id, _consolidation_session_dir, _consolidation_llm
),
name=f"queen-memory-consolidation-{session.id}",
await consolidate_queen_memory(
session.id, _consolidation_session_dir, _consolidation_llm
)
from framework.runtime.event_bus import EventType as _ET
+145 -4
View File
@@ -14,6 +14,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from aiohttp.test_utils import TestClient, TestServer
from framework.runtime.event_bus import AgentEvent, EventType
from framework.runtime.triggers import TriggerDefinition
from framework.server.app import create_app
from framework.server.session_manager import Session
@@ -190,6 +191,8 @@ def _make_session(
runner=runner,
worker_runtime=rt,
worker_info=MockAgentInfo(),
worker_validation_report={"valid": True, "steps": {}},
worker_validation_failures=[],
)
@@ -556,6 +559,7 @@ class TestExecution:
@pytest.mark.asyncio
async def test_trigger(self):
session = _make_session()
session.worker_runtime.trigger = AsyncMock(return_value="exec_test_123")
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
@@ -565,6 +569,11 @@ class TestExecution:
assert resp.status == 200
data = await resp.json()
assert data["execution_id"] == "exec_test_123"
session.worker_runtime.trigger.assert_awaited_once_with(
"default",
{"msg": "hi"},
session_state=None,
)
@pytest.mark.asyncio
async def test_trigger_not_found(self):
@@ -576,6 +585,25 @@ class TestExecution:
)
assert resp.status == 404
@pytest.mark.asyncio
async def test_trigger_blocks_invalid_loaded_worker(self):
session = _make_session()
session.worker_validation_report = {
"valid": False,
"steps": {"behavior_validation": {"passed": False}},
}
session.worker_validation_failures = ["behavior_validation: placeholder prompt"]
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/trigger",
json={"entry_point_id": "default", "input_data": {"msg": "hi"}},
)
assert resp.status == 409
data = await resp.json()
assert "failed validation" in data["error"]
assert data["validation_failures"] == ["behavior_validation: placeholder prompt"]
@pytest.mark.asyncio
async def test_inject(self):
session = _make_session()
@@ -616,8 +644,8 @@ class TestExecution:
assert data["delivered"] is True
@pytest.mark.asyncio
async def test_chat_injects_when_node_waiting(self):
"""When a node is awaiting input, /chat should inject instead of trigger."""
async def test_chat_still_goes_to_queen_when_node_waiting(self):
"""The main chat channel stays wired to Queen even if a worker is waiting."""
session = _make_session()
session.worker_runtime.find_awaiting_node = lambda: ("chat_node", "primary")
app = _make_app_with_session(session)
@@ -628,6 +656,83 @@ class TestExecution:
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "queen"
assert data["delivered"] is True
@pytest.mark.asyncio
async def test_chat_done_for_now_parks_queen_without_new_followup(self):
"""Terminal stop choices should acknowledge once and park the queen."""
session = _make_session()
session.event_bus.get_history.return_value = [
AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={"options": ["Run again with same input", "Done for now"]},
)
]
session.event_bus.emit_client_output_delta = AsyncMock()
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/chat",
json={"message": "No, stop here"},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "queen"
assert data["delivered"] is True
queen_node = session.queen_executor
assert queen_node is None
session.event_bus.emit_client_output_delta.assert_awaited_once()
@pytest.mark.asyncio
async def test_chat_non_terminal_choice_still_goes_to_queen(self):
"""Non-terminal follow-up choices should still be injected into the queen."""
session = _make_session()
session.event_bus.get_history.return_value = [
AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={"options": ["Run again with same input", "Done for now"]},
)
]
session.event_bus.emit_client_output_delta = AsyncMock()
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/chat",
json={"message": "Run again with same input"},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "queen"
assert data["delivered"] is True
queen_node = session.queen_executor.node_registry["queen"]
queen_node.inject_event.assert_awaited_once_with(
"Run again with same input",
is_client_input=True,
)
session.event_bus.emit_client_output_delta.assert_not_called()
@pytest.mark.asyncio
async def test_worker_input_injects_when_node_waiting(self):
session = _make_session()
session.worker_runtime.find_awaiting_node = lambda: ("chat_node", "primary")
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/worker-input",
json={"message": "user reply"},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "injected"
assert data["node_id"] == "chat_node"
assert data["delivered"] is True
@@ -715,8 +820,6 @@ class TestResume:
assert resp.status == 200
data = await resp.json()
assert data["execution_id"] == "exec_test_123"
assert data["resumed_from"] == session_id
assert data["checkpoint_id"] is None
@pytest.mark.asyncio
async def test_resume_with_checkpoint(self, sample_session, tmp_agent_dir):
@@ -761,6 +864,31 @@ class TestResume:
)
assert resp.status == 404
@pytest.mark.asyncio
async def test_resume_blocks_invalid_loaded_worker(self, sample_session, tmp_agent_dir):
session_id, session_dir, state = sample_session
tmp_path, agent_name, base = tmp_agent_dir
session = _make_session(tmp_dir=tmp_path / ".hive" / "agents" / agent_name)
session.worker_validation_report = {
"valid": False,
"steps": {"tool_validation": {"passed": False}},
}
session.worker_validation_failures = ["tool_validation: missing tool execute_command_tool"]
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/resume",
json={"session_id": session_id},
)
assert resp.status == 409
data = await resp.json()
assert "failed validation" in data["error"]
assert data["validation_failures"] == [
"tool_validation: missing tool execute_command_tool"
]
class TestStop:
@pytest.mark.asyncio
@@ -800,6 +928,19 @@ class TestStop:
)
assert resp.status == 400
@pytest.mark.asyncio
async def test_stop_ignores_worker_validation_failure(self):
session = _make_session()
session.worker_validation_failures = ["behavior_validation: broken"]
session.worker_runtime._mock_streams["default"]._execution_tasks["exec_abc"] = MagicMock()
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/stop",
json={"execution_id": "exec_abc"},
)
assert resp.status == 200
class TestReplay:
@pytest.mark.asyncio
+717 -122
View File
@@ -36,6 +36,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import re
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
@@ -53,6 +54,7 @@ from framework.tools.flowchart_utils import (
save_flowchart_file,
synthesize_draft_from_runtime,
)
from framework.tools.worker_monitoring_tools import read_worker_health_snapshot
if TYPE_CHECKING:
from framework.runner.tool_registry import ToolRegistry
@@ -61,6 +63,16 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_NON_ACCEPT_JUDGE_ACTIONS = frozenset({"RETRY", "CONTINUE", "ESCALATE"})
_HEALTH_SIGNAL_DESCRIPTIONS: dict[str, str] = {
"failed_session": "worker session is marked failed",
"stalled": "worker appears stalled with no meaningful progress for 5+ minutes",
"slow_progress": "worker progress has slowed for 2+ minutes without completing",
"long_non_accept_streak": "worker has a sustained non-ACCEPT judge streak",
"judge_pressure": "worker is under repeated non-ACCEPT judge pressure",
"recent_non_accept_churn": "recent judge verdicts are all non-ACCEPT, indicating churn",
}
@dataclass
class WorkerSessionAdapter:
@@ -118,8 +130,6 @@ class QueenPhaseState:
# Default skill operational protocols — appended to every phase prompt
protocols_prompt: str = ""
# Community skills catalog (XML) — appended after protocols
skills_catalog_prompt: str = ""
def get_current_tools(self) -> list:
"""Return tools for the current phase."""
@@ -146,8 +156,6 @@ class QueenPhaseState:
memory = format_for_injection()
parts = [base]
if self.skills_catalog_prompt:
parts.append(self.skills_catalog_prompt)
if self.protocols_prompt:
parts.append(self.protocols_prompt)
if memory:
@@ -750,6 +758,438 @@ def _update_meta_json(session_manager, manager_session_id, updates: dict) -> Non
pass
def _parse_validation_report(raw: Any) -> dict | None:
"""Best-effort parse of validate_agent_package output."""
if isinstance(raw, dict):
return raw
if hasattr(raw, "content"):
raw = raw.content
if raw is None:
return None
text = str(raw).strip()
if not text:
return None
candidates = [text]
if "\n\n[Saved to" in text:
candidates.append(text.split("\n\n[Saved to", 1)[0].strip())
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(text[start : end + 1])
try:
for candidate in candidates:
try:
return json.loads(candidate)
except (TypeError, json.JSONDecodeError):
continue
except TypeError:
pass
return None
def _validation_failures(report: dict | None) -> list[str]:
"""Flatten failed validation steps into readable messages."""
if not report:
return []
failures: list[str] = []
for step_name, step in (report.get("steps") or {}).items():
if step.get("passed"):
continue
detail = step.get("output") or step.get("error") or "failed"
failures.append(f"{step_name}: {detail}")
return failures
def _validation_blocks_stage_or_run(report: dict | None) -> bool:
"""Return True when validation results should block staging or execution."""
if not report:
return False
return any(
isinstance(step, dict) and not step.get("passed", False)
for step in (report.get("steps") or {}).values()
)
def _invalid_validation_report(reason: str) -> dict:
"""Build a structured validation failure when validator output is unusable."""
return {
"valid": False,
"summary": reason,
"steps": {
"validator_subprocess": {
"passed": False,
"error": reason,
}
},
}
_STRUCTURED_TASK_PAIR_RE = re.compile(
r"\[?(?P<key>[A-Za-z_][A-Za-z0-9_]*)\]?\s*(?::|=)\s*(?P<value>.*?)(?=(?:\s+\[?[A-Za-z_][A-Za-z0-9_]*\]?\s*(?::|=))|$)"
)
_STRUCTURED_TASK_LINE_RE = re.compile(
r"^\s*(?:[-*]\s*)?\[?(?P<key>[A-Za-z_][A-Za-z0-9_]*)\]?\s*(?::|=)\s*(?P<value>.*?)\s*$"
)
_NUMERIC_WITH_SUFFIX_RE = re.compile(r"^(?P<number>-?\d+(?:\.\d+)?)\s*\([^)]*\)\s*$")
_LEADING_NUMERIC_RE = re.compile(r"^\s*(?P<number>-?\d+(?:\.\d+)?)\b")
_RERUN_WITH_DEFAULTS_RE = re.compile(
r"\b(?:run\s+again|rerun|continue)\b.*\b(?:same\s+(?:default|defaults|settings|inputs)|"
r"defaults?)\b",
re.IGNORECASE,
)
_PATH_INPUT_KEY_HINTS = (
"_dir",
"_path",
"_folder",
"_root",
)
_NUMERIC_INPUT_KEY_HINTS = (
"_threshold",
"_count",
"_limit",
"_max",
"_min",
"_ratio",
"_size",
)
def _coerce_task_value(raw: str) -> Any:
"""Best-effort coerce simple structured task values from text."""
text = raw.strip().rstrip(",")
if not text:
return ""
numeric_match = _NUMERIC_WITH_SUFFIX_RE.match(text)
if numeric_match:
text = numeric_match.group("number")
try:
return json.loads(text)
except (TypeError, json.JSONDecodeError):
return text
def _parse_structured_task_payload(task: str) -> dict[str, Any]:
"""Extract ``key: value`` pairs or JSON objects from a task string."""
text = (task or "").strip()
if not text:
return {}
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except (TypeError, json.JSONDecodeError):
pass
payload: dict[str, Any] = {}
current_key: str | None = None
for raw_line in text.splitlines():
if not raw_line.strip():
current_key = None
continue
inline_matches = list(_STRUCTURED_TASK_PAIR_RE.finditer(raw_line.strip()))
if len(inline_matches) > 1:
for match in inline_matches:
payload[match.group("key")] = _coerce_task_value(match.group("value"))
current_key = inline_matches[-1].group("key")
continue
line_match = _STRUCTURED_TASK_LINE_RE.match(raw_line)
if line_match:
key = line_match.group("key")
value_text = line_match.group("value")
payload[key] = _coerce_task_value(value_text)
current_key = key
continue
if current_key and (raw_line.startswith(" ") or raw_line.startswith("\t")):
existing = payload.get(current_key, "")
continuation = raw_line.strip()
if isinstance(existing, str):
payload[current_key] = f"{existing}\n{continuation}".strip()
continue
current_key = None
if payload:
return payload
matches = list(_STRUCTURED_TASK_PAIR_RE.finditer(text))
for match in matches:
key = match.group("key")
value_text = match.group("value")
if not value_text.strip():
continue
payload[key] = _coerce_task_value(value_text)
return payload
def _looks_like_path_input_key(key: str) -> bool:
lowered = key.lower()
return lowered.endswith(_PATH_INPUT_KEY_HINTS) or lowered in {
"path",
"dir",
"folder",
"root",
}
def _looks_like_numeric_input_key(key: str) -> bool:
lowered = key.lower()
return lowered.endswith(_NUMERIC_INPUT_KEY_HINTS) or lowered in {
"threshold",
"count",
"limit",
"max",
"min",
"ratio",
"size",
}
def _normalize_worker_input_value(key: str, value: Any) -> Any:
"""Normalize structured worker inputs before handing them to the runtime."""
if not isinstance(value, str):
return value
text = value.strip()
if not text:
return ""
if _looks_like_numeric_input_key(key):
numeric_match = _LEADING_NUMERIC_RE.match(text)
if numeric_match:
number = numeric_match.group("number")
return float(number) if "." in number else int(number)
if _looks_like_path_input_key(key):
candidate = Path(text).expanduser()
if not candidate.is_absolute():
candidate = (Path.cwd() / candidate).resolve()
return str(candidate)
return text
def _should_backfill_from_recent_input(key: str, value: Any) -> bool:
"""Return True when a recent session input should replace the current value."""
if value is None:
return True
if isinstance(value, str):
text = value.strip()
if not text:
return True
if _looks_like_numeric_input_key(key):
return _LEADING_NUMERIC_RE.fullmatch(text) is None
return False
def _load_recent_worker_input_defaults(
runtime: Any,
input_keys: list[str],
session_id: str | None = None,
) -> dict[str, Any]:
"""Load the best recent worker input payload from unified session state.
When a current session is known, only that session is considered so reruns
cannot inherit structured defaults from an unrelated historical session.
Otherwise we fall back to the latest available session as a best-effort
compatibility path for legacy callers.
"""
store = getattr(runtime, "_session_store", None)
sessions_dir = Path(getattr(store, "sessions_dir", "")) if store is not None else None
if sessions_dir is None or not sessions_dir.exists():
return {}
allowed_key_set = set(input_keys)
candidate_state_paths: list[Path]
if session_id:
state_path = sessions_dir / session_id / "state.json"
candidate_state_paths = [state_path] if state_path.exists() else []
else:
candidate_state_paths = []
if not candidate_state_paths:
candidate_state_paths = sorted(
sessions_dir.glob("session_*/state.json"),
key=lambda path: path.stat().st_mtime if path.exists() else 0,
reverse=True,
)
work_keys = [key for key in input_keys if key not in {"user_request", "task", "feedback"}]
if not work_keys:
best_payload: dict[str, Any] = {}
best_updated_at = ""
for state_path in candidate_state_paths:
try:
raw_state = json.loads(state_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
input_data = raw_state.get("input_data") or {}
if not isinstance(input_data, dict) or not input_data:
continue
input_data = {key: value for key, value in input_data.items() if key in allowed_key_set}
if not input_data:
continue
updated_at = str((raw_state.get("timestamps") or {}).get("updated_at") or "")
if updated_at >= best_updated_at:
best_payload = dict(input_data)
best_updated_at = updated_at
return best_payload
best_payload: dict[str, Any] = {}
best_score = -1
best_updated_at = ""
for state_path in candidate_state_paths:
try:
raw_state = json.loads(state_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
continue
input_data = raw_state.get("input_data") or {}
if not isinstance(input_data, dict):
continue
merged_input = {key: value for key, value in input_data.items() if key in allowed_key_set}
result_output = (raw_state.get("result") or {}).get("output") or {}
if isinstance(result_output, dict):
for key in work_keys:
if merged_input.get(key) in (None, "") and result_output.get(key) not in (None, ""):
merged_input[key] = result_output[key]
score = sum(1 for key in work_keys if merged_input.get(key) not in (None, ""))
if score <= 0:
continue
updated_at = str((raw_state.get("timestamps") or {}).get("updated_at") or "")
if score > best_score or (score == best_score and updated_at > best_updated_at):
best_payload = merged_input
best_score = score
best_updated_at = updated_at
return best_payload
async def _preflight_worker_run(session: Any, runtime: Any, timeout_seconds: int) -> None:
"""Validate credentials and refresh MCP servers before a worker run."""
loop = asyncio.get_running_loop()
async def _preflight():
cred_error: CredentialError | None = None
try:
await loop.run_in_executor(
None,
lambda: validate_credentials(
runtime.graph.nodes,
interactive=False,
skip=False,
),
)
except CredentialError as e:
cred_error = e
runner = getattr(session, "runner", None)
if runner:
try:
await loop.run_in_executor(
None,
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
)
except Exception as e:
logger.warning("MCP resync failed: %s", e)
if cred_error is not None:
raise cred_error
try:
await asyncio.wait_for(_preflight(), timeout=timeout_seconds)
except TimeoutError:
logger.warning(
"worker run preflight timed out after %ds — proceeding",
timeout_seconds,
)
def _get_default_entry_input_keys(runtime: Any) -> list[str]:
"""Return the loaded worker's default entry node input keys, if available."""
try:
entry_points = runtime.get_entry_points()
except Exception:
return []
if not entry_points:
return []
graph = getattr(runtime, "graph", None)
if graph is None or not hasattr(graph, "get_node"):
return []
entry_spec = entry_points[0]
entry_node_id = getattr(entry_spec, "entry_node", None) or getattr(graph, "entry_node", None)
if not entry_node_id:
return []
node = graph.get_node(entry_node_id)
return list(getattr(node, "input_keys", []) or []) if node is not None else []
def _build_worker_input_data(
runtime: Any,
task: str,
session_id: str | None = None,
) -> dict[str, Any]:
"""Shape queen task text into the loaded worker's expected entry inputs."""
structured = _parse_structured_task_payload(task)
allowed_keys = _get_default_entry_input_keys(runtime)
# Backwards compatibility for older workers that still expect a single
# free-form task string, while allowing newer workers to receive
# structured fields directly.
if not allowed_keys:
payload = {"user_request": task, "task": task}
payload.update(structured)
return payload
shaped: dict[str, Any] = {}
if "user_request" in allowed_keys:
shaped["user_request"] = task
if "task" in allowed_keys:
shaped["task"] = task
work_keys = [key for key in allowed_keys if key not in {"user_request", "task", "feedback"}]
structured_work_keys = {key for key in work_keys if key in structured}
should_merge_recent_defaults = bool(structured_work_keys) or bool(
_RERUN_WITH_DEFAULTS_RE.search(task or "")
)
recent_defaults = (
_load_recent_worker_input_defaults(runtime, allowed_keys, session_id=session_id)
if should_merge_recent_defaults
else {}
)
for key in allowed_keys:
if key in {"user_request", "task", "feedback"}:
continue
if key in structured:
shaped[key] = _normalize_worker_input_value(key, structured[key])
if (
key in recent_defaults
and _should_backfill_from_recent_input(key, shaped[key])
and recent_defaults.get(key) not in (None, "")
):
shaped[key] = _normalize_worker_input_value(key, recent_defaults[key])
elif recent_defaults.get(key) not in (None, ""):
shaped[key] = _normalize_worker_input_value(key, recent_defaults[key])
work_keys = [key for key in allowed_keys if key != "feedback"]
if not any(key in shaped for key in work_keys):
if len(work_keys) == 1:
shaped[work_keys[0]] = task
elif "user_request" in allowed_keys and "user_request" not in shaped:
shaped["user_request"] = task
elif "task" in allowed_keys and "task" not in shaped:
shaped["task"] = task
return shaped
def register_queen_lifecycle_tools(
registry: ToolRegistry,
session: Any = None,
@@ -803,6 +1243,23 @@ def register_queen_lifecycle_tools(
"""Get current worker runtime from session (late-binding)."""
return getattr(session, "worker_runtime", None)
async def _run_package_validation(agent_ref: str) -> dict | None:
"""Run validate_agent_package if available in the registry."""
validator = registry._tools.get("validate_agent_package")
if validator is None or not agent_ref:
return None
# The validator accepts either a built-agent package name or a
# fully resolved allowed agent path.
result = validator.executor({"agent_name": agent_ref})
if asyncio.iscoroutine(result) or asyncio.isfuture(result):
result = await result
parsed = _parse_validation_report(result)
if parsed is None:
return _invalid_validation_report(
"validate_agent_package returned an invalid or undecodable report"
)
return parsed
# --- start_worker ---------------------------------------------------------
# How long to wait for credential validation + MCP resync before
@@ -821,66 +1278,17 @@ def register_queen_lifecycle_tools(
return json.dumps({"error": "No worker loaded in this session."})
try:
# Pre-flight: validate credentials and resync MCP servers.
# Both are blocking I/O (HTTP health-checks, subprocess spawns)
# so they run in a thread-pool executor. We cap the total
# preflight time so the queen never hangs waiting.
loop = asyncio.get_running_loop()
async def _preflight():
cred_error: CredentialError | None = None
try:
await loop.run_in_executor(
None,
lambda: validate_credentials(
runtime.graph.nodes,
interactive=False,
skip=False,
),
)
except CredentialError as e:
cred_error = e
runner = getattr(session, "runner", None)
if runner:
try:
await loop.run_in_executor(
None,
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
)
except Exception as e:
logger.warning("MCP resync failed: %s", e)
# Re-raise CredentialError after MCP resync so both steps
# get a chance to run before we bail.
if cred_error is not None:
raise cred_error
try:
await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT)
except TimeoutError:
logger.warning(
"start_worker preflight timed out after %ds — proceeding with trigger",
_START_PREFLIGHT_TIMEOUT,
)
except CredentialError:
raise # handled below
await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT)
# Resume timers in case they were paused by a previous stop_worker
runtime.resume_timers()
# Get session state from any prior execution for memory continuity
session_state = runtime._get_primary_session_state("default") or {}
# Use the shared session ID so queen, judge, and worker all
# scope their conversations to the same session.
if session_id:
session_state["resume_session_id"] = session_id
exec_id = await runtime.trigger(
entry_point_id="default",
input_data={"user_request": task},
session_state=session_state,
input_data=_build_worker_input_data(runtime, task, session_id=session_id),
# Worker runs should start from the explicit input payload for
# this run, not inherit another execution's shared session.
session_state=None,
)
return json.dumps(
{
@@ -2547,19 +2955,76 @@ def register_queen_lifecycle_tools(
return preamble
def _detect_red_flags(bus: EventBus) -> int:
def _get_worker_health_snapshot() -> dict[str, Any] | None:
worker_path = getattr(session, "worker_path", None)
if not worker_path:
return None
try:
snapshot = read_worker_health_snapshot(
Path(worker_path),
session_id=session_id,
default_session_id=session_id,
)
except Exception:
logger.exception("Failed to read worker health snapshot for queen status")
return None
if snapshot.get("error"):
return None
return snapshot
def _detect_red_flags(bus: EventBus, health_snapshot: dict[str, Any] | None = None) -> int:
"""Count issue categories with cheap limit=1 queries."""
if health_snapshot:
issue_signals = health_snapshot.get("issue_signals", [])
if isinstance(issue_signals, list) and issue_signals:
return len(issue_signals)
count = 0
for evt_type in (
EventType.NODE_RETRY,
EventType.NODE_STALLED,
EventType.NODE_TOOL_DOOM_LOOP,
EventType.CONSTRAINT_VIOLATION,
):
if bus.get_history(event_type=evt_type, limit=1):
count += 1
if _get_recent_judge_pressure(bus)[0]:
count += 1
return count
def _format_summary(preamble: dict[str, Any], red_flags: int) -> str:
def _get_recent_judge_pressure(bus: EventBus, streak_threshold: int = 4) -> tuple[bool, str]:
"""Detect sustained judge churn even when no hard stall event exists yet."""
verdict_events = bus.get_history(event_type=EventType.JUDGE_VERDICT, limit=8)
if len(verdict_events) < streak_threshold:
return False, ""
streak: list[str] = []
for evt in verdict_events:
action = str(evt.data.get("action", "")).upper()
if action == "ACCEPT":
break
if action in _NON_ACCEPT_JUDGE_ACTIONS:
streak.append(action)
continue
break
if len(streak) < streak_threshold:
return False, ""
compressed: list[str] = []
for action in streak:
if not compressed or compressed[-1] != action:
compressed.append(action)
return (
True,
f"{len(streak)} consecutive non-ACCEPT judge verdict(s): {' -> '.join(compressed)}",
)
def _format_summary(
preamble: dict[str, Any],
red_flags: int,
health_snapshot: dict[str, Any] | None = None,
) -> str:
"""Generate a 1-2 sentence prose summary from the preamble."""
status = preamble["status"]
@@ -2586,10 +3051,17 @@ def register_queen_lifecycle_tools(
node_part += f", iteration {iteration}"
parts.append(node_part)
health_signals = health_snapshot.get("issue_signals", []) if health_snapshot else []
if red_flags:
parts.append(f"{red_flags} issue type(s) detected — use focus='issues' for details")
if isinstance(health_signals, list) and health_signals:
parts.append(
f"{red_flags} issue signal(s) detected "
f"({', '.join(health_signals)}) — use focus='issues' for details"
)
else:
parts.append(f"{red_flags} issue type(s) detected — use focus='issues' for details")
else:
parts.append("No issues detected")
parts.append("No issue signals detected")
# Latest subagent progress (if any delegation is in flight)
bus = _get_event_bus()
@@ -2737,7 +3209,7 @@ def register_queen_lifecycle_tools(
return "\n".join(lines)
def _format_issues(bus: EventBus) -> str:
def _format_issues(bus: EventBus, health_snapshot: dict[str, Any] | None = None) -> str:
"""Format retries, stalls, doom loops, and constraint violations."""
lines = []
total = 0
@@ -2787,8 +3259,42 @@ def register_queen_lifecycle_tools(
ago = _format_time_ago(evt.timestamp)
lines.append(f" {cid} ({ago}): {desc}")
has_judge_pressure, judge_pressure_desc = _get_recent_judge_pressure(bus)
if has_judge_pressure:
total += 1
lines.append("Judge pressure detected:")
lines.append(f" {judge_pressure_desc}")
if health_snapshot:
issue_signals = health_snapshot.get("issue_signals", [])
if isinstance(issue_signals, list) and issue_signals:
total += len(issue_signals)
lines.append("Health signals:")
for signal in issue_signals:
desc = _HEALTH_SIGNAL_DESCRIPTIONS.get(signal, signal.replace("_", " "))
if (
signal in {"stalled", "slow_progress"}
and health_snapshot.get("stall_minutes") is not None
):
desc += f" ({health_snapshot['stall_minutes']} min since last step)"
elif (
signal in {"long_non_accept_streak", "judge_pressure"}
and health_snapshot.get("steps_since_last_accept") is not None
):
desc += (
" ("
f"{health_snapshot['steps_since_last_accept']} non-ACCEPT step(s)"
" since last ACCEPT)"
)
elif signal == "recent_non_accept_churn" and health_snapshot.get(
"recent_verdicts"
):
verdicts = ", ".join(health_snapshot["recent_verdicts"][-4:])
desc += f" ({verdicts})"
lines.append(f" {signal}: {desc}")
if total == 0:
return "No issues detected. No retries, stalls, or constraint violations."
return "No issues detected. No runtime issue signals were found."
header = f"{total} issue(s) detected."
return header + "\n\n" + "\n".join(lines)
@@ -3086,8 +3592,9 @@ def register_queen_lifecycle_tools(
try:
if focus is None:
# Default: brief prose summary
red_flags = _detect_red_flags(bus) if bus else 0
return _format_summary(preamble, red_flags)
health_snapshot = _get_worker_health_snapshot()
red_flags = _detect_red_flags(bus, health_snapshot) if bus else 0
return _format_summary(preamble, red_flags, health_snapshot)
if bus is None:
return (
@@ -3102,7 +3609,7 @@ def register_queen_lifecycle_tools(
elif focus == "tools":
return _format_tools(bus, last_n)
elif focus == "issues":
return _format_issues(bus)
return _format_issues(bus, _get_worker_health_snapshot())
elif focus == "progress":
return await _format_progress(runtime, bus)
elif focus == "full":
@@ -3399,14 +3906,6 @@ def register_queen_lifecycle_tools(
available immediately. The user will see the agent's graph and
can interact with it without opening a new tab.
"""
runtime = _get_runtime()
if runtime is not None:
try:
await session_manager.unload_worker(manager_session_id)
except Exception as e:
logger.error("Failed to unload existing worker: %s", e, exc_info=True)
return json.dumps({"error": f"Failed to unload existing worker: {e}"})
try:
resolved_path = validate_agent_path(agent_path)
except ValueError as e:
@@ -3414,6 +3913,19 @@ def register_queen_lifecycle_tools(
if not resolved_path.exists():
return json.dumps({"error": f"Agent path does not exist: {agent_path}"})
validation_report = await _run_package_validation(str(resolved_path))
if _validation_blocks_stage_or_run(validation_report):
failures = _validation_failures(validation_report)
return json.dumps(
{
"error": (
f"Cannot load agent '{resolved_path.name}' because validation failed. "
"Fix the package and re-run validate_agent_package() before loading."
),
"validation_failures": failures,
}
)
# Pre-check: verify the module exports goal/nodes/edges before
# attempting the full load. This gives the queen an actionable
# error message instead of a cryptic ImportError or TypeError.
@@ -3459,6 +3971,14 @@ def register_queen_lifecycle_tools(
}
)
runtime = _get_runtime()
if runtime is not None:
try:
await session_manager.unload_worker(manager_session_id)
except Exception as e:
logger.error("Failed to unload existing worker: %s", e, exc_info=True)
return json.dumps({"error": f"Failed to unload existing worker: {e}"})
try:
updated_session = await session_manager.load_worker(
manager_session_id,
@@ -3607,60 +4127,35 @@ def register_queen_lifecycle_tools(
if runtime is None:
return json.dumps({"error": "No worker loaded in this session."})
worker_path = getattr(session, "worker_path", None)
worker_name = Path(worker_path).name if worker_path else ""
validation_report = await _run_package_validation(
str(worker_path) if worker_path else worker_name
)
if _validation_blocks_stage_or_run(validation_report):
failures = _validation_failures(validation_report)
return json.dumps(
{
"error": (
f"Cannot run agent '{worker_name or 'current worker'}' because validation "
"is failing. Fix the package and reload it before running."
),
"validation_failures": failures,
}
)
try:
# Pre-flight: validate credentials and resync MCP servers.
loop = asyncio.get_running_loop()
async def _preflight():
cred_error: CredentialError | None = None
try:
await loop.run_in_executor(
None,
lambda: validate_credentials(
runtime.graph.nodes,
interactive=False,
skip=False,
),
)
except CredentialError as e:
cred_error = e
runner = getattr(session, "runner", None)
if runner:
try:
await loop.run_in_executor(
None,
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
)
except Exception as e:
logger.warning("MCP resync failed: %s", e)
if cred_error is not None:
raise cred_error
try:
await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT)
except TimeoutError:
logger.warning(
"run_agent_with_input preflight timed out after %ds — proceeding",
_START_PREFLIGHT_TIMEOUT,
)
except CredentialError:
raise # handled below
await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT)
# Resume timers in case they were paused by a previous stop
runtime.resume_timers()
# Get session state from any prior execution for memory continuity
session_state = runtime._get_primary_session_state("default") or {}
if session_id:
session_state["resume_session_id"] = session_id
exec_id = await runtime.trigger(
entry_point_id="default",
input_data={"user_request": task},
session_state=session_state,
input_data=_build_worker_input_data(runtime, task, session_id=session_id),
# Fresh manual worker runs avoid stale state leaking from a
# previous execution into Codex's next tool/planning turn.
session_state=None,
)
# Switch to running phase
@@ -3693,6 +4188,91 @@ def register_queen_lifecycle_tools(
except Exception as e:
return json.dumps({"error": f"Failed to start worker: {e}"})
async def rerun_worker_with_last_input() -> str:
"""Rerun the loaded worker using the last complete structured input payload."""
runtime = _get_runtime()
if runtime is None:
return json.dumps({"error": "No worker loaded in this session."})
worker_path = getattr(session, "worker_path", None)
worker_name = Path(worker_path).name if worker_path else ""
validation_report = await _run_package_validation(
str(worker_path) if worker_path else worker_name
)
if _validation_blocks_stage_or_run(validation_report):
failures = _validation_failures(validation_report)
return json.dumps(
{
"error": (
f"Cannot rerun agent '{worker_name or 'current worker'}' "
"because validation "
"is failing. Fix the package and reload it before running."
),
"validation_failures": failures,
}
)
allowed_keys = _get_default_entry_input_keys(runtime)
input_data = {
key: _normalize_worker_input_value(key, value)
for key, value in _load_recent_worker_input_defaults(
runtime,
allowed_keys,
session_id=session_id,
).items()
}
work_keys = [key for key in allowed_keys if key not in {"user_request", "task", "feedback"}]
if work_keys:
missing = [key for key in work_keys if input_data.get(key) in (None, "")]
if missing:
return json.dumps(
{
"error": "No complete previous worker input is available for a "
"same-defaults rerun.",
"missing_inputs": missing,
}
)
try:
await _preflight_worker_run(session, runtime, _START_PREFLIGHT_TIMEOUT)
runtime.resume_timers()
exec_id = await runtime.trigger(
entry_point_id="default",
input_data=input_data,
session_state=None,
)
if phase_state is not None:
await phase_state.switch_to_running()
_update_meta_json(session_manager, manager_session_id, {"phase": "running"})
return json.dumps(
{
"status": "started",
"phase": "running",
"execution_id": exec_id,
"input_data": input_data,
}
)
except CredentialError as e:
error_payload = credential_errors_to_json(e)
error_payload["agent_path"] = str(getattr(session, "worker_path", "") or "")
bus = getattr(session, "event_bus", None)
if bus is not None:
await bus.publish(
AgentEvent(
type=EventType.CREDENTIALS_REQUIRED,
stream_id="queen",
data=error_payload,
)
)
return json.dumps(error_payload)
except Exception as e:
return json.dumps({"error": f"Failed to rerun worker: {e}"})
_run_input_tool = Tool(
name="run_agent_with_input",
description=(
@@ -3716,6 +4296,21 @@ def register_queen_lifecycle_tools(
)
tools_registered += 1
_rerun_tool = Tool(
name="rerun_worker_with_last_input",
description=(
"Rerun the loaded worker using the most recent complete structured input payload. "
"Use this when the user asks to run again with the same defaults or same input."
),
parameters={"type": "object", "properties": {}},
)
registry.register(
"rerun_worker_with_last_input",
_rerun_tool,
lambda _inputs: rerun_worker_with_last_input(),
)
tools_registered += 1
# --- set_trigger -----------------------------------------------------------
async def set_trigger(
+177 -107
View File
@@ -36,6 +36,172 @@ logger = logging.getLogger(__name__)
# How many tool_log steps to include in the health summary
_DEFAULT_LAST_N_STEPS = 40
_NON_ACCEPT_VERDICTS = frozenset({"RETRY", "CONTINUE", "ESCALATE"})
def classify_worker_health(
*,
session_status: str,
recent_verdicts: list[str],
steps_since_last_accept: int,
stall_minutes: float | None,
) -> tuple[str, list[str]]:
"""Classify worker health from persisted run evidence.
Keeping this logic at module scope lets Queen-facing status views reuse the
exact same health signals as the monitoring tool instead of drifting into a
separate, weaker interpretation of worker state.
"""
issue_signals: list[str] = []
if session_status == "failed":
issue_signals.append("failed_session")
if stall_minutes is not None:
if stall_minutes >= 5:
issue_signals.append("stalled")
elif stall_minutes >= 2:
issue_signals.append("slow_progress")
if steps_since_last_accept >= 6:
issue_signals.append("long_non_accept_streak")
elif steps_since_last_accept >= 4:
issue_signals.append("judge_pressure")
if len(recent_verdicts) >= 4 and all(v in _NON_ACCEPT_VERDICTS for v in recent_verdicts[-4:]):
issue_signals.append("recent_non_accept_churn")
issue_signals = list(dict.fromkeys(issue_signals))
if any(sig in issue_signals for sig in ("failed_session", "stalled", "long_non_accept_streak")):
return "critical", issue_signals
if issue_signals:
return "warning", issue_signals
return "healthy", issue_signals
def read_worker_health_snapshot(
storage_path: Path,
*,
session_id: str | None = None,
last_n_steps: int = _DEFAULT_LAST_N_STEPS,
default_session_id: str | None = None,
worker_agent_id: str | None = None,
worker_graph_id: str | None = None,
) -> dict[str, object]:
"""Read persisted worker logs and return the structured health snapshot.
This is the shared source of truth for worker-health reporting. The
monitoring tool returns it as JSON, while Queen's user-facing summaries can
consume the same dict directly to avoid underreporting issue signals.
"""
storage_path = Path(storage_path)
resolved_worker_agent_id = worker_agent_id or storage_path.name
resolved_worker_graph_id = worker_graph_id or storage_path.name
# Auto-discover the most recent session if not specified.
if not session_id or session_id == "auto":
sessions_dir = storage_path / "sessions"
if not sessions_dir.exists():
return {"error": "No sessions found — worker has not started yet"}
if default_session_id and (sessions_dir / default_session_id).is_dir():
session_id = default_session_id
else:
candidates = [
d for d in sessions_dir.iterdir() if d.is_dir() and (d / "state.json").exists()
]
if not candidates:
return {"error": "No sessions found — worker has not started yet"}
def _sort_key(d: Path):
try:
state = json.loads((d / "state.json").read_text(encoding="utf-8"))
priority = 0 if state.get("status", "") in ("in_progress", "running") else 1
return (priority, -d.stat().st_mtime)
except Exception:
return (2, 0)
candidates.sort(key=_sort_key)
session_id = candidates[0].name
session_dir = storage_path / "sessions" / str(session_id)
tool_logs_path = session_dir / "logs" / "tool_logs.jsonl"
state_path = session_dir / "state.json"
if not session_dir.exists() or not state_path.exists():
return {"error": f"No persisted worker state found for session '{session_id}'"}
session_status = "unknown"
if state_path.exists():
try:
state = json.loads(state_path.read_text(encoding="utf-8"))
session_status = state.get("status", "unknown")
except Exception:
pass
steps: list[dict] = []
if tool_logs_path.exists():
try:
with open(tool_logs_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
steps.append(json.loads(line))
except json.JSONDecodeError:
continue
except OSError as e:
return {"error": f"Could not read tool logs: {e}"}
total_steps = len(steps)
recent = steps[-last_n_steps:] if len(steps) > last_n_steps else steps
recent_verdicts = [s.get("verdict", "") for s in recent if s.get("verdict")]
steps_since_last_accept = 0
for verdict in reversed(recent_verdicts):
if verdict == "ACCEPT":
break
steps_since_last_accept += 1
last_step_time_iso: str | None = None
stall_minutes: float | None = None
if steps and tool_logs_path.exists():
try:
mtime = tool_logs_path.stat().st_mtime
last_step_time_iso = datetime.fromtimestamp(mtime, UTC).isoformat()
elapsed = (datetime.now(UTC).timestamp() - mtime) / 60
stall_minutes = round(elapsed, 1) if elapsed >= 1.0 else None
except OSError:
pass
evidence_snippet = ""
for step in reversed(recent):
text = step.get("llm_text", "")
if text:
evidence_snippet = text[:500]
break
health_status, issue_signals = classify_worker_health(
session_status=session_status,
recent_verdicts=recent_verdicts,
steps_since_last_accept=steps_since_last_accept,
stall_minutes=stall_minutes,
)
return {
"worker_agent_id": resolved_worker_agent_id,
"worker_graph_id": resolved_worker_graph_id,
"session_id": session_id,
"session_status": session_status,
"health_status": health_status,
"issue_signals": issue_signals,
"total_steps": total_steps,
"recent_verdicts": recent_verdicts,
"steps_since_last_accept": steps_since_last_accept,
"last_step_time_iso": last_step_time_iso,
"stall_minutes": stall_minutes,
"evidence_snippet": evidence_snippet,
}
def register_worker_monitoring_tools(
@@ -91,6 +257,8 @@ def register_worker_monitoring_tools(
Returns a JSON object with:
- session_id: the session inspected (useful when auto-discovered)
- session_status: "running"|"completed"|"failed"|"in_progress"|"unknown"
- health_status: "healthy"|"warning"|"critical"
- issue_signals: list of detected warning/attention categories
- total_steps: total number of log steps recorded so far
- recent_verdicts: list of last N verdict strings (ACCEPT/RETRY/CONTINUE/ESCALATE)
- steps_since_last_accept: consecutive non-ACCEPT steps from the end
@@ -98,120 +266,22 @@ def register_worker_monitoring_tools(
- stall_minutes: wall-clock minutes since last step (null if < 1 min)
- evidence_snippet: last LLM text from the most recent step (truncated)
"""
# Auto-discover the most recent session if not specified
if not session_id or session_id == "auto":
sessions_dir = storage_path / "sessions"
if not sessions_dir.exists():
return json.dumps({"error": "No sessions found — worker has not started yet"})
# Prefer the queen's own session ID (set at registration time) over
# mtime-based discovery, which can pick a stale orphaned session after
# a cold-restore when a newer-but-empty session directory exists.
if default_session_id and (sessions_dir / default_session_id).is_dir():
session_id = default_session_id
else:
candidates = [
d for d in sessions_dir.iterdir() if d.is_dir() and (d / "state.json").exists()
]
if not candidates:
return json.dumps({"error": "No sessions found — worker has not started yet"})
def _sort_key(d: Path):
try:
state = json.loads((d / "state.json").read_text(encoding="utf-8"))
# in_progress/running sorts before completed/failed
priority = 0 if state.get("status", "") in ("in_progress", "running") else 1
return (priority, -d.stat().st_mtime)
except Exception:
return (2, 0)
candidates.sort(key=_sort_key)
session_id = candidates[0].name
# Resolve log paths
session_dir = storage_path / "sessions" / session_id
tool_logs_path = session_dir / "logs" / "tool_logs.jsonl"
state_path = session_dir / "state.json"
# Read session status
session_status = "unknown"
if state_path.exists():
try:
state = json.loads(state_path.read_text(encoding="utf-8"))
session_status = state.get("status", "unknown")
except Exception:
pass
# Read tool logs
steps: list[dict] = []
if tool_logs_path.exists():
try:
with open(tool_logs_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
steps.append(json.loads(line))
except json.JSONDecodeError:
continue
except OSError as e:
return json.dumps({"error": f"Could not read tool logs: {e}"})
total_steps = len(steps)
recent = steps[-last_n_steps:] if len(steps) > last_n_steps else steps
# Extract verdict sequence
recent_verdicts = [s.get("verdict", "") for s in recent if s.get("verdict")]
# Count consecutive non-ACCEPT from the end
steps_since_last_accept = 0
for v in reversed(recent_verdicts):
if v == "ACCEPT":
break
steps_since_last_accept += 1
# Timing: use tool_logs file mtime as proxy for last step time
last_step_time_iso: str | None = None
stall_minutes: float | None = None
if steps and tool_logs_path.exists():
try:
mtime = tool_logs_path.stat().st_mtime
last_step_time_iso = datetime.fromtimestamp(mtime, UTC).isoformat()
elapsed = (datetime.now(UTC).timestamp() - mtime) / 60
stall_minutes = round(elapsed, 1) if elapsed >= 1.0 else None
except OSError:
pass
# Evidence snippet: last LLM text
evidence_snippet = ""
for step in reversed(recent):
text = step.get("llm_text", "")
if text:
evidence_snippet = text[:500]
break
return json.dumps(
{
"worker_agent_id": _worker_agent_id,
"worker_graph_id": _worker_graph_id,
"session_id": session_id,
"session_status": session_status,
"total_steps": total_steps,
"recent_verdicts": recent_verdicts,
"steps_since_last_accept": steps_since_last_accept,
"last_step_time_iso": last_step_time_iso,
"stall_minutes": stall_minutes,
"evidence_snippet": evidence_snippet,
},
ensure_ascii=False,
snapshot = read_worker_health_snapshot(
storage_path,
session_id=session_id,
last_n_steps=last_n_steps,
default_session_id=default_session_id,
worker_agent_id=_worker_agent_id,
worker_graph_id=_worker_graph_id,
)
return json.dumps(snapshot, ensure_ascii=False)
_health_summary_tool = Tool(
name="get_worker_health_summary",
description=(
"Read the worker agent's execution logs and return a compact health snapshot. "
"Returns worker_agent_id and worker_graph_id (use these for ticket identity fields), "
"recent verdicts, step count, time since last step, and "
"health_status, issue_signals, recent verdicts, step count, time since last step, and "
"a snippet of the most recent LLM output. "
"session_id is optional — omit it to auto-discover the most recent active session."
),
+135
View File
@@ -0,0 +1,135 @@
import { describe, expect, it } from "vitest";
import type { NodeSpec } from "@/api/types";
import type { GraphNode } from "@/components/graph-types";
import {
buildStructuredRunQuestions,
canShowRunButton,
getStructuredRunInputKeys,
hasAllStructuredRunInputs,
trimStructuredRunInputs,
} from "./run-inputs";
function makeNodeSpec(overrides: Partial<NodeSpec>): NodeSpec {
return {
id: "node-1",
name: "Node 1",
description: "",
node_type: "event_loop",
input_keys: [],
output_keys: [],
nullable_output_keys: [],
tools: [],
routes: {},
max_retries: 0,
max_node_visits: 0,
client_facing: false,
success_criteria: null,
system_prompt: "",
sub_agents: [],
...overrides,
};
}
function makeGraphNode(overrides: Partial<GraphNode>): GraphNode {
return {
id: "node-1",
label: "Node 1",
status: "pending",
...overrides,
};
}
describe("getStructuredRunInputKeys", () => {
it("returns structured input keys from the first non-trigger graph node", () => {
const nodeSpecs = [
makeNodeSpec({
id: "receive-runtime-inputs",
input_keys: ["target_dir", "review_dir", "word_threshold"],
}),
];
const graphNodes = [
makeGraphNode({ id: "__trigger_default", nodeType: "trigger" }),
makeGraphNode({ id: "receive-runtime-inputs", nodeType: "execution" }),
];
expect(getStructuredRunInputKeys(nodeSpecs, graphNodes)).toEqual([
"target_dir",
"review_dir",
"word_threshold",
]);
});
it("filters out generic task-style entry keys", () => {
const nodeSpecs = [
makeNodeSpec({
id: "entry",
input_keys: ["user_request", "task", "feedback", "target_dir"],
}),
];
expect(getStructuredRunInputKeys(nodeSpecs, [])).toEqual(["target_dir"]);
});
});
describe("hasAllStructuredRunInputs", () => {
it("requires every structured key to be present and non-blank", () => {
expect(
hasAllStructuredRunInputs(["target_dir", "word_threshold"], {
target_dir: "/tmp/project",
word_threshold: "800",
}),
).toBe(true);
expect(
hasAllStructuredRunInputs(["target_dir", "word_threshold"], {
target_dir: " ",
word_threshold: "800",
}),
).toBe(false);
expect(
hasAllStructuredRunInputs(["target_dir", "word_threshold"], {
target_dir: "/tmp/project",
}),
).toBe(false);
});
});
describe("buildStructuredRunQuestions", () => {
it("creates free-text prompts for each required run input", () => {
expect(buildStructuredRunQuestions(["target_dir", "review_dir"])).toEqual([
{ id: "target_dir", prompt: "Provide target_dir for this run." },
{ id: "review_dir", prompt: "Provide review_dir for this run." },
]);
});
});
describe("canShowRunButton", () => {
it("only exposes Run when a worker session is ready and staged/running", () => {
expect(canShowRunButton("sess-1", true, "staging", true)).toBe(true);
expect(canShowRunButton("sess-1", true, "running", true)).toBe(true);
expect(canShowRunButton("sess-1", true, "planning", true)).toBe(false);
expect(canShowRunButton("sess-1", true, "building", true)).toBe(false);
expect(canShowRunButton("sess-1", false, "staging", true)).toBe(false);
expect(canShowRunButton("sess-1", true, "staging", false)).toBe(false);
expect(canShowRunButton(null, true, "staging", true)).toBe(false);
});
});
describe("trimStructuredRunInputs", () => {
it("drops stale keys that are no longer part of the current schema", () => {
expect(
trimStructuredRunInputs(["target_dir", "word_threshold"], {
target_dir: "/tmp/project",
word_threshold: 800,
stale_key: "old",
}),
).toEqual({
target_dir: "/tmp/project",
word_threshold: 800,
});
});
});
+56
View File
@@ -0,0 +1,56 @@
import type { NodeSpec } from "@/api/types";
import type { GraphNode } from "@/components/graph-types";
const GENERIC_ENTRY_KEYS = new Set(["task", "user_request", "feedback"]);
const RUNNABLE_PHASES = new Set(["staging", "running"]);
type QueenPhase = "planning" | "building" | "staging" | "running";
function isMeaningfulValue(value: unknown): boolean {
if (typeof value === "string") return value.trim().length > 0;
return value !== undefined && value !== null;
}
export function getStructuredRunInputKeys(
nodeSpecs: NodeSpec[],
graphNodes: GraphNode[],
): string[] {
const entryNodeId =
graphNodes.find((node) => node.nodeType !== "trigger")?.id ?? nodeSpecs[0]?.id;
if (!entryNodeId) return [];
const entrySpec = nodeSpecs.find((node) => node.id === entryNodeId) ?? nodeSpecs[0];
return (entrySpec?.input_keys ?? []).filter((key) => !GENERIC_ENTRY_KEYS.has(key));
}
export function hasAllStructuredRunInputs(
keys: string[],
inputData: Record<string, unknown> | null | undefined,
): inputData is Record<string, unknown> {
if (!inputData) return false;
return keys.every((key) => isMeaningfulValue(inputData[key]));
}
export function buildStructuredRunQuestions(keys: string[]) {
return keys.map((key) => ({
id: key,
prompt: `Provide ${key} for this run.`,
}));
}
export function trimStructuredRunInputs(
keys: string[],
inputData: Record<string, unknown> | null | undefined,
): Record<string, unknown> {
if (!inputData) return {};
return Object.fromEntries(keys.flatMap((key) => (key in inputData ? [[key, inputData[key]]] : [])));
}
export function canShowRunButton(
sessionId: string | null | undefined,
ready: boolean | null | undefined,
queenPhase: QueenPhase | null | undefined,
topologyReady: boolean,
): boolean {
return Boolean(sessionId && ready && topologyReady && queenPhase && RUNNABLE_PHASES.has(queenPhase));
}
+140 -22
View File
@@ -18,6 +18,13 @@ import type { LiveSession, AgentEvent, DiscoverEntry, NodeSpec, DraftGraph as Dr
import { sseEventToChatMessage, formatAgentDisplayName } from "@/lib/chat-helpers";
import { topologyToGraphNodes } from "@/lib/graph-converter";
import { cronToLabel } from "@/lib/graphUtils";
import {
buildStructuredRunQuestions,
canShowRunButton,
getStructuredRunInputKeys,
hasAllStructuredRunInputs,
trimStructuredRunInputs,
} from "@/lib/run-inputs";
import { ApiError } from "@/api/client";
const makeId = () => Math.random().toString(36).slice(2, 9);
@@ -351,7 +358,9 @@ interface AgentBackendState {
/** Multiple questions from ask_user_multiple */
pendingQuestions: { id: string; prompt: string; options?: string[] }[] | null;
/** Whether the pending question came from queen or worker */
pendingQuestionSource: "queen" | "worker" | null;
pendingQuestionSource: "queen" | "worker" | "run" | null;
/** Last structured input payload successfully used to start the worker. */
lastRunInputData: Record<string, unknown> | null;
/** Per-node context window usage (from context_usage_updated events) */
contextUsage: Record<string, { usagePct: number; messageCount: number; estimatedTokens: number; maxTokens: number }>;
/** Whether the queen's LLM supports image content — false disables the attach button */
@@ -393,6 +402,7 @@ function defaultAgentState(): AgentBackendState {
pendingOptions: null,
pendingQuestions: null,
pendingQuestionSource: null,
lastRunInputData: null,
contextUsage: {},
queenSupportsImages: true,
};
@@ -693,15 +703,71 @@ export default function Workspace() {
}
}, [sessionsByAgent, activeSessionByAgent, activeWorker, agentStates]);
const appendSystemMessage = useCallback((agentType: string, content: string) => {
setSessionsByAgent((prev) => {
const sessions = prev[agentType] || [];
const activeId = activeSessionRef.current[agentType] || sessions[0]?.id;
return {
...prev,
[agentType]: sessions.map((s) => {
if (s.id !== activeId) return s;
const errorMsg: ChatMessage = {
id: makeId(),
agent: "System",
agentColor: "",
content,
timestamp: "",
type: "system",
thread: agentType,
createdAt: Date.now(),
};
return { ...s, messages: [...s.messages, errorMsg] };
}),
};
});
}, []);
const handleRun = useCallback(async () => {
const state = agentStates[activeWorker];
if (!state?.sessionId || !state?.ready) return;
const sessions = sessionsRef.current[activeWorker] || [];
const activeId = activeSessionRef.current[activeWorker] || sessions[0]?.id;
const activeSession = sessions.find((s) => s.id === activeId) || sessions[0];
const requiredRunKeys = getStructuredRunInputKeys(
state.nodeSpecs,
activeSession?.graphNodes || [],
);
if (
requiredRunKeys.length > 0 &&
!hasAllStructuredRunInputs(requiredRunKeys, state.lastRunInputData)
) {
updateAgentState(activeWorker, {
awaitingInput: true,
pendingQuestion: null,
pendingOptions: null,
pendingQuestions: buildStructuredRunQuestions(requiredRunKeys),
pendingQuestionSource: "run",
workerRunState: "idle",
});
return;
}
const inputData =
requiredRunKeys.length > 0
? trimStructuredRunInputs(requiredRunKeys, state.lastRunInputData)
: {};
// Reset dismissed banner so a repeated 424 re-shows it
setDismissedBanner(null);
try {
updateAgentState(activeWorker, { workerRunState: "deploying" });
const result = await executionApi.trigger(state.sessionId, "default", {});
updateAgentState(activeWorker, { currentExecutionId: result.execution_id });
const result = await executionApi.trigger(state.sessionId, "default", inputData);
updateAgentState(activeWorker, {
currentExecutionId: result.execution_id,
lastRunInputData: inputData,
});
} catch (err) {
// 424 = credentials required — open the credentials modal
if (err instanceof ApiError && err.status === 424) {
@@ -714,25 +780,23 @@ export default function Workspace() {
}
const errMsg = err instanceof Error ? err.message : String(err);
setSessionsByAgent((prev) => {
const sessions = prev[activeWorker] || [];
const activeId = activeSessionRef.current[activeWorker] || sessions[0]?.id;
return {
...prev,
[activeWorker]: sessions.map((s) => {
if (s.id !== activeId) return s;
const errorMsg: ChatMessage = {
id: makeId(), agent: "System", agentColor: "",
content: `Failed to trigger run: ${errMsg}`,
timestamp: "", type: "system", thread: activeWorker, createdAt: Date.now(),
};
return { ...s, messages: [...s.messages, errorMsg] };
}),
};
});
appendSystemMessage(activeWorker, `Failed to trigger run: ${errMsg}`);
updateAgentState(activeWorker, { workerRunState: "idle" });
}
}, [agentStates, activeWorker, updateAgentState]);
}, [agentStates, activeWorker, appendSystemMessage, updateAgentState]);
const canRunLoadedWorker = canShowRunButton(
activeAgentState?.sessionId,
activeAgentState?.ready,
activeAgentState?.queenPhase,
Boolean(
activeAgentState?.nodeSpecs?.length ||
sessionsByAgent[activeWorker]?.some(
(session) =>
session.id === activeAgentState?.sessionId && session.graphNodes.length > 0,
),
),
);
// --- Fetch discovered agents for NewTabPopover ---
const [discoverAgents, setDiscoverAgents] = useState<DiscoverEntry[]>([]);
@@ -2826,6 +2890,55 @@ export default function Workspace() {
// --- handleMultiQuestionAnswer: submit answers to ask_user_multiple ---
const handleMultiQuestionAnswer = useCallback((answers: Record<string, string>) => {
const state = agentStates[activeWorker];
if (state?.pendingQuestionSource === "run") {
if (!state.sessionId || !state.ready) return;
updateAgentState(activeWorker, {
pendingQuestion: null,
pendingOptions: null,
pendingQuestions: null,
pendingQuestionSource: null,
awaitingInput: false,
workerRunState: "deploying",
});
const requiredRunKeys = getStructuredRunInputKeys(
state.nodeSpecs,
sessionsRef.current[activeWorker]?.find((s) => s.id === state.sessionId)?.graphNodes || [],
);
const trimmedAnswers = trimStructuredRunInputs(requiredRunKeys, answers);
executionApi.trigger(state.sessionId, "default", trimmedAnswers).then((result) => {
updateAgentState(activeWorker, {
currentExecutionId: result.execution_id,
lastRunInputData: trimmedAnswers,
});
}).catch((err: unknown) => {
if (err instanceof ApiError && err.status === 424) {
const errBody = err.body as Record<string, unknown>;
const credPath = (errBody?.agent_path as string) || null;
if (credPath) setCredentialAgentPath(credPath);
updateAgentState(activeWorker, {
workerRunState: "idle",
error: "credentials_required",
lastRunInputData: trimmedAnswers,
});
setCredentialsOpen(true);
return;
}
const errMsg = err instanceof Error ? err.message : String(err);
appendSystemMessage(activeWorker, `Failed to trigger run: ${errMsg}`);
updateAgentState(activeWorker, {
workerRunState: "idle",
awaitingInput: true,
pendingQuestion: null,
pendingOptions: null,
pendingQuestions: buildStructuredRunQuestions(requiredRunKeys),
pendingQuestionSource: "run",
});
});
return;
}
updateAgentState(activeWorker, {
pendingQuestion: null, pendingOptions: null,
pendingQuestions: null, pendingQuestionSource: null,
@@ -2835,7 +2948,7 @@ export default function Workspace() {
([id, answer]) => `[${id}]: ${answer}`,
);
handleSend(lines.join("\n"), activeWorker);
}, [activeWorker, handleSend, updateAgentState]);
}, [activeWorker, agentStates, appendSystemMessage, handleSend, updateAgentState]);
// --- handleQuestionDismiss: user closed the question widget without answering ---
// Injects a dismiss signal so the blocked node can continue.
@@ -2854,6 +2967,11 @@ export default function Workspace() {
awaitingInput: false,
});
if (source === "run") {
updateAgentState(activeWorker, { workerRunState: "idle" });
return;
}
// Unblock the waiting node with a dismiss signal
const dismissMsg = `[User dismissed the question: "${question}"]`;
if (source === "worker") {
@@ -3145,7 +3263,7 @@ export default function Workspace() {
: null
}
building={activeAgentState?.queenBuilding}
onRun={handleRun}
onRun={canRunLoadedWorker ? handleRun : undefined}
onPause={handlePause}
runState={activeAgentState?.workerRunState ?? "idle"}
flowchartMap={activeAgentState?.flowchartMap ?? undefined}
+1 -1
View File
@@ -33,7 +33,7 @@ async def test_codex_stream():
print(f"extra_kwargs keys: {list(extra_kwargs.keys())}")
print(f"extra_headers: {list(extra_kwargs.get('extra_headers', {}).keys())}")
model = "openai/gpt-5.3-codex"
model = "openai/gpt-5.4"
# Create the provider
provider = LiteLLMProvider(
+1 -1
View File
@@ -33,7 +33,7 @@ async def main():
return
provider = LiteLLMProvider(
model="openai/gpt-5.3-codex",
model="openai/gpt-5.4",
api_key=api_key,
api_base=api_base,
**extra_kwargs,
+306
View File
@@ -0,0 +1,306 @@
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import MagicMock
import pytest
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
from framework.graph.node import NodeContext, NodeSpec, SharedMemory
from framework.llm.provider import LLMProvider
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
from framework.server.queen_orchestrator import _client_input_counts_as_planning_ask
from framework.tools.queen_lifecycle_tools import QueenPhaseState
class MockStreamingLLM(LLMProvider):
"""Minimal streaming LLM for Codex-vs-control parity checks."""
def __init__(self, scenarios: list[list[Any]] | None = None):
self.scenarios = scenarios or []
self._call_index = 0
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools=None,
max_tokens: int = 4096,
):
if not self.scenarios:
return
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
def complete(self, messages, system="", **kwargs):
raise NotImplementedError
def text_scenario(text: str) -> list[Any]:
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
]
def tool_call_scenario(
tool_name: str,
tool_input: dict[str, Any],
*,
tool_use_id: str = "call_1",
preamble_text: str = "",
) -> list[Any]:
events: list[Any] = []
if preamble_text:
events.append(TextDeltaEvent(content=preamble_text, snapshot=preamble_text))
events.append(
ToolCallEvent(
tool_use_id=tool_use_id,
tool_name=tool_name,
tool_input=tool_input,
)
)
events.append(
FinishEvent(
stop_reason="tool_calls",
input_tokens=10,
output_tokens=5,
model="mock",
)
)
return events
def build_ctx(spec: NodeSpec, llm: LLMProvider, *, stream_id: str) -> NodeContext:
runtime = MagicMock()
runtime.start_run = MagicMock(return_value=f"session_{stream_id}")
runtime.decide = MagicMock(return_value="dec_1")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
runtime.report_problem = MagicMock()
runtime.set_node = MagicMock()
return NodeContext(
runtime=runtime,
node_id=spec.id,
node_spec=spec,
memory=SharedMemory(),
input_data={},
llm=llm,
available_tools=[],
stream_id=stream_id,
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("style", "first_turn"),
[
(
"control",
tool_call_scenario(
"ask_user",
{
"question": "What kind of agent should I design for you?",
"options": ["Summarizer"],
},
tool_use_id="ask_1",
),
),
(
"codex",
text_scenario("What kind of agent should I design for you?"),
),
],
)
async def test_codex_and_control_styles_both_count_toward_planning_gate(
style: str,
first_turn: list[Any],
) -> None:
bus = EventBus()
phase_state = QueenPhaseState(phase="planning", event_bus=bus)
received: list[AgentEvent] = []
async def capture(event: AgentEvent) -> None:
received.append(event)
if _client_input_counts_as_planning_ask(event):
phase_state.planning_ask_rounds += 1
bus.subscribe([EventType.CLIENT_INPUT_REQUESTED], capture, filter_stream="queen")
spec = NodeSpec(
id="queen",
name="Queen",
description="planning orchestrator",
node_type="event_loop",
client_facing=True,
output_keys=[],
skip_judge=True,
)
llm = MockStreamingLLM(scenarios=[first_turn])
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
ctx = build_ctx(spec, llm, stream_id="queen")
async def shutdown_after_first_block() -> None:
await asyncio.sleep(0.05)
node.signal_shutdown()
task = asyncio.create_task(shutdown_after_first_block())
result = await node.execute(ctx)
await task
assert result.success is True
assert phase_state.planning_ask_rounds == 1
assert received
if style == "control":
assert received[0].data["prompt"] == "What kind of agent should I design for you?"
assert received[0].data.get("auto_blocked") is not True
else:
assert received[0].data["prompt"] == ""
assert received[0].data["auto_blocked"] is True
@pytest.mark.asyncio
@pytest.mark.parametrize(
("style", "scenarios"),
[
(
"control",
[
tool_call_scenario(
"ask_user",
{
"question": "Paste old and new policy text.",
"options": ["I'll paste both now"],
},
tool_use_id="ask_1",
),
tool_call_scenario(
"set_output",
{
"key": "important_changes",
"value": "- Remote days increased from 2 to 4",
},
tool_use_id="set_1",
),
],
),
(
"codex",
[
text_scenario("Paste old and new policy text."),
tool_call_scenario(
"set_output",
{
"key": "important_changes",
"value": "- Remote days increased from 2 to 4",
},
tool_use_id="set_1",
),
],
),
],
)
async def test_codex_and_control_styles_complete_same_human_in_loop_run(
style: str,
scenarios: list[list[Any]],
) -> None:
spec = NodeSpec(
id=f"policy_diff_{style}",
name="Policy Diff Worker",
description="Compare two policy versions",
node_type="event_loop",
output_keys=["important_changes"],
client_facing=True,
)
llm = MockStreamingLLM(scenarios=scenarios)
node = EventLoopNode(config=LoopConfig(max_iterations=6))
ctx = build_ctx(spec, llm, stream_id=f"worker_{style}")
async def user_responds() -> None:
await asyncio.sleep(0.05)
await node.inject_event("Old policy ... New policy ...")
task = asyncio.create_task(user_responds())
result = await node.execute(ctx)
await task
assert result.success is True
assert result.output["important_changes"] == "- Remote days increased from 2 to 4"
@pytest.mark.asyncio
@pytest.mark.parametrize(
("style", "scenario"),
[
(
"control",
tool_call_scenario(
"ask_user",
{"question": "What would you like to do next?", "options": ["Rerun", "Stop"]},
tool_use_id="ask_1",
preamble_text="Root cause: checkout is failing because the DB pool is exhausted.",
),
),
(
"codex",
tool_call_scenario(
"ask_user",
{
"question": (
"Root cause: checkout is failing because the DB pool is exhausted.\n\n"
"What would you like to do next?"
),
"options": ["Rerun", "Stop"],
},
tool_use_id="ask_1",
),
),
],
)
async def test_codex_and_control_styles_surface_result_before_followup_widget(
style: str,
scenario: list[Any],
) -> None:
spec = NodeSpec(
id=f"queen_{style}",
name="Queen",
description="orchestrator",
node_type="event_loop",
client_facing=True,
output_keys=[],
skip_judge=True,
)
llm = MockStreamingLLM(scenarios=[scenario])
bus = EventBus()
received: list[AgentEvent] = []
async def capture(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
handler=capture,
)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
ctx = build_ctx(spec, llm, stream_id="queen")
async def shutdown() -> None:
await asyncio.sleep(0.05)
node.signal_shutdown()
task = asyncio.create_task(shutdown())
await node.execute(ctx)
await task
output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA]
input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED]
assert output_events
assert input_events
assert "DB pool is exhausted" in output_events[0].data["snapshot"]
assert input_events[0].data["prompt"] == "What would you like to do next?"
+448
View File
@@ -0,0 +1,448 @@
from __future__ import annotations
import asyncio
import json
from dataclasses import dataclass, field
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from aiohttp.test_utils import TestClient, TestServer
import framework.tools.queen_lifecycle_tools as qlt
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
from framework.graph.node import NodeContext, NodeSpec, SharedMemory
from framework.llm.provider import LLMProvider, Tool
from framework.llm.stream_events import FinishEvent, TextDeltaEvent, ToolCallEvent
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
from framework.server.app import create_app, validate_agent_path
from framework.server.session_manager import (
Session,
_run_validation_report_sync,
_validation_blocks_stage_or_run,
)
from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools
REPO_ROOT = Path(__file__).resolve().parents[2]
class MockStreamingLLM(LLMProvider):
"""Minimal streaming LLM for parity-gate regressions."""
def __init__(self, scenarios: list[list[Any]] | None = None):
self.scenarios = scenarios or []
self._call_index = 0
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools=None,
max_tokens: int = 4096,
):
if not self.scenarios:
return
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
def complete(self, messages, system="", **kwargs):
raise NotImplementedError
def text_scenario(text: str) -> list[Any]:
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
]
def tool_call_scenario(
tool_name: str,
tool_input: dict[str, Any],
*,
tool_use_id: str = "call_1",
) -> list[Any]:
return [
ToolCallEvent(tool_use_id=tool_use_id, tool_name=tool_name, tool_input=tool_input),
FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock"),
]
def build_ctx(
spec: NodeSpec,
llm: LLMProvider,
*,
stream_id: str = "worker",
input_data: dict[str, Any] | None = None,
) -> NodeContext:
runtime = MagicMock()
runtime.start_run = MagicMock(return_value="session_codex_parity")
runtime.decide = MagicMock(return_value="dec_1")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
runtime.report_problem = MagicMock()
runtime.set_node = MagicMock()
return NodeContext(
runtime=runtime,
node_id=spec.id,
node_spec=spec,
memory=SharedMemory(),
input_data=input_data or {},
llm=llm,
available_tools=[],
stream_id=stream_id,
)
@pytest.mark.parametrize(
"agent_ref",
[
"examples/templates/tech_news_reporter",
"examples/templates/vulnerability_assessment",
],
)
def test_codex_parity_existing_templates_validate_for_stage_run(agent_ref: str) -> None:
"""Existing checked-in agents should pass the shared stage/run gate."""
resolved = validate_agent_path(agent_ref)
report = _run_validation_report_sync(agent_ref)
assert resolved.is_dir()
assert report.get("valid") is True
assert _validation_blocks_stage_or_run(report) is False
@pytest.mark.asyncio
async def test_codex_parity_local_only_human_in_loop_run_completes() -> None:
"""A local-only client-facing worker flow should complete end to end."""
spec = NodeSpec(
id="policy_diff_worker",
name="Policy Diff Worker",
description="Compare two policy versions",
node_type="event_loop",
output_keys=["important_changes"],
client_facing=True,
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"ask_user",
{"question": "Paste old and new policy text.", "options": ["I'll paste both now"]},
tool_use_id="ask_1",
),
tool_call_scenario(
"set_output",
{
"key": "important_changes",
"value": (
"- Remote days increased from 2 to 4\n"
"- Security training increased from annual to twice yearly"
),
},
tool_use_id="set_1",
),
]
)
node = EventLoopNode(config=LoopConfig(max_iterations=6))
ctx = build_ctx(spec, llm, stream_id="worker")
async def user_responds() -> None:
await asyncio.sleep(0.05)
await node.inject_event("Old policy ... New policy ...")
task = asyncio.create_task(user_responds())
result = await node.execute(ctx)
await task
assert result.success is True
assert "Remote days increased" in result.output["important_changes"]
@pytest.mark.asyncio
async def test_codex_parity_result_is_visible_before_followup_widget() -> None:
"""Long result-bearing queen prompts should stream the result before the widget."""
spec = NodeSpec(
id="queen",
name="Queen",
description="orchestrator",
node_type="event_loop",
client_facing=True,
output_keys=[],
skip_judge=True,
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"ask_user",
{
"question": (
"Root cause: checkout is failing because the DB pool is exhausted.\n\n"
"What would you like to do next?"
),
"options": ["Rerun", "Stop"],
},
tool_use_id="ask_1",
)
]
)
bus = EventBus()
received: list[AgentEvent] = []
async def capture(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
handler=capture,
)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
ctx = build_ctx(spec, llm, stream_id="queen")
async def shutdown() -> None:
await asyncio.sleep(0.05)
node.signal_shutdown()
task = asyncio.create_task(shutdown())
await node.execute(ctx)
await task
output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA]
input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED]
assert output_events
assert input_events
assert "DB pool is exhausted" in output_events[0].data["snapshot"]
assert input_events[0].data["prompt"] == "What would you like to do next?"
@pytest.mark.asyncio
async def test_codex_parity_rerun_reuses_complete_recent_defaults(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""Rerun should keep structured inputs stable instead of relying on text reconstruction."""
registry = ToolRegistry()
registry.register(
"validate_agent_package",
Tool(
name="validate_agent_package",
description="fake validator",
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
),
lambda _inputs: json.dumps({"valid": True, "steps": {}}),
)
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
sessions_dir.mkdir(parents=True)
valid_prior_state = {
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
"input_data": {
"target_dir": "docs",
"review_dir": "docs_reviews",
"word_threshold": 800,
},
}
malformed_recent_state = {
"timestamps": {"updated_at": "2026-03-24T21:20:23"},
"input_data": {
"review_dir": "docs_reviews",
"word_threshold": "800. Validate inputs and continue.",
},
}
for session_name, state in {
"session_20260324_204400_good": valid_prior_state,
"session_20260324_212023_bad": malformed_recent_state,
}.items():
session_dir = sessions_dir / session_name
session_dir.mkdir()
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-rerun"),
graph=SimpleNamespace(
nodes=[],
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
if node_id == "process"
else None
),
),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/local_markdown_review_probe_2"),
runner=None,
)
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-rerun",
phase_state=QueenPhaseState(phase="staging"),
)
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
result = json.loads(result_raw)
assert result["status"] == "started"
runtime.trigger.assert_awaited_once()
assert runtime.trigger.await_args.kwargs["input_data"] == {
"target_dir": str((tmp_path / "docs").resolve()),
"review_dir": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 800,
}
assert runtime.trigger.await_args.kwargs["session_state"] is None
@dataclass
class _MockEntryPoint:
id: str = "default"
name: str = "Default"
entry_node: str = "start"
trigger_type: str = "manual"
trigger_config: dict = field(default_factory=dict)
@dataclass
class _MockStream:
is_awaiting_input: bool = False
_execution_tasks: dict = field(default_factory=dict)
_active_executors: dict = field(default_factory=dict)
active_execution_ids: set = field(default_factory=set)
async def cancel_execution(self, execution_id: str) -> bool:
return execution_id in self._execution_tasks
@dataclass
class _MockGraphRegistration:
graph: Any = field(default_factory=lambda: SimpleNamespace(nodes=[], edges=[], entry_node=""))
streams: dict = field(default_factory=dict)
entry_points: dict = field(default_factory=dict)
class _MockRuntime:
def __init__(self):
self._entry_points = [_MockEntryPoint()]
self._mock_streams = {"default": _MockStream()}
self._registration = _MockGraphRegistration(
streams=self._mock_streams,
entry_points={"default": self._entry_points[0]},
)
def list_graphs(self):
return ["primary"]
def get_graph_registration(self, graph_id):
if graph_id == "primary":
return self._registration
return None
def get_entry_points(self):
return self._entry_points
async def trigger(self, ep_id, input_data=None, session_state=None):
return "exec_test_123"
async def inject_input(self, node_id, content, graph_id=None, *, is_client_input=False):
return True
def pause_timers(self):
pass
async def get_goal_progress(self):
return {"progress": 0.5, "criteria": []}
def find_awaiting_node(self):
return None, None
def get_stats(self):
return {"running": True, "executions": 1}
def get_timer_next_fire_in(self, ep_id):
return None
def _make_queen_executor():
mock_node = MagicMock()
mock_node.inject_event = AsyncMock()
executor = MagicMock()
executor.node_registry = {"queen": mock_node}
return executor
def _make_session(agent_id="test_agent") -> Session:
runner = MagicMock()
runner.intro_message = "Test intro"
return Session(
id=agent_id,
event_bus=EventBus(),
llm=MagicMock(),
loaded_at=1000000.0,
queen_executor=_make_queen_executor(),
worker_id=agent_id,
worker_path=Path("/tmp/test_agent"),
runner=runner,
worker_runtime=_MockRuntime(),
worker_info=SimpleNamespace(
name="test_agent",
description="A test agent",
goal_name="test_goal",
node_count=2,
),
worker_validation_report={"valid": True, "steps": {}},
worker_validation_failures=[],
)
def _make_app_with_session(session: Session):
app = create_app()
mgr = app["manager"]
mgr._sessions[session.id] = session
return app
@pytest.mark.asyncio
async def test_codex_parity_done_for_now_parks_queen_without_new_followup() -> None:
"""Terminal stop choices should acknowledge once and park the queen."""
session = _make_session()
session.event_bus.get_history = MagicMock(
return_value=[
AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
stream_id="queen",
node_id="queen",
execution_id=session.id,
data={"options": ["Run again with same input", "Done for now"]},
)
]
)
session.event_bus.emit_client_output_delta = AsyncMock()
app = _make_app_with_session(session)
async with TestClient(TestServer(app)) as client:
resp = await client.post(
"/api/sessions/test_agent/chat",
json={"message": "No, stop here"},
)
assert resp.status == 200
data = await resp.json()
assert data["status"] == "queen"
assert data["delivered"] is True
assert session.queen_executor is None
session.event_bus.emit_client_output_delta.assert_awaited_once()
+194
View File
@@ -0,0 +1,194 @@
from __future__ import annotations
import asyncio
import json
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
from framework.graph.node import NodeContext, NodeSpec, SharedMemory
from framework.llm.provider import LLMProvider
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
from framework.server.queen_orchestrator import _client_input_counts_as_planning_ask
from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools
class MockStreamingLLM(LLMProvider):
"""Minimal streaming LLM for planning-phase regression tests."""
def __init__(self, scenarios: list[list[Any]] | None = None):
self.scenarios = scenarios or []
self._call_index = 0
async def stream(
self,
messages: list[dict[str, Any]],
system: str = "",
tools=None,
max_tokens: int = 4096,
):
if not self.scenarios:
return
events = self.scenarios[self._call_index % len(self.scenarios)]
self._call_index += 1
for event in events:
yield event
def complete(self, messages, system="", **kwargs):
raise NotImplementedError
def text_scenario(text: str) -> list[Any]:
return [
TextDeltaEvent(content=text, snapshot=text),
FinishEvent(stop_reason="stop", input_tokens=10, output_tokens=5, model="mock"),
]
def build_ctx(spec: NodeSpec, llm: LLMProvider) -> NodeContext:
runtime = MagicMock()
runtime.start_run = MagicMock(return_value="session_codex_planning")
runtime.decide = MagicMock(return_value="dec_1")
runtime.record_outcome = MagicMock()
runtime.end_run = MagicMock()
runtime.report_problem = MagicMock()
runtime.set_node = MagicMock()
return NodeContext(
runtime=runtime,
node_id=spec.id,
node_spec=spec,
memory=SharedMemory(),
input_data={"greeting": "Session started."},
llm=llm,
available_tools=[],
stream_id="queen",
)
@pytest.mark.asyncio
async def test_codex_style_text_only_planning_turn_counts_toward_ask_rounds() -> None:
"""Plain-text planning questions should satisfy the ask_rounds gate.
This reproduces the Codex failure mode: the queen asks a planning question
in plain text instead of calling ask_user(), which triggers an auto-blocked
CLIENT_INPUT_REQUESTED event with an empty prompt.
"""
bus = EventBus()
phase_state = QueenPhaseState(phase="planning", event_bus=bus)
received: list[AgentEvent] = []
async def capture(event: AgentEvent) -> None:
received.append(event)
if _client_input_counts_as_planning_ask(event):
phase_state.planning_ask_rounds += 1
bus.subscribe([EventType.CLIENT_INPUT_REQUESTED], capture, filter_stream="queen")
spec = NodeSpec(
id="queen",
name="Queen",
description="planning orchestrator",
node_type="event_loop",
client_facing=True,
output_keys=[],
skip_judge=True,
)
llm = MockStreamingLLM(scenarios=[text_scenario("What kind of agent should I design for you?")])
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
ctx = build_ctx(spec, llm)
async def shutdown_after_first_block() -> None:
await asyncio.sleep(0.05)
node.signal_shutdown()
task = asyncio.create_task(shutdown_after_first_block())
result = await node.execute(ctx)
await task
assert result.success is True
assert len(received) >= 1
assert received[0].data["prompt"] == ""
assert received[0].data["auto_blocked"] is True
assert received[0].data["assistant_text_present"] is True
assert received[0].data["assistant_text_requires_input"] is True
assert phase_state.planning_ask_rounds == 1
@pytest.mark.asyncio
async def test_save_agent_draft_accepts_two_codex_style_planning_rounds() -> None:
"""Two counted auto-blocked planning turns should unlock save_agent_draft()."""
phase_state = QueenPhaseState(phase="planning")
codex_style_event = AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
stream_id="queen",
data={
"prompt": "",
"auto_blocked": True,
"assistant_text_present": True,
"assistant_text_requires_input": True,
},
)
for _ in range(2):
if _client_input_counts_as_planning_ask(codex_style_event):
phase_state.planning_ask_rounds += 1
registry = ToolRegistry()
session = SimpleNamespace(
worker_runtime=None,
event_bus=None,
worker_path=None,
runner=None,
)
register_queen_lifecycle_tools(
registry,
session=session,
session_id="session_codex_planning",
phase_state=phase_state,
)
save_draft = registry._tools["save_agent_draft"].executor
result_raw = await save_draft(
{
"agent_name": "codex_planning_repro",
"goal": "Reproduce the planning gate.",
"nodes": [
{"id": "start"},
{"id": "discover"},
{"id": "plan"},
{"id": "review"},
{"id": "finish"},
],
"edges": [
{"source": "start", "target": "discover"},
{"source": "discover", "target": "plan"},
{"source": "plan", "target": "review"},
{"source": "review", "target": "finish"},
],
}
)
result = json.loads(result_raw)
assert phase_state.planning_ask_rounds == 2
assert result["status"] == "draft_saved"
def test_status_only_auto_block_does_not_count_toward_planning_ask_rounds() -> None:
"""Auto-blocked acknowledgements should not satisfy the planning ask gate."""
event = AgentEvent(
type=EventType.CLIENT_INPUT_REQUESTED,
stream_id="queen",
data={
"prompt": "",
"auto_blocked": True,
"assistant_text_present": True,
"assistant_text_requires_input": False,
},
)
assert _client_input_counts_as_planning_ask(event) is False
+65 -2
View File
@@ -1,8 +1,15 @@
"""Tests for framework/config.py - Hive configuration loading."""
import logging
from unittest.mock import patch
from framework.config import get_api_base, get_hive_config, get_preferred_model
from framework.config import (
get_api_base,
get_hive_config,
get_llm_extra_kwargs,
get_preferred_model,
)
from framework.llm.codex_backend import CODEX_API_BASE, is_codex_api_base, normalize_codex_api_base
class TestGetHiveConfig:
@@ -59,9 +66,65 @@ class TestOpenRouterConfig:
def test_get_api_base_keeps_explicit_openrouter_api_base(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta","api_base":"https://proxy.example/v1"}}',
(
'{"llm":{"provider":"openrouter","model":"x-ai/grok-4.20-beta",'
'"api_base":"https://proxy.example/v1"}}'
),
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
assert get_api_base() == "https://proxy.example/v1"
class TestCodexConfig:
"""Codex config helpers should share the same transport defaults."""
def test_get_api_base_uses_shared_codex_backend(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openai","model":"gpt-5.4","use_codex_subscription":true}}',
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
assert get_api_base() == CODEX_API_BASE
def test_get_llm_extra_kwargs_uses_shared_codex_transport(self, tmp_path, monkeypatch):
config_file = tmp_path / "configuration.json"
config_file.write_text(
'{"llm":{"provider":"openai","model":"gpt-5.4","use_codex_subscription":true}}',
encoding="utf-8",
)
monkeypatch.setattr("framework.config.HIVE_CONFIG_FILE", config_file)
with (
patch("framework.runner.runner.get_codex_token", return_value="tok_test"),
patch("framework.runner.runner.get_codex_account_id", return_value="acct_123"),
):
kwargs = get_llm_extra_kwargs()
assert kwargs["store"] is False
assert kwargs["allowed_openai_params"] == ["store"]
assert kwargs["extra_headers"] == {
"Authorization": "Bearer tok_test",
"User-Agent": "CodexBar",
"ChatGPT-Account-Id": "acct_123",
}
def test_codex_api_base_detection_requires_real_chatgpt_origin(self):
assert is_codex_api_base("https://chatgpt.com/backend-api/codex")
assert is_codex_api_base("https://chatgpt.com/backend-api/codex/responses")
assert not is_codex_api_base(
"https://proxy.example/v1?target=https://chatgpt.com/backend-api/codex"
)
def test_normalize_codex_api_base_strips_only_real_responses_suffix(self):
assert (
normalize_codex_api_base("https://chatgpt.com/backend-api/codex/responses")
== CODEX_API_BASE
)
assert (
normalize_codex_api_base("https://proxy.example/v1/responses")
== "https://proxy.example/v1/responses"
)
@@ -224,6 +224,12 @@ class TestUpdateSystemPrompt:
conv.update_system_prompt("updated")
assert conv.system_prompt == "updated"
def test_update_replaces_output_keys(self):
conv = NodeConversation(system_prompt="original", output_keys=["brief"])
conv.update_system_prompt("updated", output_keys=["articles_data"])
assert conv.system_prompt == "updated"
assert conv._output_keys == ["articles_data"]
# ===========================================================================
# Conversation threading through executor
@@ -372,6 +378,61 @@ class TestContinuousConversation:
)
assert "PHASE TRANSITION" in all_content
@pytest.mark.asyncio
async def test_transition_marker_uses_next_node_tools_not_stale_previous_tools(self):
runtime = _make_runtime()
web_scrape = _make_tool("web_scrape")
file_tool = _make_tool("read_file")
llm = MockStreamingLLM(
scenarios=[
_text_then_set_output("Intake done.", "brief", "enterprise ai"),
_text_finish(""),
_text_then_set_output("Research done.", "articles_data", '{"articles": []}'),
_text_finish(""),
]
)
node_a = NodeSpec(
id="a",
name="Intake",
description="Collect user preference",
node_type="event_loop",
client_facing=True,
output_keys=["brief"],
)
node_b = NodeSpec(
id="b",
name="Research",
description="Scrape recent articles",
node_type="event_loop",
input_keys=["brief"],
output_keys=["articles_data"],
tools=["web_scrape"],
)
graph = GraphSpec(
id="g1",
goal_id="g1",
entry_node="a",
nodes=[node_a, node_b],
edges=[EdgeSpec(id="e1", source="a", target="b", condition=EdgeCondition.ON_SUCCESS)],
terminal_nodes=["b"],
conversation_mode="continuous",
)
executor = GraphExecutor(runtime=runtime, llm=llm, tools=[file_tool, web_scrape])
result = await executor.execute(graph=graph, goal=_make_goal())
assert result.success
node_b_messages = llm.stream_calls[2]["messages"]
all_content = " ".join(
m.get("content", "") for m in node_b_messages if isinstance(m.get("content"), str)
)
assert "Available tools:" in all_content
assert "web_scrape" in all_content
assert "set_output" in all_content
# ===========================================================================
# Cumulative tools
+544 -2
View File
@@ -7,6 +7,7 @@ that yields pre-programmed StreamEvents to control the loop deterministically.
from __future__ import annotations
import asyncio
import contextvars
from collections.abc import AsyncIterator
from typing import Any
from unittest.mock import AsyncMock, MagicMock
@@ -348,6 +349,143 @@ class TestSetOutput:
assert result.output["result"] == "ok"
assert "bad_key" not in result.output
def test_set_output_rejects_identical_duplicate_value(self):
"""Identical repeated set_output calls should be treated as an error, not progress."""
node = EventLoopNode()
result = node._handle_set_output(
{"key": "result", "value": "42"},
["result"],
missing_keys=["result", "summary"],
current_value=42,
normalized_value=42,
)
assert result.is_error is True
assert "already set to the same value" in result.content
assert "summary" in result.content
@pytest.mark.asyncio
async def test_set_output_auto_completes_non_client_facing_node(
self,
runtime,
node_spec,
memory,
):
"""A worker node should finish immediately once required outputs are set."""
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("set_output", {"key": "result", "value": "done"}),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "done"
assert len(llm.stream_calls) == 1
@pytest.mark.asyncio
async def test_set_output_auto_completes_client_facing_node(
self,
runtime,
memory,
):
"""Client-facing nodes should also finish once required outputs are set."""
spec = NodeSpec(
id="review",
name="Review",
description="client-facing review node",
node_type="event_loop",
output_keys=["decision"],
client_facing=True,
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario("set_output", {"key": "decision", "value": "approve"}),
]
)
ctx = build_ctx(runtime, spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=5))
result = await node.execute(ctx)
assert result.success is True
assert result.output["decision"] == "approve"
assert len(llm.stream_calls) == 1
@pytest.mark.asyncio
async def test_client_facing_completes_immediately_after_user_reply_sets_all_outputs(
self,
runtime,
memory,
):
"""Client-facing nodes should finish once a post-user-reply turn sets all outputs."""
spec = NodeSpec(
id="findings-review",
name="Findings Review",
description="review findings",
node_type="event_loop",
output_keys=["continue_scanning", "feedback", "all_findings"],
client_facing=True,
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"ask_user",
{
"question": "Continue scanning or generate report?",
"options": ["Continue", "Report"],
},
tool_use_id="ask_1",
),
[
ToolCallEvent(
tool_use_id="set_continue",
tool_name="set_output",
tool_input={"key": "continue_scanning", "value": "false"},
),
ToolCallEvent(
tool_use_id="set_feedback",
tool_name="set_output",
tool_input={"key": "feedback", "value": "generate final report"},
),
ToolCallEvent(
tool_use_id="set_all",
tool_name="set_output",
tool_input={"key": "all_findings", "value": '{"ok": true}'},
),
FinishEvent(
stop_reason="tool_calls",
input_tokens=10,
output_tokens=5,
model="mock",
),
],
]
)
node = EventLoopNode(config=LoopConfig(max_iterations=10))
ctx = build_ctx(runtime, spec, memory, llm)
async def user_responds():
await asyncio.sleep(0.05)
await node.inject_event("Generate the report")
task = asyncio.create_task(user_responds())
result = await node.execute(ctx)
await task
assert result.success is True
assert result.output == {
"continue_scanning": False,
"feedback": "generate final report",
"all_findings": {"ok": True},
}
assert len(llm.stream_calls) == 2
@pytest.mark.asyncio
async def test_missing_keys_triggers_retry(self, runtime, node_spec, memory):
"""Judge accepts but output keys are missing -> retry with hint."""
@@ -399,6 +537,47 @@ class TestStallDetection:
assert result.success is False
assert "stalled" in result.error.lower()
@pytest.mark.asyncio
async def test_no_progress_churn_detection(self, runtime, node_spec, memory):
"""Different text with missing outputs should still fail if nothing progresses."""
llm = MockStreamingLLM(
scenarios=[
text_scenario("Reviewing the logs and thinking through the issue."),
text_scenario("I am narrowing down the likely cause."),
text_scenario("I have more context and am still analyzing."),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=10, stall_detection_threshold=3))
result = await node.execute(ctx)
assert result.success is False
assert "no-progress loop detected" in (result.error or "").lower()
@pytest.mark.asyncio
async def test_no_progress_counter_resets_after_output_progress(
self,
runtime,
node_spec,
memory,
):
"""A real output-setting turn should reset the no-progress churn counter."""
llm = MockStreamingLLM(
scenarios=[
text_scenario("Parsing the problem statement."),
text_scenario("Extracting the key facts now."),
tool_call_scenario("set_output", {"key": "result", "value": "triaged"}),
]
)
ctx = build_ctx(runtime, node_spec, memory, llm)
node = EventLoopNode(config=LoopConfig(max_iterations=10, stall_detection_threshold=3))
result = await node.execute(ctx)
assert result.success is True
assert result.output["result"] == "triaged"
# ===========================================================================
# EventBus lifecycle events
@@ -647,6 +826,57 @@ class TestClientFacingBlocking:
assert len(received) >= 1
assert received[0].type == EventType.CLIENT_INPUT_REQUESTED
@pytest.mark.asyncio
async def test_queen_long_ask_user_prompt_surfaces_result_before_widget(
self, runtime, memory, client_spec
):
"""Long queen prompts should stream visible result text before options."""
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"ask_user",
{
"question": (
"Root cause: checkout requests are failing because the DB pool is "
"exhausted and cart reads are timing out.\n\n"
"What would you like to do next?"
),
"options": ["Rerun", "Stop"],
},
tool_use_id="ask_1",
),
]
)
bus = EventBus()
received = []
async def capture(e):
received.append(e)
bus.subscribe(
event_types=[EventType.CLIENT_OUTPUT_DELTA, EventType.CLIENT_INPUT_REQUESTED],
handler=capture,
)
node = EventLoopNode(event_bus=bus, config=LoopConfig(max_iterations=5))
ctx = build_ctx(runtime, client_spec, memory, llm, stream_id="queen")
async def shutdown():
await asyncio.sleep(0.05)
node.signal_shutdown()
task = asyncio.create_task(shutdown())
await node.execute(ctx)
await task
output_events = [e for e in received if e.type == EventType.CLIENT_OUTPUT_DELTA]
input_events = [e for e in received if e.type == EventType.CLIENT_INPUT_REQUESTED]
assert output_events
assert input_events
assert "Root cause: checkout requests are failing" in output_events[0].data["snapshot"]
assert input_events[0].data["prompt"] == "What would you like to do next?"
@pytest.mark.asyncio
@pytest.mark.skip(reason="Hangs in non-interactive shells (client-facing blocks on stdin)")
async def test_ask_user_with_real_tools(self, runtime, memory):
@@ -905,9 +1135,79 @@ class TestEscalate:
assert result.success is True
assert result.output["result"] == "resolved after queen guidance"
assert judge.evaluate.await_count >= 1
assert judge.evaluate.await_count == 0
assert len(client_input_events) == 0
@pytest.mark.asyncio
async def test_escalate_then_complete_outputs_autocompletes_without_extra_turn(
self, runtime, memory
):
"""Worker nodes should still auto-complete after queen guidance once outputs are set."""
spec = NodeSpec(
id="csv-intake",
name="CSV Intake",
description="parse csv",
node_type="event_loop",
output_keys=["original_headers", "parsed_rows", "raw_csv_text"],
)
llm = MockStreamingLLM(
scenarios=[
tool_call_scenario(
"escalate",
{"reason": "need csv", "context": "input malformed"},
tool_use_id="esc_1",
),
[
ToolCallEvent(
tool_use_id="set_headers",
tool_name="set_output",
tool_input={"key": "original_headers", "value": '["name","email"]'},
),
ToolCallEvent(
tool_use_id="set_rows",
tool_name="set_output",
tool_input={
"key": "parsed_rows",
"value": '[{"name":"Alice","email":"alice@example.com"}]',
},
),
ToolCallEvent(
tool_use_id="set_raw",
tool_name="set_output",
tool_input={
"key": "raw_csv_text",
"value": "name,email\\nAlice,alice@example.com",
},
),
FinishEvent(
stop_reason="tool_calls",
input_tokens=10,
output_tokens=5,
model="mock",
),
],
]
)
ctx = build_ctx(runtime, spec, memory, llm, stream_id="worker")
node = EventLoopNode(config=LoopConfig(max_iterations=5))
async def queen_reply():
await asyncio.sleep(0.05)
await node.inject_event("Use the sample CSV and continue.")
task = asyncio.create_task(queen_reply())
result = await node.execute(ctx)
await task
assert result.success is True
assert result.output == {
"original_headers": ["name", "email"],
"parsed_rows": [{"name": "Alice", "email": "alice@example.com"}],
"raw_csv_text": "name,email\\nAlice,alice@example.com",
}
assert len(llm.stream_calls) == 2
# ===========================================================================
# Client-facing: _cf_expecting_work state machine
@@ -1382,6 +1682,39 @@ class TestPauseResume:
assert llm._call_index == 0
class TestToolExecutionContext:
@pytest.mark.asyncio
async def test_execute_tool_preserves_contextvars_in_threadpool(self):
marker = contextvars.ContextVar("marker", default="missing")
def tool_exec(tool_use: ToolUse) -> ToolResult:
return ToolResult(
tool_use_id=tool_use.id,
content=marker.get(),
is_error=False,
)
node = EventLoopNode(
tool_executor=tool_exec,
config=LoopConfig(tool_call_timeout_seconds=5),
)
token = marker.set("present")
try:
result = await node._execute_tool(
ToolCallEvent(
tool_use_id="tool-1",
tool_name="echo_marker",
tool_input={},
)
)
finally:
marker.reset(token)
assert result.is_error is False
assert result.content == "present"
# ===========================================================================
# Stream errors
# ===========================================================================
@@ -1783,6 +2116,17 @@ class TestFingerprintToolCalls:
)
class TestFingerprintSetOutputCalls:
"""Unit tests for _fingerprint_set_output_calls()."""
def test_basic_fingerprint(self):
results = [
{"tool_name": "set_output", "tool_input": {"key": "result", "value": {"a": 1}}},
]
fps = EventLoopNode._fingerprint_set_output_calls(results)
assert fps == [("result", '{"a": 1}')]
class TestIsToolDoomLoop:
"""Unit tests for _is_tool_doom_loop()."""
@@ -1821,6 +2165,25 @@ class TestIsToolDoomLoop:
assert is_doom is False
class TestIsOutputDoomLoop:
"""Unit tests for _is_output_doom_loop()."""
def test_at_threshold_identical(self):
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
fp = [("result", '"done"')]
is_doom, desc = node._is_output_doom_loop([fp, fp, fp])
assert is_doom is True
assert "set_output" in desc
def test_different_values_no_doom(self):
node = EventLoopNode(config=LoopConfig(tool_doom_loop_threshold=3))
fp1 = [("result", '"a"')]
fp2 = [("result", '"b"')]
fp3 = [("result", '"c"')]
is_doom, _ = node._is_output_doom_loop([fp1, fp2, fp3])
assert is_doom is False
class ToolRepeatLLM(LLMProvider):
"""LLM that produces identical tool calls across outer iterations.
@@ -1879,6 +2242,95 @@ class ToolRepeatLLM(LLMProvider):
)
class SetOutputRepeatLLM(LLMProvider):
"""LLM that repeats the same set_output-only turn across iterations."""
def __init__(self, key: str, value: str, tool_turns: int, final_text: str = "done"):
self.key = key
self.value = value
self.tool_turns = tool_turns
self.final_text = final_text
self._call_index = 0
async def stream(self, messages, system="", tools=None, max_tokens=4096):
idx = self._call_index
self._call_index += 1
outer_iter = idx // 2
is_tool_call = (idx % 2 == 0) and outer_iter < self.tool_turns
if is_tool_call:
yield ToolCallEvent(
tool_use_id=f"set_{outer_iter}",
tool_name="set_output",
tool_input={"key": self.key, "value": self.value},
)
yield FinishEvent(
stop_reason="tool_calls",
input_tokens=10,
output_tokens=5,
model="mock",
)
else:
text = f"{self.final_text} (call {idx})"
yield TextDeltaEvent(content=text, snapshot=text)
yield FinishEvent(
stop_reason="stop",
input_tokens=10,
output_tokens=5,
model="mock",
)
def complete(self, messages, system="", **kwargs) -> LLMResponse:
return LLMResponse(
content="ok",
model="mock",
stop_reason="stop",
)
class VaryingSetOutputRepeatLLM(LLMProvider):
"""LLM that repeats set_output turns with different values across iterations."""
def __init__(self, key: str, values: list[str], final_text: str = "done"):
self.key = key
self.values = values
self.final_text = final_text
self._call_index = 0
async def stream(self, messages, system="", tools=None, max_tokens=4096):
idx = self._call_index
self._call_index += 1
outer_iter = idx // 2
is_tool_call = (idx % 2 == 0) and outer_iter < len(self.values)
if is_tool_call:
yield ToolCallEvent(
tool_use_id=f"set_{outer_iter}",
tool_name="set_output",
tool_input={"key": self.key, "value": self.values[outer_iter]},
)
yield FinishEvent(
stop_reason="tool_calls",
input_tokens=10,
output_tokens=5,
model="mock",
)
else:
text = f"{self.final_text} (call {idx})"
yield TextDeltaEvent(content=text, snapshot=text)
yield FinishEvent(
stop_reason="stop",
input_tokens=10,
output_tokens=5,
model="mock",
)
def complete(self, messages, system="", **kwargs) -> LLMResponse:
return LLMResponse(
content="ok",
model="mock",
stop_reason="stop",
)
class TestToolDoomLoopIntegration:
"""Integration tests for doom loop detection in execute().
@@ -2263,7 +2715,97 @@ class TestToolDoomLoopIntegration:
assert result.success is True
# Doom loop MUST fire for repeatedly-failing tool calls
assert len(doom_events) >= 1
assert "failing_tool" in doom_events[0].data["description"]
@pytest.mark.asyncio
async def test_repeated_identical_set_output_turns_fail_fast(
self,
runtime,
node_spec,
memory,
):
"""Repeated identical set_output-only turns should fail instead of spinning forever."""
node_spec.output_keys = ["result", "review_manifest"]
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
llm = SetOutputRepeatLLM("result", "same summary", tool_turns=4)
bus = EventBus()
doom_events: list = []
bus.subscribe(
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
handler=lambda e: doom_events.append(e),
)
ctx = build_ctx(runtime, node_spec, memory, llm, tools=[])
node = EventLoopNode(
judge=judge,
event_bus=bus,
config=LoopConfig(
max_iterations=10,
tool_doom_loop_threshold=3,
stall_similarity_threshold=1.0,
),
)
result = await node.execute(ctx)
assert result.success is False
assert "Output doom loop detected" in (result.error or "")
assert doom_events
assert "set_output" in doom_events[0].data["description"]
assert "result" in doom_events[0].data["description"]
@pytest.mark.asyncio
async def test_meta_reset_set_output_turns_fail_fast(
self,
runtime,
node_spec,
memory,
):
"""Fresh-payload reset chatter should trip the output doom loop guard."""
node_spec.output_keys = ["rules", "candidates", "scan_stats"]
judge = AsyncMock(spec=JudgeProtocol)
judge.evaluate = AsyncMock(return_value=JudgeVerdict(action="RETRY"))
llm = VaryingSetOutputRepeatLLM(
"rules",
[
(
"New event acknowledged. Awaiting fresh request payload "
"(phase transition details + structured inputs) to proceed."
),
(
"Context reset complete. Awaiting fresh phase transition payload "
"and structured inputs to proceed."
),
(
"Ready for fresh request payload with phase transition "
"instructions and structured inputs."
),
],
)
bus = EventBus()
doom_events: list = []
bus.subscribe(
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
handler=lambda e: doom_events.append(e),
)
ctx = build_ctx(runtime, node_spec, memory, llm, tools=[])
node = EventLoopNode(
judge=judge,
event_bus=bus,
config=LoopConfig(
max_iterations=10,
tool_doom_loop_threshold=3,
stall_similarity_threshold=1.0,
),
)
result = await node.execute(ctx)
assert result.success is False
assert "fresh payload" in (result.error or "").lower()
assert doom_events
assert "fresh payload" in doom_events[0].data["description"].lower()
# ===========================================================================
+257
View File
@@ -241,6 +241,48 @@ class TestToolConversion:
with pytest.raises(ValueError, match="Failed to parse tool call arguments"):
provider._parse_tool_call_arguments('{"question": foo', "ask_user")
def test_parse_tool_call_arguments_recovers_pythonish_payloads(self):
"""Single-quoted and trailing-comma argument payloads should be recovered."""
provider = LiteLLMProvider(model="openai/gpt-5.4", api_key="test-key")
parsed = provider._parse_tool_call_arguments(
"{'question': 'Continue?', 'options': ['Yes', 'No'],}",
"ask_user",
)
assert parsed == {
"question": "Continue?",
"options": ["Yes", "No"],
}
def test_parse_tool_call_arguments_keeps_null_inside_strings(self):
"""Literal normalization should not mutate quoted text values."""
provider = LiteLLMProvider(model="openai/gpt-5.4", api_key="test-key")
parsed = provider._parse_tool_call_arguments(
"{'hypothesis': 'null hypothesis', 'approved': false}",
"summarize",
)
assert parsed == {
"hypothesis": "null hypothesis",
"approved": False,
}
def test_parse_tool_call_arguments_strips_json_code_fences(self):
"""Fence stripping should remove the language tag before JSON parsing."""
provider = LiteLLMProvider(model="openai/gpt-5.4", api_key="test-key")
parsed = provider._parse_tool_call_arguments(
'```json\n{"question":"Continue?","options":["Yes","No"]}\n```',
"ask_user",
)
assert parsed == {
"question": "Continue?",
"options": ["Yes", "No"],
}
class TestAnthropicProviderBackwardCompatibility:
"""Test AnthropicProvider backward compatibility with LiteLLM backend."""
@@ -731,6 +773,221 @@ class TestMiniMaxStreamFallback:
assert not LiteLLMProvider(model="gpt-4o-mini", api_key="x")._is_minimax_model()
class TestCodexEmptyStreamRecovery:
"""Codex empty streams should fall back before surfacing ghost-stream retries."""
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_recovers_empty_codex_stream_via_nonstream_completion(
self,
mock_acompletion,
):
"""An empty Codex stream should be salvaged with a non-stream completion."""
from framework.llm.stream_events import FinishEvent, TextDeltaEvent
provider = LiteLLMProvider(
model="openai/gpt-5.4",
api_key="test-key",
api_base="https://chatgpt.com/backend-api/codex",
)
class EmptyStreamResponse:
chunks: list = []
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
recovered = MagicMock()
recovered.choices = [MagicMock()]
recovered.choices[0].message.content = "Recovered via fallback"
recovered.choices[0].message.tool_calls = []
recovered.choices[0].finish_reason = "stop"
recovered.model = provider.model
recovered.usage.prompt_tokens = 12
recovered.usage.completion_tokens = 4
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
return EmptyStreamResponse()
return recovered
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(messages=[{"role": "user", "content": "hi"}]):
events.append(event)
text_events = [event for event in events if isinstance(event, TextDeltaEvent)]
assert len(text_events) == 1
assert text_events[0].snapshot == "Recovered via fallback"
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "stop"
assert finish_events[0].input_tokens == 12
assert finish_events[0].output_tokens == 4
assert mock_acompletion.call_count == 2
assert mock_acompletion.call_args_list[0].kwargs["stream"] is True
assert "stream" not in mock_acompletion.call_args_list[1].kwargs
@pytest.mark.asyncio
@patch("litellm.acompletion")
async def test_stream_recovers_empty_codex_stream_with_tool_calls(
self,
mock_acompletion,
):
"""Non-stream fallback should preserve tool calls, not just text."""
from framework.llm.stream_events import FinishEvent, ToolCallEvent
provider = LiteLLMProvider(
model="openai/gpt-5.4",
api_key="test-key",
api_base="https://chatgpt.com/backend-api/codex/responses",
)
class EmptyStreamResponse:
chunks: list = []
def __aiter__(self):
return self
async def __anext__(self):
raise StopAsyncIteration
tc = MagicMock()
tc.id = "tool_1"
tc.function.name = "ask_user"
tc.function.arguments = '{"question":"Continue?","options":["Yes","No"]}'
recovered = MagicMock()
recovered.choices = [MagicMock()]
recovered.choices[0].message.content = ""
recovered.choices[0].message.tool_calls = [tc]
recovered.choices[0].finish_reason = "tool_calls"
recovered.model = provider.model
recovered.usage.prompt_tokens = 14
recovered.usage.completion_tokens = 5
async def side_effect(*args, **kwargs):
if kwargs.get("stream"):
return EmptyStreamResponse()
return recovered
mock_acompletion.side_effect = side_effect
events = []
async for event in provider.stream(
messages=[{"role": "user", "content": "Should we continue?"}],
tools=[
Tool(
name="ask_user",
description="Ask the user",
parameters={"properties": {"question": {"type": "string"}}},
)
],
):
events.append(event)
tool_events = [event for event in events if isinstance(event, ToolCallEvent)]
assert len(tool_events) == 1
assert tool_events[0].tool_name == "ask_user"
assert tool_events[0].tool_input == {
"question": "Continue?",
"options": ["Yes", "No"],
}
finish_events = [event for event in events if isinstance(event, FinishEvent)]
assert len(finish_events) == 1
assert finish_events[0].stop_reason == "tool_calls"
class TestCodexRequestHardening:
def test_codex_build_completion_kwargs_splits_prompt_and_forces_tool_choice(self):
"""Codex requests should chunk large system prompts and require tools when needed."""
provider = LiteLLMProvider(
model="openai/gpt-5.4",
api_key="test-key",
api_base="https://chatgpt.com/backend-api/codex/responses",
)
kwargs = provider._build_completion_kwargs(
messages=[{"role": "user", "content": "hi"}],
system="# Identity\n" + ("rule\n" * 2000),
tools=[
Tool(
name="ask_user",
description="Ask the user",
parameters={"properties": {"question": {"type": "string"}}},
)
],
max_tokens=256,
response_format=None,
json_mode=False,
stream=True,
)
system_messages = [m for m in kwargs["messages"] if m["role"] == "system"]
assert len(system_messages) >= 2
assert system_messages[0]["content"].startswith("# Codex Execution Contract")
assert kwargs["tool_choice"] == "required"
assert kwargs["store"] is False
assert "max_tokens" not in kwargs
assert "stream_options" not in kwargs
assert kwargs["api_base"] == "https://chatgpt.com/backend-api/codex"
assert "store" in kwargs["allowed_openai_params"]
def test_codex_merge_tool_call_chunk_handles_parallel_calls_with_broken_indexes(self):
"""Codex chunk merging should survive index=0 for multiple parallel tool calls."""
from types import SimpleNamespace
provider = LiteLLMProvider(
model="openai/gpt-5.4",
api_key="test-key",
api_base="https://chatgpt.com/backend-api/codex",
)
acc: dict[int, dict[str, str]] = {}
last_idx = 0
chunks = [
SimpleNamespace(
id="tool_1",
index=0,
function=SimpleNamespace(name="web_search", arguments='{"query":"alpha'),
),
SimpleNamespace(
id="tool_2",
index=0,
function=SimpleNamespace(name="read_file", arguments='{"path":"beta'),
),
SimpleNamespace(
id=None,
index=0,
function=SimpleNamespace(name=None, arguments='"}'),
),
SimpleNamespace(
id=None,
index=0,
function=SimpleNamespace(name=None, arguments='"}'),
),
]
for chunk in chunks:
last_idx = provider._merge_tool_call_chunk(acc, chunk, last_idx)
assert len(acc) == 2
parsed = [
provider._parse_tool_call_arguments(slot["arguments"], slot["name"])
for _, slot in sorted(acc.items())
]
assert parsed == [
{"query": "alpha"},
{"path": "beta"},
]
class TestOpenRouterToolCompatFallback:
"""OpenRouter models should fall back when native tool use is unavailable."""
@@ -0,0 +1,985 @@
from __future__ import annotations
import json
import os
import time
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
import framework.tools.queen_lifecycle_tools as qlt
from framework.llm.provider import Tool
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.event_bus import EventBus
from framework.tools.queen_lifecycle_tools import QueenPhaseState, register_queen_lifecycle_tools
def _write_worker_logs(
storage_path: Path,
session_id: str,
*,
session_status: str,
steps: list[dict[str, object]],
) -> Path:
session_dir = storage_path / "sessions" / session_id
logs_dir = session_dir / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
(session_dir / "state.json").write_text(
json.dumps({"status": session_status}),
encoding="utf-8",
)
log_path = logs_dir / "tool_logs.jsonl"
log_path.write_text(
"".join(json.dumps(step) + "\n" for step in steps),
encoding="utf-8",
)
return log_path
def _register_fake_validator(registry: ToolRegistry, report: dict) -> None:
registry.register(
"validate_agent_package",
Tool(
name="validate_agent_package",
description="fake validator",
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
),
lambda _inputs: json.dumps(report),
)
def test_parse_validation_report_handles_saved_footer() -> None:
raw = (
'{\n "valid": false,\n "steps": {"tool_validation": {"passed": false}}\n}\n\n'
"[Saved to 'validate.txt']"
)
parsed = qlt._parse_validation_report(raw)
assert parsed == {"valid": False, "steps": {"tool_validation": {"passed": False}}}
def test_validation_blocks_stage_or_run_ignores_non_blocking_warnings() -> None:
report = {
"steps": {
"behavior_validation": {
"passed": True,
"warnings": ["placeholder prompt"],
"output": "placeholder prompt",
},
"tests": {
"passed": True,
"warnings": ["1 failed"],
"summary": "1 failed",
},
},
}
assert qlt._validation_blocks_stage_or_run(report) is False
def test_invalid_validation_report_blocks_stage_or_run() -> None:
report = qlt._invalid_validation_report("validator returned garbage")
assert report["valid"] is False
assert qlt._validation_blocks_stage_or_run(report) is True
@pytest.mark.asyncio
async def test_get_worker_status_summary_flags_retry_and_judge_pressure() -> None:
registry = ToolRegistry()
bus = EventBus()
await bus.emit_node_retry(
stream_id="worker",
node_id="scan",
retry_count=1,
max_retries=3,
error="still missing required result",
)
for _ in range(4):
await bus.emit_judge_verdict(
stream_id="worker",
node_id="scan",
action="RETRY",
feedback="missing structured output",
)
runtime = SimpleNamespace(
graph_id="worker-graph",
get_graph_registration=lambda _gid: SimpleNamespace(
streams={
"default": SimpleNamespace(
active_execution_ids=["exec-1"],
get_context=lambda _exec_id: SimpleNamespace(started_at=datetime.now()),
get_waiting_nodes=lambda: [],
)
}
),
)
session = SimpleNamespace(worker_runtime=runtime, event_bus=bus, worker_path=None, runner=None)
register_queen_lifecycle_tools(registry, session=session, session_id="sess-status")
summary = await registry._tools["get_worker_status"].executor({})
assert "issue type(s) detected" in summary
@pytest.mark.asyncio
async def test_get_worker_status_issues_reports_judge_pressure() -> None:
registry = ToolRegistry()
bus = EventBus()
for action in ("CONTINUE", "RETRY", "RETRY", "ESCALATE"):
await bus.emit_judge_verdict(
stream_id="worker",
node_id="review",
action=action,
feedback="still not converging",
)
runtime = SimpleNamespace(
graph_id="worker-graph",
get_graph_registration=lambda _gid: SimpleNamespace(streams={}),
)
session = SimpleNamespace(worker_runtime=runtime, event_bus=bus, worker_path=None, runner=None)
register_queen_lifecycle_tools(registry, session=session, session_id="sess-issues")
issues = await registry._tools["get_worker_status"].executor({"focus": "issues"})
assert "Judge pressure detected" in issues
assert "consecutive non-ACCEPT judge verdict" in issues
@pytest.mark.asyncio
async def test_get_worker_status_summary_uses_health_snapshot_signals(tmp_path: Path) -> None:
storage_path = tmp_path / "agent_store"
storage_path.mkdir(parents=True, exist_ok=True)
log_path = _write_worker_logs(
storage_path,
"sess-health",
session_status="running",
steps=[
{"verdict": "CONTINUE", "llm_text": "thinking"},
{"verdict": "RETRY", "llm_text": "retrying"},
{"verdict": "RETRY", "llm_text": "still retrying"},
{"verdict": "ESCALATE", "llm_text": "need help"},
],
)
three_minutes_ago = time.time() - 180
os.utime(log_path, (three_minutes_ago, three_minutes_ago))
registry = ToolRegistry()
bus = EventBus()
runtime = SimpleNamespace(
graph_id="worker-graph",
get_graph_registration=lambda _gid: SimpleNamespace(
streams={
"default": SimpleNamespace(
active_execution_ids=["exec-1"],
get_context=lambda _exec_id: SimpleNamespace(started_at=datetime.now()),
get_waiting_nodes=lambda: [],
)
}
),
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=bus,
worker_path=storage_path,
runner=None,
)
register_queen_lifecycle_tools(registry, session=session, session_id="sess-health")
summary = await registry._tools["get_worker_status"].executor({})
assert "issue signal(s) detected" in summary
assert "judge_pressure" in summary
assert "recent_non_accept_churn" in summary
@pytest.mark.asyncio
async def test_get_worker_status_issues_includes_health_snapshot_signals(tmp_path: Path) -> None:
storage_path = tmp_path / "agent_store"
storage_path.mkdir(parents=True, exist_ok=True)
log_path = _write_worker_logs(
storage_path,
"sess-health",
session_status="running",
steps=[
{"verdict": "CONTINUE", "llm_text": "thinking"},
{"verdict": "RETRY", "llm_text": "retrying"},
{"verdict": "RETRY", "llm_text": "still retrying"},
{"verdict": "ESCALATE", "llm_text": "need help"},
],
)
three_minutes_ago = time.time() - 180
os.utime(log_path, (three_minutes_ago, three_minutes_ago))
registry = ToolRegistry()
bus = EventBus()
runtime = SimpleNamespace(
graph_id="worker-graph",
get_graph_registration=lambda _gid: SimpleNamespace(streams={}),
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=bus,
worker_path=storage_path,
runner=None,
)
register_queen_lifecycle_tools(registry, session=session, session_id="sess-health")
issues = await registry._tools["get_worker_status"].executor({"focus": "issues"})
assert "Health signals:" in issues
assert "slow_progress" in issues
assert "recent_non_accept_churn" in issues
def test_build_worker_input_data_maps_bullet_task_fields_to_entry_inputs(
monkeypatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
(tmp_path / "docs").mkdir()
runtime = SimpleNamespace(
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
graph=SimpleNamespace(
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(
input_keys=[
"docs_dir",
"review_dir",
"word_threshold",
"style_rules",
"target_ratio",
]
)
if node_id == "process"
else None
),
),
)
payload = qlt._build_worker_input_data(
runtime,
(
"Run md_condense_reviewer with the following runtime config:\n"
"- docs_dir: docs/\n"
"- review_dir: docs_reviews/\n"
"- word_threshold: 800\n"
"- target_ratio: 0.6 (default)\n"
"- style_rules: Preserve headings and links.\n\n"
"Execution requirements:\n"
"1) Scan the docs directory.\n"
"2) Write review copies."
),
)
assert payload == {
"docs_dir": str((tmp_path / "docs").resolve()),
"review_dir": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 800,
"style_rules": "Preserve headings and links.",
"target_ratio": 0.6,
}
def test_build_worker_input_data_maps_equals_style_runtime_fields(
monkeypatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
runtime = SimpleNamespace(
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
graph=SimpleNamespace(
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir_mode", "word_threshold"])
if node_id == "process"
else None
),
),
)
payload = qlt._build_worker_input_data(
runtime,
("Yes, rerun with target_dir=docs review_dir_mode=next_to_source word_threshold=800"),
)
assert payload == {
"target_dir": str((tmp_path / "docs").resolve()),
"review_dir_mode": "next_to_source",
"word_threshold": 800,
}
def test_build_worker_input_data_backfills_missing_fields_from_recent_session(
monkeypatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
sessions_dir.mkdir(parents=True)
valid_prior_state = {
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
"input_data": {
"target_dir": "docs",
"review_dir": "docs_reviews",
"word_threshold": 800,
},
}
malformed_recent_state = {
"timestamps": {"updated_at": "2026-03-24T21:20:23"},
"input_data": {
"review_dir": "docs_reviews",
"word_threshold": "800. Validate inputs and continue.",
},
}
for session_name, state in {
"session_20260324_204400_good": valid_prior_state,
"session_20260324_212023_bad": malformed_recent_state,
}.items():
session_dir = sessions_dir / session_name
session_dir.mkdir()
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
graph=SimpleNamespace(
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
if node_id == "process"
else None
),
),
)
payload = qlt._build_worker_input_data(
runtime,
("review_dir: docs_reviews\nword_threshold: 800. Validate inputs and continue."),
)
assert payload == {
"target_dir": str((tmp_path / "docs").resolve()),
"review_dir": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 800,
}
def test_build_worker_input_data_reuses_recent_defaults_for_rerun_phrase(
monkeypatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
sessions_dir.mkdir(parents=True)
state = {
"timestamps": {"updated_at": "2026-03-24T21:17:00"},
"input_data": {
"target_dir": "docs",
"review_dir": "docs_reviews",
"word_threshold": 800,
},
}
session_dir = sessions_dir / "session_20260324_211700_prev"
session_dir.mkdir()
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
graph=SimpleNamespace(
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
if node_id == "process"
else None
),
),
)
payload = qlt._build_worker_input_data(runtime, "Run again with same defaults")
assert payload == {
"target_dir": str((tmp_path / "docs").resolve()),
"review_dir": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 800,
}
def test_build_worker_input_data_backfills_from_recent_result_output(
monkeypatch, tmp_path: Path
) -> None:
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
sessions_dir.mkdir(parents=True)
state = {
"timestamps": {"updated_at": "2026-03-24T23:35:19"},
"input_data": {
"review_dir": "docs_reviews",
"word_threshold": 800,
},
"result": {
"output": {
"target_dir": "docs",
"review_dir": "docs_reviews",
"word_threshold": 800,
}
},
}
session_dir = sessions_dir / "session_20260324_233519_prev"
session_dir.mkdir()
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
graph=SimpleNamespace(
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
if node_id == "process"
else None
),
),
)
payload = qlt._build_worker_input_data(
runtime,
("review_dir: docs_reviews\nword_threshold: 600"),
)
assert payload == {
"target_dir": str((tmp_path / "docs").resolve()),
"review_dir": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 600,
}
@pytest.mark.asyncio
async def test_load_built_agent_blocks_invalid_package(monkeypatch, tmp_path: Path) -> None:
registry = ToolRegistry()
captured: dict[str, str] = {}
registry.register(
"validate_agent_package",
Tool(
name="validate_agent_package",
description="fake validator",
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
),
lambda inputs: (
captured.setdefault("agent_name", inputs["agent_name"]),
json.dumps(
{
"valid": False,
"steps": {
"behavior_validation": {
"passed": False,
"output": (
"Node 'scan-markdown' has a blank or placeholder system_prompt"
),
}
},
}
),
)[1],
)
session = SimpleNamespace(worker_runtime=None, event_bus=None, worker_path=None, runner=None)
fake_manager = SimpleNamespace(
get_session=lambda _sid: None,
unload_worker=AsyncMock(),
load_worker=AsyncMock(),
)
phase_state = QueenPhaseState(phase="building")
register_queen_lifecycle_tools(
registry,
session=session,
session_manager=fake_manager,
manager_session_id="sess-1",
phase_state=phase_state,
)
agent_dir = tmp_path / "broken_agent"
agent_dir.mkdir()
monkeypatch.setattr(qlt, "validate_agent_path", lambda _path: agent_dir)
result_raw = await registry._tools["load_built_agent"].executor({"agent_path": str(agent_dir)})
result = json.loads(result_raw)
assert "Cannot load agent" in result["error"]
assert "behavior_validation" in result["validation_failures"][0]
assert captured["agent_name"] == str(agent_dir)
fake_manager.load_worker.assert_not_called()
fake_manager.unload_worker.assert_not_called()
@pytest.mark.asyncio
async def test_load_built_agent_keeps_current_worker_when_replacement_fails_validation(
monkeypatch, tmp_path: Path
) -> None:
registry = ToolRegistry()
_register_fake_validator(
registry,
{
"valid": False,
"steps": {
"behavior_validation": {
"passed": False,
"output": "Node 'scan' has a blank or placeholder system_prompt",
}
},
},
)
session = SimpleNamespace(
worker_runtime=SimpleNamespace(),
event_bus=None,
worker_path=Path("exports/existing_agent"),
runner=None,
)
fake_manager = SimpleNamespace(
get_session=lambda _sid: None,
unload_worker=AsyncMock(),
load_worker=AsyncMock(),
)
phase_state = QueenPhaseState(phase="building")
register_queen_lifecycle_tools(
registry,
session=session,
session_manager=fake_manager,
manager_session_id="sess-1",
phase_state=phase_state,
)
agent_dir = tmp_path / "broken_agent"
agent_dir.mkdir()
monkeypatch.setattr(qlt, "validate_agent_path", lambda _path: agent_dir)
result_raw = await registry._tools["load_built_agent"].executor({"agent_path": str(agent_dir)})
result = json.loads(result_raw)
assert "Cannot load agent" in result["error"]
fake_manager.unload_worker.assert_not_called()
fake_manager.load_worker.assert_not_called()
@pytest.mark.asyncio
async def test_run_agent_with_input_blocks_loaded_invalid_worker() -> None:
registry = ToolRegistry()
_register_fake_validator(
registry,
{
"valid": False,
"steps": {
"tool_validation": {
"passed": False,
"output": "Scan Markdown Files missing run_command",
}
},
},
)
runtime = SimpleNamespace(
resume_timers=MagicMock(),
trigger=AsyncMock(),
_get_primary_session_state=MagicMock(return_value={}),
graph=SimpleNamespace(nodes=[]),
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/broken_agent"),
runner=None,
)
phase_state = QueenPhaseState(phase="staging")
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-2",
phase_state=phase_state,
)
result_raw = await registry._tools["run_agent_with_input"].executor({"task": "run it"})
result = json.loads(result_raw)
assert "Cannot run agent" in result["error"]
assert "tool_validation" in result["validation_failures"][0]
runtime.trigger.assert_not_called()
@pytest.mark.asyncio
async def test_run_agent_with_input_uses_structured_entry_inputs(
monkeypatch, tmp_path: Path
) -> None:
registry = ToolRegistry()
_register_fake_validator(registry, {"valid": True, "steps": {}})
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
runtime = SimpleNamespace(
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-1"),
_get_primary_session_state=MagicMock(return_value={}),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
graph=SimpleNamespace(
nodes=[],
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(
input_keys=["docs_path", "review_path", "word_threshold", "style_rules"]
)
if node_id == "process"
else None
),
),
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/markdown_condense_approver"),
runner=None,
)
phase_state = QueenPhaseState(phase="staging")
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-3",
phase_state=phase_state,
)
result_raw = await registry._tools["run_agent_with_input"].executor(
{
"task": (
"docs_path: docs/ review_path: docs_reviews/ word_threshold: 800 "
"style_rules: Preserve headings, keep links intact."
)
}
)
result = json.loads(result_raw)
assert result["status"] == "started"
runtime.trigger.assert_awaited_once()
trigger_kwargs = runtime.trigger.await_args.kwargs
assert trigger_kwargs["input_data"] == {
"docs_path": str((tmp_path / "docs").resolve()),
"review_path": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 800,
"style_rules": "Preserve headings, keep links intact.",
}
assert trigger_kwargs["session_state"] is None
runtime._get_primary_session_state.assert_not_called()
@pytest.mark.asyncio
async def test_rerun_worker_with_last_input_reuses_complete_recent_defaults(
monkeypatch, tmp_path: Path
) -> None:
registry = ToolRegistry()
_register_fake_validator(registry, {"valid": True, "steps": {}})
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
sessions_dir.mkdir(parents=True)
valid_prior_state = {
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
"input_data": {
"target_dir": "docs",
"review_dir": "docs_reviews",
"word_threshold": 800,
"feedback": "stale",
},
}
malformed_recent_state = {
"timestamps": {"updated_at": "2026-03-24T21:20:23"},
"input_data": {
"review_dir": "docs_reviews",
"word_threshold": "800. Validate inputs and continue.",
},
}
for session_name, state in {
"session_20260324_204400_good": valid_prior_state,
"session_20260324_212023_bad": malformed_recent_state,
}.items():
session_dir = sessions_dir / session_name
session_dir.mkdir()
(session_dir / "state.json").write_text(json.dumps(state), encoding="utf-8")
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-rerun"),
graph=SimpleNamespace(
nodes=[],
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir", "word_threshold"])
if node_id == "process"
else None
),
),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/local_markdown_review_probe_2"),
runner=None,
)
phase_state = QueenPhaseState(phase="staging")
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-rerun",
phase_state=phase_state,
)
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
result = json.loads(result_raw)
assert result["status"] == "started"
runtime.trigger.assert_awaited_once()
trigger_kwargs = runtime.trigger.await_args.kwargs
assert trigger_kwargs["input_data"] == {
"target_dir": str((tmp_path / "docs").resolve()),
"review_dir": str((tmp_path / "docs_reviews").resolve()),
"word_threshold": 800,
}
assert trigger_kwargs["session_state"] is None
@pytest.mark.asyncio
async def test_rerun_worker_with_last_input_preserves_legacy_task_payload(
monkeypatch, tmp_path: Path
) -> None:
registry = ToolRegistry()
_register_fake_validator(registry, {"valid": True, "steps": {}})
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
sessions_dir.mkdir(parents=True)
session_dir = sessions_dir / "session_20260324_204400_task"
session_dir.mkdir()
(session_dir / "state.json").write_text(
json.dumps(
{
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
"input_data": {"task": "re-run the markdown review"},
}
),
encoding="utf-8",
)
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-rerun"),
graph=SimpleNamespace(
nodes=[],
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["task"]) if node_id == "process" else None
),
),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/legacy_worker"),
runner=None,
)
phase_state = QueenPhaseState(phase="staging")
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-rerun",
phase_state=phase_state,
)
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
result = json.loads(result_raw)
assert result["status"] == "started"
trigger_kwargs = runtime.trigger.await_args.kwargs
assert trigger_kwargs["input_data"] == {"task": "re-run the markdown review"}
@pytest.mark.asyncio
async def test_rerun_worker_with_last_input_uses_current_session_defaults_only(
monkeypatch, tmp_path: Path
) -> None:
registry = ToolRegistry()
_register_fake_validator(registry, {"valid": True, "steps": {}})
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
sessions_dir = tmp_path / "agent_store" / "sessions"
current_session_dir = sessions_dir / "sess-rerun-current"
current_session_dir.mkdir(parents=True)
(current_session_dir / "state.json").write_text(
json.dumps(
{
"timestamps": {"updated_at": "2026-03-24T20:44:00"},
"input_data": {"target_dir": str((tmp_path / "current").resolve())},
}
),
encoding="utf-8",
)
other_session_dir = sessions_dir / "session_20260325_204400_other"
other_session_dir.mkdir(parents=True)
(other_session_dir / "state.json").write_text(
json.dumps(
{
"timestamps": {"updated_at": "2026-03-25T20:44:00"},
"input_data": {
"target_dir": str((tmp_path / "other").resolve()),
"review_dir": str((tmp_path / "other_reviews").resolve()),
},
}
),
encoding="utf-8",
)
runtime = SimpleNamespace(
_session_store=SimpleNamespace(sessions_dir=sessions_dir),
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-rerun"),
graph=SimpleNamespace(
nodes=[],
entry_node="process",
get_node=lambda node_id: (
SimpleNamespace(input_keys=["target_dir", "review_dir"])
if node_id == "process"
else None
),
),
get_entry_points=lambda: [SimpleNamespace(id="default", entry_node="process")],
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/docs_reviewer"),
runner=None,
)
phase_state = QueenPhaseState(phase="staging")
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-rerun-current",
phase_state=phase_state,
)
result_raw = await registry._tools["rerun_worker_with_last_input"].executor({})
result = json.loads(result_raw)
assert (
result["error"]
== "No complete previous worker input is available for a same-defaults rerun."
)
assert result["missing_inputs"] == ["review_dir"]
@pytest.mark.asyncio
async def test_run_agent_with_input_blocks_when_validator_output_is_undecodable(
monkeypatch, tmp_path: Path
) -> None:
registry = ToolRegistry()
registry.register(
"validate_agent_package",
Tool(
name="validate_agent_package",
description="fake validator",
parameters={"type": "object", "properties": {"agent_name": {"type": "string"}}},
),
lambda _inputs: "not-json",
)
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
runtime = SimpleNamespace(
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-2"),
graph=SimpleNamespace(nodes=[]),
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/docs_sanitizer_agent"),
runner=None,
)
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-invalid-validator",
phase_state=QueenPhaseState(phase="staging"),
)
result_raw = await registry._tools["run_agent_with_input"].executor({"task": "run it"})
result = json.loads(result_raw)
assert "validation is failing" in result["error"]
assert result["validation_failures"] == [
"validator_subprocess: validate_agent_package returned an invalid or undecodable report"
]
runtime.trigger.assert_not_awaited()
@pytest.mark.asyncio
async def test_start_worker_starts_fresh_worker_session(monkeypatch, tmp_path: Path) -> None:
registry = ToolRegistry()
monkeypatch.setattr(qlt, "validate_credentials", lambda *args, **kwargs: None)
monkeypatch.chdir(tmp_path)
runtime = SimpleNamespace(
resume_timers=MagicMock(),
trigger=AsyncMock(return_value="exec-2"),
_get_primary_session_state=MagicMock(return_value={"resume_session_id": "old"}),
graph=SimpleNamespace(nodes=[]),
)
session = SimpleNamespace(
worker_runtime=runtime,
event_bus=None,
worker_path=Path("exports/docs_sanitizer_agent"),
runner=None,
)
register_queen_lifecycle_tools(
registry,
session=session,
session_id="sess-4",
phase_state=QueenPhaseState(phase="staging"),
)
result_raw = await registry._tools["start_worker"].executor(
{"task": "run with docs_path: docs/"}
)
result = json.loads(result_raw)
assert result["status"] == "started"
runtime.trigger.assert_awaited_once()
trigger_kwargs = runtime.trigger.await_args.kwargs
assert trigger_kwargs["session_state"] is None
runtime._get_primary_session_state.assert_not_called()
@@ -1,11 +1,14 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from framework.runtime.event_bus import EventBus
import framework.agents.worker_memory as worker_memory
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
from framework.server.session_manager import Session, SessionManager
@@ -123,3 +126,52 @@ async def test_stop_session_unsubscribes_worker_handoff() -> None:
reason="after stop",
)
assert queen_node.inject_event.await_count == 1
@pytest.mark.asyncio
async def test_worker_digest_final_completion_does_not_overwrite_terminal_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
bus = EventBus()
manager = SessionManager()
session = _make_session(bus, session_id="session_digest_final")
session.worker_path = Path("/tmp/log_triage_agent")
queen_node = SimpleNamespace(inject_event=AsyncMock())
session.queen_executor = _make_executor(queen_node)
consolidate = AsyncMock()
monkeypatch.setattr(worker_memory, "consolidate_worker_run", consolidate)
manager._subscribe_worker_digest(session)
bus.get_history = lambda event_type=None, limit=None: [
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="default",
execution_id="exec_digest",
run_id="run_digest",
)
]
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="default",
execution_id="exec_digest",
run_id="run_digest",
)
)
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="default",
execution_id="exec_digest",
run_id="run_digest",
data={"output": {"result": "final answer"}},
)
)
await asyncio.sleep(0.05)
consolidate.assert_awaited_once()
assert queen_node.inject_event.await_count == 0
@@ -0,0 +1,113 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
import pytest
import framework.runner as runner_mod
from framework.server import session_manager as sm
@pytest.mark.asyncio
async def test_load_worker_blocks_invalid_package_before_runner_load(monkeypatch) -> None:
manager = sm.SessionManager()
session = sm.Session(
id="sess-1",
event_bus=MagicMock(),
llm=MagicMock(),
loaded_at=0.0,
)
manager._sessions[session.id] = session
captured: dict[str, str] = {}
def _fake_validation(agent_ref):
captured["agent_ref"] = str(agent_ref)
return {
"valid": False,
"steps": {
"behavior_validation": {
"passed": False,
"errors": ["Node 'scan' has a blank or placeholder system_prompt"],
}
},
}
monkeypatch.setattr(
sm,
"_run_validation_report_sync",
_fake_validation,
)
called = {"runner_load": False}
class _FakeAgentRunner:
@staticmethod
def load(*args, **kwargs):
called["runner_load"] = True
raise AssertionError("AgentRunner.load should not run for invalid workers")
monkeypatch.setattr(runner_mod, "AgentRunner", _FakeAgentRunner)
with pytest.raises(sm.WorkerValidationError) as exc:
await manager.load_worker(session.id, Path("/tmp/bad_worker"))
assert "blank or placeholder system_prompt" in str(exc.value)
assert Path(captured["agent_ref"]).as_posix() == "/tmp/bad_worker"
assert called["runner_load"] is False
def test_run_validation_report_sync_uses_internal_validator_impl(monkeypatch) -> None:
captured: dict[str, object] = {}
class _Proc:
returncode = 0
stdout = '{"valid": true, "steps": {}}'
stderr = ""
def _fake_run(cmd, **kwargs):
captured["cmd"] = cmd
return _Proc()
monkeypatch.setattr(sm.subprocess, "run", _fake_run)
report = sm._run_validation_report_sync("/tmp/demo_agent")
assert report["valid"] is True
script = captured["cmd"][4]
assert "_validate_agent_package_impl" in script
assert "validate_agent_package(agent_name)" not in script
assert captured["cmd"][6] == "/tmp/demo_agent"
def test_validation_blocks_stage_or_run_ignores_non_blocking_warnings() -> None:
report = {
"steps": {
"behavior_validation": {
"passed": True,
"warnings": ["placeholder prompt"],
"output": "placeholder prompt",
},
"tests": {
"passed": True,
"warnings": ["1 failed"],
"summary": "1 failed",
},
},
}
assert sm._validation_blocks_stage_or_run(report) is False
def test_run_validation_report_sync_handles_subprocess_launcher_errors(monkeypatch) -> None:
def _boom(*args, **kwargs):
raise FileNotFoundError("uv not found")
monkeypatch.setattr(sm.subprocess, "run", _boom)
report = sm._run_validation_report_sync("/tmp/demo_agent")
assert report["valid"] is False
assert report["steps"]["validator_subprocess"]["passed"] is False
assert "uv not found" in report["steps"]["validator_subprocess"]["error"]
+1 -1
View File
@@ -331,7 +331,7 @@ def _make_codex_provider():
if not api_key or not api_base:
return None
return LiteLLMProvider(
model="openai/gpt-5.3-codex",
model="openai/gpt-5.4",
api_key=api_key,
api_base=api_base,
**extra_kwargs,
+22
View File
@@ -123,6 +123,28 @@ class TestValidateAgentPathPositive:
result = validate_agent_path(str(agent_dir))
assert isinstance(result, Path)
def test_repo_relative_path_resolves_from_repo_root_not_cwd(self, tmp_path, monkeypatch):
import framework.server.app as app_module
repo_root = tmp_path / "repo"
examples_root = repo_root / "examples"
agent_dir = examples_root / "some_agent"
agent_dir.mkdir(parents=True)
other_cwd = tmp_path / "elsewhere"
other_cwd.mkdir()
monkeypatch.setattr(app_module, "_REPO_ROOT", repo_root)
app_module._ALLOWED_AGENT_ROOTS = (
repo_root / "exports",
examples_root,
tmp_path / ".hive" / "agents",
)
monkeypatch.chdir(other_cwd)
result = validate_agent_path("examples/some_agent")
assert result == agent_dir.resolve()
# ---------------------------------------------------------------------------
# validate_agent_path: negative cases (should raise ValueError)
+129
View File
@@ -0,0 +1,129 @@
from __future__ import annotations
import json
import os
import time
from pathlib import Path
import pytest
from framework.runner.tool_registry import ToolRegistry
from framework.runtime.event_bus import EventBus
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
def _write_session_logs(
storage_path: Path,
session_id: str,
*,
session_status: str,
steps: list[dict],
) -> Path:
session_dir = storage_path / "sessions" / session_id
logs_dir = session_dir / "logs"
logs_dir.mkdir(parents=True, exist_ok=True)
(session_dir / "state.json").write_text(
json.dumps({"status": session_status}),
encoding="utf-8",
)
log_path = logs_dir / "tool_logs.jsonl"
log_path.write_text(
"".join(json.dumps(step) + "\n" for step in steps),
encoding="utf-8",
)
return log_path
@pytest.mark.asyncio
async def test_worker_health_summary_marks_healthy_runs(tmp_path: Path) -> None:
registry = ToolRegistry()
event_bus = EventBus()
storage_path = tmp_path / "agent_store"
storage_path.mkdir(parents=True, exist_ok=True)
_write_session_logs(
storage_path,
"session-healthy",
session_status="running",
steps=[
{"verdict": "RETRY", "llm_text": "first pass"},
{"verdict": "ACCEPT", "llm_text": "done"},
],
)
register_worker_monitoring_tools(
registry,
event_bus,
storage_path,
default_session_id="session-healthy",
)
raw = await registry._tools["get_worker_health_summary"].executor({})
data = json.loads(raw)
assert data["health_status"] == "healthy"
assert data["issue_signals"] == []
assert data["recent_verdicts"] == ["RETRY", "ACCEPT"]
assert data["steps_since_last_accept"] == 0
@pytest.mark.asyncio
async def test_worker_health_summary_flags_stall_and_non_accept_churn(tmp_path: Path) -> None:
registry = ToolRegistry()
event_bus = EventBus()
storage_path = tmp_path / "agent_store"
storage_path.mkdir(parents=True, exist_ok=True)
log_path = _write_session_logs(
storage_path,
"session-stalled",
session_status="running",
steps=[
{"verdict": "CONTINUE", "llm_text": "thinking"},
{"verdict": "RETRY", "llm_text": "still working"},
{"verdict": "RETRY", "llm_text": "trying again"},
{"verdict": "ESCALATE", "llm_text": "blocked"},
],
)
ten_minutes_ago = time.time() - 600
os.utime(log_path, (ten_minutes_ago, ten_minutes_ago))
register_worker_monitoring_tools(
registry,
event_bus,
storage_path,
default_session_id="session-stalled",
)
raw = await registry._tools["get_worker_health_summary"].executor({})
data = json.loads(raw)
assert data["health_status"] == "critical"
assert "stalled" in data["issue_signals"]
assert "judge_pressure" in data["issue_signals"]
assert "recent_non_accept_churn" in data["issue_signals"]
assert data["steps_since_last_accept"] == 4
assert data["stall_minutes"] is not None
@pytest.mark.asyncio
async def test_worker_health_summary_errors_when_default_session_has_no_state(
tmp_path: Path,
) -> None:
registry = ToolRegistry()
event_bus = EventBus()
storage_path = tmp_path / "agent_store"
(storage_path / "sessions" / "session-stale").mkdir(parents=True, exist_ok=True)
register_worker_monitoring_tools(
registry,
event_bus,
storage_path,
default_session_id="session-stale",
)
raw = await registry._tools["get_worker_health_summary"].executor({})
data = json.loads(raw)
assert "error" in data
assert "session-stale" in data["error"]
+3 -3
View File
@@ -1279,9 +1279,9 @@ switch ($num) {
if ($CodexCredDetected) {
$SubscriptionMode = "codex"
$SelectedProviderId = "openai"
$SelectedModel = "gpt-5.3-codex"
$SelectedMaxTokens = 16384
$SelectedMaxContextTokens = 120000
$SelectedModel = "gpt-5.4"
$SelectedMaxTokens = 128000
$SelectedMaxContextTokens = 900000
Write-Host ""
Write-Ok "Using OpenAI Codex subscription"
}
+3 -3
View File
@@ -1306,9 +1306,9 @@ case $choice in
if [ "$CODEX_CRED_DETECTED" = true ]; then
SUBSCRIPTION_MODE="codex"
SELECTED_PROVIDER_ID="openai"
SELECTED_MODEL="gpt-5.3-codex"
SELECTED_MAX_TOKENS=16384
SELECTED_MAX_CONTEXT_TOKENS=120000 # GPT Codex — 128k context window
SELECTED_MODEL="gpt-5.4"
SELECTED_MAX_TOKENS=128000
SELECTED_MAX_CONTEXT_TOKENS=900000 # GPT-5.4 — 1.05M context window
echo ""
echo -e "${GREEN}${NC} Using OpenAI Codex subscription"
fi
+3 -3
View File
@@ -561,9 +561,9 @@ switch ($num) {
if ($CodexCredDetected) {
$SubscriptionMode = "codex"
$SelectedProviderId = "openai"
$SelectedModel = "gpt-5.3-codex"
$SelectedMaxTokens = 16384
$SelectedMaxContextTokens = 120000
$SelectedModel = "gpt-5.4"
$SelectedMaxTokens = 128000
$SelectedMaxContextTokens = 900000
Write-Host ""
Write-Ok "Using OpenAI Codex subscription"
}
+3 -3
View File
@@ -870,9 +870,9 @@ case $choice in
if [ "$CODEX_CRED_DETECTED" = true ]; then
SUBSCRIPTION_MODE="codex"
SELECTED_PROVIDER_ID="openai"
SELECTED_MODEL="gpt-5.3-codex"
SELECTED_MAX_TOKENS=16384
SELECTED_MAX_CONTEXT_TOKENS=120000 # GPT Codex — 128k context window
SELECTED_MODEL="gpt-5.4"
SELECTED_MAX_TOKENS=128000
SELECTED_MAX_CONTEXT_TOKENS=900000 # GPT-5.4 — 1.05M context window
echo ""
echo -e "${GREEN}${NC} Using OpenAI Codex subscription"
fi
+588 -83
View File
@@ -68,6 +68,15 @@ mcp = FastMCP("coder-tools")
PROJECT_ROOT: str = ""
SNAPSHOT_DIR: str = ""
_PLACEHOLDER_MARKERS = (
"TODO",
"TODO:",
"TODO ",
"Add system prompt for this node",
"Add identity prompt",
"Define success criteria",
"Describe what this node does",
)
# ── Path resolution ───────────────────────────────────────────────────────
@@ -138,6 +147,388 @@ def _resolve_path(path: str) -> str:
return resolved
def _is_placeholder_text(value: str | None) -> bool:
text = (value or "").strip()
if not text:
return True
return any(marker in text for marker in _PLACEHOLDER_MARKERS)
_ENTRY_INTAKE_HINTS = (
"parse the incoming task",
"parse the task text",
"structured runtime task",
"accept structured runtime task",
"read runtime task input",
"validate runtime",
"validate runtime path",
"validate runtime paths",
"intake & validate",
"configuration values",
"intake config",
"infer values conservatively",
)
_ENTRY_DIRECT_WORK_HINTS = (
"scan",
"scanning",
"discover",
"discovery",
"search",
"fetch",
"analyze",
"analyse",
"transform",
"sanitize",
"summarize",
"summarise",
"generate",
"write",
"apply",
"candidate",
)
_TOOL_ALIAS_HINTS = {
"run_command": "execute_command_tool",
}
_OUTPUT_DIRECTORY_INPUT_HINTS = {
"review_dir",
"output_dir",
"destination_dir",
"dest_dir",
"target_dir",
"review_path",
"output_path",
"destination_path",
"dest_path",
"target_path",
}
_SESSION_DATA_TOOLS = frozenset(
{
"save_data",
"load_data",
"list_data_files",
"append_data",
"edit_data",
"serve_file_to_user",
}
)
_WORKSPACE_PATH_HINTS = frozenset(
{
"review_dir",
"review_root",
"output_dir",
"output_root",
"output_path",
"target_dir",
"target_root",
"target_path",
"workspace",
"project folder",
"project folders",
}
)
_SESSION_DATA_TOOL_PATH_OP_RE = re.compile(
r"(save_data|load_data|list_data_files|append_data|edit_data|serve_file_to_user)"
r"[^.\n]{0,200}\b(?:to|into|in|inside|under|within|from|at|on)\s+(?:the\s+)?"
r"(review_dir|review_root|output_dir|output_root|output_path|target_dir|target_root|"
r"target_path|workspace|project folder|project folders)\b|"
r"\b(?:to|into|in|inside|under|within|from|at|on)\s+(?:the\s+)?"
r"(review_dir|review_root|output_dir|output_root|output_path|target_dir|target_root|"
r"target_path|workspace|project folder|project folders)\b[^.\n]{0,200}"
r"(save_data|load_data|list_data_files|append_data|edit_data|serve_file_to_user)",
re.IGNORECASE,
)
def _contains_hint_word(text: str, hint: str) -> bool:
"""Return True when *hint* appears as a word/phrase, not just a substring."""
if " " in hint:
return hint in text
return re.search(rf"\b{re.escape(hint)}\b", text) is not None
def _default_intro_message(human_name: str, description: str) -> str:
"""Return a non-placeholder intro line for generated agents."""
desc = (description or "").strip().rstrip(".")
if desc:
return f"{desc}."
return f"Ready to run {human_name}."
def _default_success_metric(index: int) -> str:
"""Return a generic but non-placeholder success metric name."""
return f"criterion_{index}_satisfied"
def _looks_like_agent_path(agent_ref: str) -> bool:
"""Return True when *agent_ref* should be treated as a filesystem path."""
candidate = Path(agent_ref).expanduser()
return candidate.is_absolute() or len(candidate.parts) > 1 or agent_ref.startswith((".", "~"))
def _resolve_agent_package_target(agent_ref: str) -> tuple[Path, str, str]:
"""Resolve a validator target to (agent_dir, package_name, display_ref).
Bare names still target exports/<name> for build-time validation.
Paths are resolved relative to the repository root and must pass the
server allowlist so existing example agents can be staged safely.
"""
ref = (agent_ref or "").strip()
if not ref:
raise ValueError("Agent reference is required")
if not PROJECT_ROOT:
raise ValueError("PROJECT_ROOT is not configured")
if _looks_like_agent_path(ref):
resolved = Path(_resolve_path(ref))
try:
from framework.server.app import validate_agent_path
except ImportError as exc:
raise ValueError("Cannot validate agent path: framework package not available") from exc
resolved = validate_agent_path(str(resolved))
return resolved, resolved.name, str(resolved)
resolved = (Path(PROJECT_ROOT) / "exports" / ref).resolve()
return resolved, ref, f"exports/{ref}"
def _default_success_target() -> str:
"""Return a generic non-placeholder success target."""
return "true"
def _node_can_progress_without_declared_tools(node) -> bool:
"""Return True when a node can legitimately work without MCP/local tools.
Runtime supports two common cases that should not be blocked by static
validation:
- ``gcu`` nodes, whose browser tools are injected by the framework.
- pure LLM work nodes that consume inputs and explicitly write outputs via
``set_output`` without needing external tools.
"""
if getattr(node, "node_type", "") == "gcu":
return True
output_keys = list(getattr(node, "output_keys", []) or [])
if not output_keys:
return False
prompt = (getattr(node, "system_prompt", "") or "").lower()
text = " ".join(
filter(
None,
[
getattr(node, "name", "") or "",
getattr(node, "description", "") or "",
getattr(node, "system_prompt", "") or "",
],
)
).lower()
mentions_set_output = (
"set_output(" in prompt
or "call set_output" in prompt
or "use set_output" in prompt
or "set_output " in prompt
)
looks_like_real_work = any(_contains_hint_word(text, hint) for hint in _ENTRY_DIRECT_WORK_HINTS)
return mentions_set_output and looks_like_real_work
def _behavior_validation_errors(agent_module) -> list[str]:
"""Return behavior-level validation errors for a generated agent package."""
errors: list[str] = []
nodes = list(getattr(agent_module, "nodes", []) or [])
terminal_ids = set(getattr(agent_module, "terminal_nodes", []) or [])
entry_node_id = getattr(agent_module, "entry_node", None) or ""
identity_prompt = getattr(agent_module, "identity_prompt", "") or ""
metadata = getattr(agent_module, "metadata", None)
goal = getattr(agent_module, "goal", None)
identity_prompt_text = identity_prompt.strip()
if not identity_prompt_text:
errors.append("identity_prompt is blank")
elif any(marker in identity_prompt_text for marker in _PLACEHOLDER_MARKERS):
errors.append("identity_prompt still contains TODO placeholders")
if metadata is not None:
if _is_placeholder_text(getattr(metadata, "description", "") or ""):
errors.append("metadata.description is blank or still contains TODO placeholders")
if _is_placeholder_text(getattr(metadata, "intro_message", "") or ""):
errors.append("metadata.intro_message is blank or still contains TODO placeholders")
if goal is not None:
if _is_placeholder_text(getattr(goal, "description", "") or ""):
errors.append("goal.description is blank or still contains TODO placeholders")
for criterion in list(getattr(goal, "success_criteria", []) or []):
cid = getattr(criterion, "id", "<unknown>")
for attr in ("description", "metric", "target"):
if _is_placeholder_text(getattr(criterion, attr, "") or ""):
errors.append(f"Success criterion '{cid}' has blank or placeholder {attr}")
for constraint in list(getattr(goal, "constraints", []) or []):
cid = getattr(constraint, "id", "<unknown>")
if _is_placeholder_text(getattr(constraint, "description", "") or ""):
errors.append(f"Constraint '{cid}' has blank or placeholder description")
for node in nodes:
node_id = getattr(node, "id", "<unknown>")
node_desc = getattr(node, "description", "") or ""
if _is_placeholder_text(node_desc):
errors.append(f"Node '{node_id}' has a blank or placeholder description")
prompt = getattr(node, "system_prompt", "") or ""
prompt_lower = prompt.lower()
if _is_placeholder_text(prompt):
errors.append(f"Node '{node_id}' has a blank or placeholder system_prompt")
else:
tools = list(getattr(node, "tools", []) or [])
for tool_name in tools:
if isinstance(tool_name, str) and f"{tool_name}(" in prompt:
errors.append(
f"Node '{node_id}' system_prompt uses callable-style tool syntax for "
f"'{tool_name}'. Describe tool usage in prose instead of "
"Python-style calls."
)
for alias, actual in _TOOL_ALIAS_HINTS.items():
if alias in prompt and actual in tools and alias not in tools:
errors.append(
f"Node '{node_id}' system_prompt references unsupported tool alias "
f"'{alias}'. Use the actual registered tool name '{actual}'."
)
data_tools_used = [tool for tool in tools if tool in _SESSION_DATA_TOOLS]
if data_tools_used:
workspace_path_hints = []
if _SESSION_DATA_TOOL_PATH_OP_RE.search(prompt):
workspace_path_hints = [
hint
for hint in _WORKSPACE_PATH_HINTS
if _contains_hint_word(prompt_lower, hint)
]
if workspace_path_hints:
joined_tools = ", ".join(sorted(data_tools_used))
joined_hints = ", ".join(sorted(workspace_path_hints))
errors.append(
f"Node '{node_id}' uses session data tools ({joined_tools}) as if they "
f"can operate on workspace paths ({joined_hints}). Data tools use the "
"framework-managed session data directory; use execute_command_tool for "
"workspace/review/output directories."
)
success_criteria = getattr(node, "success_criteria", "") or ""
if _is_placeholder_text(success_criteria):
errors.append(f"Node '{node_id}' has blank or placeholder success_criteria")
tools = list(getattr(node, "tools", []) or [])
sub_agents = list(getattr(node, "sub_agents", []) or [])
client_facing = bool(getattr(node, "client_facing", False))
if (
node_id not in terminal_ids
and not client_facing
and not tools
and not sub_agents
and not _node_can_progress_without_declared_tools(node)
):
errors.append(f"Autonomous node '{node_id}' has no tools or sub_agents")
if node_id == entry_node_id:
input_keys = list(getattr(node, "input_keys", []) or [])
output_keys = list(getattr(node, "output_keys", []) or [])
text = " ".join(
filter(
None,
[
getattr(node, "name", "") or "",
node_desc,
prompt,
],
)
).lower()
lowered_input_keys = {str(key).lower() for key in input_keys}
lowered_output_keys = {str(key).lower() for key in output_keys}
generic_task_only = len(input_keys) == 1 and input_keys[0] in {
"task",
"user_request",
"raw",
"input",
"request",
"message",
}
intake_like = any(hint in text for hint in _ENTRY_INTAKE_HINTS)
direct_work_like = any(
_contains_hint_word(text, hint) for hint in _ENTRY_DIRECT_WORK_HINTS
)
pass_through_inputs = bool(lowered_input_keys) and (
lowered_input_keys <= lowered_output_keys
)
runtime_normalization_only = any(
hint in text
for hint in (
"validate",
"validation",
"normalize",
"normalise",
"config",
"configuration",
"runtime",
"path",
"paths",
)
)
if (
intake_like
and not direct_work_like
and (generic_task_only or pass_through_inputs or runtime_normalization_only)
):
errors.append(
f"Entry node '{node_id}' appears to be an intake/config parser. "
"The queen handles intake. Make the first real work node consume "
"structured input_keys directly instead of reparsing a generic task string."
)
for input_key in input_keys:
lowered_key = str(input_key).lower()
if lowered_key not in _OUTPUT_DIRECTORY_INPUT_HINTS:
continue
if (
lowered_key in text
and "exist" in text
and ("directory" in text or "directories" in text)
):
errors.append(
f"Entry node '{node_id}' requires output path '{input_key}' to pre-exist. "
"Output/review directories should be created if missing instead of "
"blocking the run during intake validation."
)
break
return errors
def _classify_behavior_validation_errors(errors: list[str]) -> tuple[list[str], list[str]]:
"""Split behavior validation findings into blocking errors and warnings.
Hard failures are reserved for issues that are likely to break runtime
execution or violate framework contracts. Quality/style issues remain
warnings so they can be surfaced without preventing staging/runs.
"""
blocking_markers = (
"identity_prompt still contains TODO placeholders",
"blank or placeholder system_prompt",
"uses session data tools",
"Autonomous node ",
"appears to be an intake/config parser",
"requires output path",
)
blocking: list[str] = []
warnings: list[str] = []
for error in errors:
if any(marker in error for marker in blocking_markers):
blocking.append(error)
else:
warnings.append(error)
return blocking, warnings
# ── Git snapshot system (ported from opencode's shadow git) ───────────────
@@ -752,9 +1143,14 @@ def list_agent_tools(
def _validate_agent_tools_impl(agent_path: str) -> dict:
"""Validate that all tools declared in an agent's nodes exist in its MCP servers.
"""Validate that all tools declared in an agent's nodes exist at runtime.
Returns a dict with validation result: pass/fail, missing tools per node, available tools.
Mirrors runtime tool discovery:
1. MCP tools from ``mcp_servers.json`` (when present)
2. Agent-local ``tools.py`` custom tools (when present)
Returns a dict with validation result: pass/fail, missing tools per node,
available tools, and discovery warnings.
"""
try:
resolved = _resolve_path(agent_path)
@@ -781,11 +1177,8 @@ def _validate_agent_tools_impl(agent_path: str) -> dict:
agent_dir = resolved # Keep path; 'resolved' is reused for MCP config in loop
# --- Discover available tools from agent's MCP servers ---
# --- Discover available tools from MCP + local tools.py ---
mcp_config_path = os.path.join(agent_dir, "mcp_servers.json")
if not os.path.isfile(mcp_config_path):
return {"error": f"No mcp_servers.json found in {agent_path}"}
try:
from pathlib import Path
@@ -796,36 +1189,45 @@ def _validate_agent_tools_impl(agent_path: str) -> dict:
available_tools: set[str] = set()
discovery_errors = []
config_dir = Path(mcp_config_path).parent
try:
with open(mcp_config_path, encoding="utf-8") as f:
servers_config = json.load(f)
except (json.JSONDecodeError, OSError) as e:
return {"error": f"Failed to read mcp_servers.json: {e}"}
for server_name, server_conf in servers_config.items():
resolved = ToolRegistry.resolve_mcp_stdio_config(
{"name": server_name, **server_conf}, config_dir
)
if os.path.isfile(mcp_config_path):
config_dir = Path(mcp_config_path).parent
try:
config = MCPServerConfig(
name=server_name,
transport=resolved.get("transport", "stdio"),
command=resolved.get("command"),
args=resolved.get("args", []),
env=resolved.get("env", {}),
cwd=resolved.get("cwd"),
url=resolved.get("url"),
headers=resolved.get("headers", {}),
with open(mcp_config_path, encoding="utf-8") as f:
servers_config = json.load(f)
except (json.JSONDecodeError, OSError) as e:
return {"error": f"Failed to read mcp_servers.json: {e}"}
for server_name, server_conf in servers_config.items():
resolved = ToolRegistry.resolve_mcp_stdio_config(
{"name": server_name, **server_conf}, config_dir
)
client = MCPClient(config)
client.connect()
for tool in client.list_tools():
available_tools.add(tool.name)
client.disconnect()
try:
config = MCPServerConfig(
name=server_name,
transport=resolved.get("transport", "stdio"),
command=resolved.get("command"),
args=resolved.get("args", []),
env=resolved.get("env", {}),
cwd=resolved.get("cwd"),
url=resolved.get("url"),
headers=resolved.get("headers", {}),
)
client = MCPClient(config)
client.connect()
for tool in client.list_tools():
available_tools.add(tool.name)
client.disconnect()
except Exception as e:
discovery_errors.append({"server": server_name, "error": str(e)})
local_tools_path = Path(agent_dir) / "tools.py"
if local_tools_path.is_file():
try:
registry = ToolRegistry()
registry.discover_from_module(local_tools_path)
available_tools.update(registry.get_tools().keys())
except Exception as e:
discovery_errors.append({"server": server_name, "error": str(e)})
discovery_errors.append({"server": "tools.py", "error": str(e)})
# --- Load agent nodes and extract declared tools ---
agent_py = os.path.join(agent_dir, "agent.py")
@@ -1227,7 +1629,7 @@ def get_agent_checkpoint(
def _run_agent_tests_impl(
agent_name: str,
agent_ref: str,
test_types: str = "all",
fail_fast: bool = False,
) -> dict:
@@ -1235,22 +1637,35 @@ def _run_agent_tests_impl(
Returns a dict with summary counts, per-test results, and failure details.
"""
agent_path = Path(PROJECT_ROOT) / "exports" / agent_name
try:
agent_path, agent_name, display_ref = _resolve_agent_package_target(agent_ref)
except ValueError as e:
return {"error": str(e)}
if not agent_path.is_dir():
# Fall back to framework agents
# Fall back to framework agents for bare framework package names.
agent_path = Path(PROJECT_ROOT) / "core" / "framework" / "agents" / agent_name
tests_dir = agent_path / "tests"
if not agent_path.is_dir():
return {
"error": f"Agent not found: {agent_name}",
"error": f"Agent not found: {agent_ref}",
"hint": "Use list_agents() to see available agents.",
}
if not tests_dir.exists():
return {
"error": f"No tests directory: exports/{agent_name}/tests/",
"hint": "Create test files in the tests/ directory first.",
"agent_name": agent_name,
"agent_path": str(agent_path),
"summary": f"No tests directory: {tests_dir}",
"passed": 0,
"failed": 0,
"skipped": 1,
"errors": 0,
"total": 0,
"test_results": [],
"failures": [],
"skipped_all": True,
}
# Parse test types
@@ -1297,7 +1712,10 @@ def _run_agent_tests_impl(
core_path = os.path.join(PROJECT_ROOT, "core")
exports_path = os.path.join(PROJECT_ROOT, "exports")
fw_agents_path = os.path.join(PROJECT_ROOT, "core", "framework", "agents")
package_parent = str(agent_path.parent)
path_parts = [core_path, exports_path, fw_agents_path, PROJECT_ROOT]
if package_parent not in path_parts:
path_parts.insert(1, package_parent)
if pythonpath:
path_parts.append(pythonpath)
env["PYTHONPATH"] = os.pathsep.join(path_parts)
@@ -1387,6 +1805,7 @@ def _run_agent_tests_impl(
return {
"agent_name": agent_name,
"agent_path": str(agent_path),
"summary": summary_text,
"passed": passed,
"failed": failed,
@@ -1425,28 +1844,50 @@ def run_agent_tests(
# ── Meta-agent: Unified agent validation ───────────────────────────────────
@mcp.tool()
def validate_agent_package(agent_name: str) -> str:
def _validate_agent_package_impl(agent_name: str) -> dict[str, object]:
"""Run structural validation checks on a built agent package in one call.
Executes 5 steps and reports all results (does not stop on first failure):
Executes multiple checks and reports all results (does not stop on first failure):
1. Class validation checks graph structure and entry_points contract
2. Node completeness every NodeSpec in nodes/ must be in the nodes list,
and GCU nodes must be referenced in a parent's sub_agents
3. Graph validation loads the agent graph without credential checks
4. Tool validation checks declared tools exist in MCP servers
5. Tests runs the agent's pytest suite
4. Behavior validation rejects placeholder prompts and empty autonomous nodes
5. Tool validation checks declared tools exist in MCP servers
6. Tests runs the agent's pytest suite
Note: Credential validation is intentionally skipped here (building phase).
Credentials are validated at run time by run_agent_with_input() preflight.
Args:
agent_name: Agent package name (e.g. 'my_agent'). Must exist in exports/.
agent_name: Agent package name (e.g. 'my_agent') or an allowed
agent path such as examples/templates/my_agent.
Returns:
JSON with per-step results and overall pass/fail summary
Dict with per-step results and overall pass/fail summary
"""
agent_path = f"exports/{agent_name}"
global PROJECT_ROOT, SNAPSHOT_DIR
if not PROJECT_ROOT:
PROJECT_ROOT = _find_project_root()
if not SNAPSHOT_DIR and PROJECT_ROOT:
SNAPSHOT_DIR = os.path.join(
os.path.expanduser("~"),
".hive",
"snapshots",
os.path.basename(PROJECT_ROOT),
)
try:
agent_dir, package_name, display_ref = _resolve_agent_package_target(agent_name)
except ValueError as e:
return {
"valid": False,
"agent_name": agent_name,
"steps": {"target_resolution": {"passed": False, "error": str(e)}},
"summary": "FAIL: 1 of 1 steps failed (target_resolution)",
}
steps: dict[str, dict] = {}
# Set up env for subprocess calls
@@ -1454,8 +1895,11 @@ def validate_agent_package(agent_name: str) -> str:
core_path = os.path.join(PROJECT_ROOT, "core")
exports_path = os.path.join(PROJECT_ROOT, "exports")
fw_agents_path = os.path.join(PROJECT_ROOT, "core", "framework", "agents")
package_parent = str(agent_dir.parent)
pythonpath = env.get("PYTHONPATH", "")
path_parts = [core_path, exports_path, fw_agents_path, PROJECT_ROOT]
if package_parent not in path_parts:
path_parts.insert(1, package_parent)
if pythonpath:
path_parts.append(pythonpath)
env["PYTHONPATH"] = os.pathsep.join(path_parts)
@@ -1464,22 +1908,22 @@ def validate_agent_package(agent_name: str) -> str:
try:
_contract_script = textwrap.dedent("""\
import importlib, json
mod = importlib.import_module('{agent_name}')
mod = importlib.import_module('{package_name}')
missing = [a for a in ('goal', 'nodes', 'edges') if getattr(mod, a, None) is None]
if missing:
print(json.dumps({{
'valid': False,
'error': (
"Module '{agent_name}' is missing module-level attributes: "
"Module '{package_name}' is missing module-level attributes: "
+ ", ".join(missing) + ". "
"Fix: in {agent_name}/__init__.py, add "
"Fix: in {package_name}/__init__.py, add "
"'from .agent import " + ", ".join(missing) + "' "
"so that 'import {agent_name}' exposes them at package level."
"so that 'import {package_name}' exposes them at package level."
)
}}))
else:
print(json.dumps({{'valid': True}}))
""").format(agent_name=agent_name)
""").format(package_name=package_name)
proc = subprocess.run(
["uv", "run", "python", "-c", _contract_script],
capture_output=True,
@@ -1499,8 +1943,8 @@ def validate_agent_package(agent_name: str) -> str:
steps["module_contract"] = {
"passed": False,
"error": (
f"Failed to import '{agent_name}': {proc.stderr.strip()[:1000]}. "
f"Fix: ensure {agent_name}/__init__.py exists and can be imported "
f"Failed to import '{package_name}': {proc.stderr.strip()[:1000]}. "
f"Fix: ensure {package_name}/__init__.py exists and can be imported "
f"without errors (check syntax, missing dependencies, relative imports)."
),
}
@@ -1515,7 +1959,7 @@ def validate_agent_package(agent_name: str) -> str:
"run",
"python",
"-c",
f"from {agent_name} import default_agent; print(default_agent.validate())",
f"from {package_name} import default_agent; print(default_agent.validate())",
],
capture_output=True,
text=True,
@@ -1562,7 +2006,7 @@ def validate_agent_package(agent_name: str) -> str:
)
print(json.dumps({{'valid': len(errors) == 0, 'errors': errors}}))
""")
check_script = _check_template.format(agent_name=agent_name)
check_script = _check_template.format(agent_name=package_name)
proc = subprocess.run(
["uv", "run", "python", "-c", check_script],
capture_output=True,
@@ -1603,7 +2047,7 @@ def validate_agent_package(agent_name: str) -> str:
"python",
"-c",
f"from framework.runner.runner import AgentRunner; "
f'r = AgentRunner.load("exports/{agent_name}", '
f"r = AgentRunner.load({str(display_ref)!r}, "
f"skip_credential_validation=True); "
f'print("AgentRunner.load (graph-only): OK")',
],
@@ -1624,9 +2068,42 @@ def validate_agent_package(agent_name: str) -> str:
except Exception as e:
steps["graph_validation"] = {"passed": False, "error": str(e)}
# Step B2: Behavior validation — reject placeholder prompts and empty work nodes
try:
import importlib
if package_parent not in sys.path:
sys.path.insert(0, package_parent)
stale = [
name
for name in sys.modules
if name == package_name or name.startswith(f"{package_name}.")
]
for name in stale:
del sys.modules[name]
agent_mod = importlib.import_module(package_name)
behavior_errors = _behavior_validation_errors(agent_mod)
behavior_blockers, behavior_warnings = _classify_behavior_validation_errors(behavior_errors)
steps["behavior_validation"] = {
"passed": len(behavior_blockers) == 0,
"output": (
"No placeholder prompts or empty autonomous nodes detected"
if not behavior_errors
else "; ".join(behavior_blockers or behavior_warnings)
),
}
if behavior_blockers:
steps["behavior_validation"]["errors"] = behavior_blockers
if behavior_warnings:
steps["behavior_validation"]["warnings"] = behavior_warnings
except Exception as e:
steps["behavior_validation"] = {"passed": False, "error": str(e)}
# Step C: Tool validation (direct call)
try:
tool_result = _validate_agent_tools_impl(agent_path)
tool_result = _validate_agent_tools_impl(str(agent_dir))
if "error" in tool_result:
steps["tool_validation"] = {"passed": False, "error": tool_result["error"]}
else:
@@ -1641,17 +2118,33 @@ def validate_agent_package(agent_name: str) -> str:
# Step D: Tests (direct call)
try:
test_result = _run_agent_tests_impl(agent_name)
if "error" in test_result:
steps["tests"] = {"passed": False, "error": test_result["error"]}
test_result = _run_agent_tests_impl(str(agent_dir))
if test_result.get("skipped_all"):
steps["tests"] = {
"passed": True,
"skipped": True,
"summary": test_result.get("summary", "No tests directory found; skipped"),
}
elif "error" in test_result:
steps["tests"] = {
"passed": False,
"warning": test_result["error"],
"warnings": [test_result["error"]],
}
else:
all_passed = test_result.get("failed", 0) == 0 and test_result.get("errors", 0) == 0
steps["tests"] = {
"passed": all_passed,
"summary": test_result.get("summary", "unknown"),
}
if not all_passed and test_result.get("failures"):
steps["tests"]["failures"] = test_result["failures"]
if not all_passed:
warning_summary = (
f"Test suite not fully passing: {test_result.get('summary', 'unknown')}"
)
steps["tests"]["warning"] = warning_summary
steps["tests"]["warnings"] = [warning_summary]
if test_result.get("failures"):
steps["tests"]["failures"] = test_result["failures"]
except Exception as e:
steps["tests"] = {"passed": False, "error": str(e)}
@@ -1665,13 +2158,20 @@ def validate_agent_package(agent_name: str) -> str:
else:
summary = f"FAIL: {len(failed_steps)} of {total} steps failed ({', '.join(failed_steps)})"
return {
"valid": valid,
"agent_name": package_name,
"agent_path": str(agent_dir),
"steps": steps,
"summary": summary,
}
@mcp.tool()
def validate_agent_package(agent_name: str) -> str:
"""Run structural validation checks on a built agent package in one call."""
return json.dumps(
{
"valid": valid,
"agent_name": agent_name,
"steps": steps,
"summary": summary,
},
_validate_agent_package_impl(agent_name),
indent=2,
default=str,
)
@@ -1804,10 +2304,10 @@ default_config = RuntimeConfig()
@dataclass
class AgentMetadata:
name: str = "{human_name}"
name: str = {human_name!r}
version: str = "1.0.0"
description: str = "{_draft_desc or "TODO: Add agent description."}"
intro_message: str = "TODO: Add intro message."
description: str = {(_draft_desc or "TODO: Add agent description.")!r}
intro_message: str = {_default_intro_message(human_name, _draft_desc)!r}
metadata = AgentMetadata()
@@ -1913,8 +2413,8 @@ __all__ = {node_var_names!r}
SuccessCriterion(
id="sc-{i + 1}",
description="{sc}",
metric="TODO",
target="TODO",
metric="{_default_success_metric(i + 1)}",
target="{_default_success_target()}",
weight=1.0,
),"""
for i, sc in enumerate(_draft_sc)
@@ -1924,8 +2424,8 @@ __all__ = {node_var_names!r}
SuccessCriterion(
id="sc-1",
description="TODO: Define success criterion.",
metric="TODO",
target="TODO",
metric="criterion_1_satisfied",
target="true",
weight=1.0,
),"""
@@ -1997,7 +2497,10 @@ pause_nodes = []
terminal_nodes = []
conversation_mode = "continuous"
identity_prompt = "TODO: Add identity prompt."
identity_prompt = (
"You are {human_name}, a focused Hive worker that follows the goal, "
"constraints, and node instructions precisely."
)
loop_config = {{
"max_iterations": 100,
"max_tool_calls_per_turn": 30,
@@ -2063,7 +2566,7 @@ class {class_name}:
name="Default",
entry_node=self.entry_node,
trigger_type="manual",
isolation_level="shared",
isolation_level="isolated",
),
],
llm=llm,
@@ -2347,11 +2850,13 @@ def runner_loaded():
"files": all_file_paths,
"next_steps": [
(
"IMPORTANT: All generated files are structurally complete "
"with correct imports, class definition, validate() method, "
"and __init__.py exports. Use edit_file to customize TODO "
"placeholders — do NOT use write_file to rewrite entire files, "
"as this will break imports and structure."
"IMPORTANT: The generated scaffold has correct imports, class "
"definition, validate() method, and __init__.py exports, but "
"it is NOT ready to load or run yet. Replace every TODO / "
"placeholder prompt and make validation pass before staging. "
"Use edit_file to customize placeholders — do NOT use "
"write_file to rewrite entire files, as this will break "
"imports and structure."
),
(
f"Use edit_file to customize system prompts, tools, "
File diff suppressed because it is too large Load Diff