fix: mcp registry pipeline stage
This commit is contained in:
@@ -470,7 +470,8 @@ The agent.json must include ALL of these in one write:
|
||||
`input_keys`, `output_keys`, `success_criteria`
|
||||
- `edges` — connecting all nodes with proper conditions
|
||||
- `entry_node`, `terminal_nodes`
|
||||
- `mcp_servers` — reference by name: `[{"name": "hive-tools"}, {"name": "gcu-tools"}]`
|
||||
- `mcp_servers` — REQUIRED. Always include all three: \
|
||||
`[{"name": "hive-tools"}, {"name": "gcu-tools"}, {"name": "files-tools"}]`
|
||||
- `loop_config` — `max_iterations`, `max_context_tokens`
|
||||
|
||||
**Write the COMPLETE config in one `write_file` call. No TODOs, no placeholders.** \
|
||||
|
||||
@@ -351,7 +351,12 @@ class AgentHost:
|
||||
# Start storage
|
||||
await self._storage.start()
|
||||
|
||||
# Create streams for each entry point
|
||||
# Initialize pipeline stages FIRST -- they inject LLM, tools,
|
||||
# credentials, and skills into the host before streams are created.
|
||||
await self._pipeline.initialize_all()
|
||||
self._apply_pipeline_results()
|
||||
|
||||
# Create streams for each entry point (uses pipeline results)
|
||||
for ep_id, spec in self._entry_points.items():
|
||||
stream = ExecutionManager(
|
||||
stream_id=ep_id,
|
||||
@@ -804,9 +809,6 @@ class AgentHost:
|
||||
# Start skill hot-reload watcher (no-op if watchfiles not installed)
|
||||
await self._skills_manager.start_watching()
|
||||
|
||||
# Initialize pipeline stages (one-time setup)
|
||||
await self._pipeline.initialize_all()
|
||||
|
||||
self._running = True
|
||||
self._timers_paused = False
|
||||
n_stages = len(self._pipeline.stages)
|
||||
@@ -899,6 +901,49 @@ class AgentHost:
|
||||
# Primary graph (also stored in self._streams)
|
||||
return self._streams.get(entry_point_id)
|
||||
|
||||
def _apply_pipeline_results(self) -> None:
|
||||
"""Extract tools/LLM/credentials/skills from pipeline stages.
|
||||
|
||||
Called after ``pipeline.initialize_all()`` so stages have finished
|
||||
their async setup (MCP connected, skills discovered, etc.).
|
||||
The host reads stage properties and updates its own state.
|
||||
"""
|
||||
for stage in self._pipeline.stages:
|
||||
stage_name = stage.__class__.__name__
|
||||
|
||||
# McpRegistryStage -> tools
|
||||
if hasattr(stage, "tool_registry") and stage.tool_registry is not None:
|
||||
tools = list(stage.tool_registry.get_tools().values())
|
||||
executor = stage.tool_registry.get_executor()
|
||||
if tools:
|
||||
self._tools = tools
|
||||
self._tool_executor = executor
|
||||
logger.info(
|
||||
"Pipeline injected %d tools from %s",
|
||||
len(tools), stage_name,
|
||||
)
|
||||
|
||||
# LlmProviderStage -> LLM
|
||||
if hasattr(stage, "llm") and stage.llm is not None:
|
||||
if self._llm is None:
|
||||
self._llm = stage.llm
|
||||
logger.info(
|
||||
"Pipeline injected LLM from %s", stage_name,
|
||||
)
|
||||
|
||||
# CredentialResolverStage -> accounts
|
||||
if hasattr(stage, "accounts_prompt") and stage.accounts_prompt:
|
||||
self._accounts_prompt = stage.accounts_prompt
|
||||
self._accounts_data = getattr(stage, "accounts_data", None)
|
||||
self._tool_provider_map = getattr(
|
||||
stage, "tool_provider_map", None,
|
||||
)
|
||||
|
||||
# SkillRegistryStage -> skills manager
|
||||
if hasattr(stage, "skills_manager") and stage.skills_manager is not None:
|
||||
self._skills_manager = stage.skills_manager
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _load_pipeline_from_config():
|
||||
"""Build pipeline from ``~/.hive/configuration.json`` ``pipeline`` key.
|
||||
@@ -1917,6 +1962,7 @@ def create_agent_runtime(
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
pipeline_stages: "list[PipelineStage] | None" = None,
|
||||
) -> AgentHost:
|
||||
"""
|
||||
Create and configure an AgentHost with entry points.
|
||||
@@ -1980,6 +2026,7 @@ def create_agent_runtime(
|
||||
skills_catalog_prompt=skills_catalog_prompt,
|
||||
protocols_prompt=protocols_prompt,
|
||||
skill_dirs=skill_dirs,
|
||||
pipeline_stages=pipeline_stages,
|
||||
)
|
||||
|
||||
for spec in entry_points:
|
||||
|
||||
@@ -1267,82 +1267,7 @@ class AgentLoader:
|
||||
os.environ["HIVE_AGENT_NAME"] = agent_path.name
|
||||
os.environ["HIVE_STORAGE_PATH"] = str(self._storage_path)
|
||||
|
||||
# Load MCP servers: prefer agent.json mcp_servers refs -> global registry
|
||||
# Fallback to mcp_servers.json if it exists (legacy)
|
||||
mcp_config_path = agent_path / "mcp_servers.json"
|
||||
agent_json_path = agent_path / "agent.json"
|
||||
|
||||
logger.info(
|
||||
"MCP loading: agent_json=%s, mcp_json=%s",
|
||||
agent_json_path.exists(),
|
||||
mcp_config_path.exists(),
|
||||
)
|
||||
|
||||
mcp_loaded = False
|
||||
# 1. From agent.json mcp_servers field (resolved via global registry)
|
||||
if agent_json_path.exists():
|
||||
try:
|
||||
import json as _json
|
||||
|
||||
agent_data = _json.loads(agent_json_path.read_text(encoding="utf-8"))
|
||||
server_refs = agent_data.get("mcp_servers", [])
|
||||
if server_refs:
|
||||
names = [ref["name"] for ref in server_refs if ref.get("name")]
|
||||
if names:
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
|
||||
registry = MCPRegistry()
|
||||
configs = registry.resolve_for_agent(include=names)
|
||||
if configs:
|
||||
self._tool_registry.load_registry_servers(configs)
|
||||
mcp_loaded = True
|
||||
logger.info(
|
||||
"Loaded %d MCP servers from registry: %s",
|
||||
len(configs),
|
||||
names,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to load MCP servers from agent.json refs", exc_info=True)
|
||||
|
||||
# 2. Legacy: mcp_servers.json file
|
||||
if not mcp_loaded and mcp_config_path.exists():
|
||||
self._load_mcp_servers_from_config(mcp_config_path)
|
||||
mcp_loaded = True
|
||||
|
||||
# 3. Fallback: load all servers from global registry
|
||||
if not mcp_loaded:
|
||||
try:
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
|
||||
registry = MCPRegistry()
|
||||
configs = registry.resolve_for_agent(profile="all")
|
||||
if configs:
|
||||
self._tool_registry.load_registry_servers(configs)
|
||||
logger.info(
|
||||
"Loaded %d MCP servers from global registry (fallback)",
|
||||
len(configs),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to load MCP servers from global registry", exc_info=True)
|
||||
|
||||
# Auto-discover registry-selected MCP servers from mcp_registry.json
|
||||
self._load_registry_mcp_servers(agent_path)
|
||||
|
||||
# Summary: how many tools were loaded from MCP servers?
|
||||
_all_tools = self._tool_registry.get_tools()
|
||||
logger.info(
|
||||
"MCP loading complete: %d tools from tool registry",
|
||||
len(_all_tools),
|
||||
)
|
||||
if not _all_tools:
|
||||
logger.warning(
|
||||
"ZERO tools loaded from MCP servers. Workers will only have "
|
||||
"synthetic tools (set_output, escalate). Check: "
|
||||
"1) agent.json mcp_servers field, "
|
||||
"2) ~/.hive/mcp_registry/installed.json, "
|
||||
"3) MCP server process startup."
|
||||
)
|
||||
|
||||
# MCP tools are loaded by McpRegistryStage in the pipeline during AgentHost.start()
|
||||
@staticmethod
|
||||
def _import_agent_module(agent_path: Path):
|
||||
"""Import an agent package from its directory path.
|
||||
@@ -1695,228 +1620,148 @@ class AgentLoader:
|
||||
self._approval_callback = callback
|
||||
|
||||
def _setup(self, event_bus=None) -> None:
|
||||
"""Set up runtime, LLM, and executor."""
|
||||
# Configure structured logging (auto-detects JSON vs human-readable)
|
||||
"""Set up runtime via pipeline stages.
|
||||
|
||||
Builds a pipeline with the default stages (LLM, credentials, MCP,
|
||||
skills) and passes it to AgentHost. The stages initialize during
|
||||
``AgentHost.start()`` and inject tools/LLM/credentials/skills.
|
||||
"""
|
||||
from framework.observability import configure_logging
|
||||
from framework.pipeline.stages.credential_resolver import CredentialResolverStage
|
||||
from framework.pipeline.stages.llm_provider import LlmProviderStage
|
||||
from framework.pipeline.stages.mcp_registry import McpRegistryStage
|
||||
from framework.pipeline.stages.skill_registry import SkillRegistryStage
|
||||
from framework.skills.config import SkillsConfig
|
||||
|
||||
configure_logging(level="INFO", format="auto")
|
||||
|
||||
# Set up session context for tools (agent_id)
|
||||
# Set up session context for tools
|
||||
agent_id = self.graph.id or "unknown"
|
||||
self._tool_registry.set_session_context(agent_id=agent_id)
|
||||
|
||||
self._tool_registry.set_session_context(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
# Read MCP server refs from agent.json
|
||||
mcp_refs = []
|
||||
agent_json = self.agent_path / "agent.json"
|
||||
if agent_json.exists():
|
||||
try:
|
||||
import json as _json
|
||||
|
||||
# Create LLM provider
|
||||
# Uses LiteLLM which auto-detects the provider from model name
|
||||
# Skip if already injected (e.g. worker agents with a pre-built LLM)
|
||||
if self._llm is not None:
|
||||
pass # LLM already configured externally
|
||||
elif self.mock_mode:
|
||||
# Use mock LLM for testing without real API calls
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
data = _json.loads(agent_json.read_text(encoding="utf-8"))
|
||||
mcp_refs = data.get("mcp_servers", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._llm = MockLLMProvider(model=self.model)
|
||||
else:
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
# Check if a subscription mode is configured
|
||||
config = get_hive_config()
|
||||
llm_config = config.get("llm", {})
|
||||
use_claude_code = llm_config.get("use_claude_code_subscription", False)
|
||||
use_codex = llm_config.get("use_codex_subscription", False)
|
||||
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
|
||||
use_antigravity = llm_config.get("use_antigravity_subscription", False)
|
||||
api_base = llm_config.get("api_base")
|
||||
|
||||
api_key = None
|
||||
if use_claude_code:
|
||||
# Get OAuth token from Claude Code subscription
|
||||
api_key = get_claude_code_token()
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"Claude Code subscription configured but no token found. "
|
||||
"Run 'claude' to authenticate, then try again."
|
||||
)
|
||||
elif use_codex:
|
||||
# Get OAuth token from Codex subscription
|
||||
api_key = get_codex_token()
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"Codex subscription configured but no token found. "
|
||||
"Run 'codex' to authenticate, then try again."
|
||||
)
|
||||
elif use_kimi_code:
|
||||
# Get API key from Kimi Code CLI config (~/.kimi/config.toml)
|
||||
api_key = get_kimi_code_token()
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"Kimi Code subscription configured but no key found. "
|
||||
"Run 'kimi /login' to authenticate, then try again."
|
||||
)
|
||||
elif use_antigravity:
|
||||
pass # AntigravityProvider handles credentials internally
|
||||
|
||||
if api_key and use_claude_code:
|
||||
# Use litellm's built-in Anthropic OAuth support.
|
||||
# The lowercase "authorization" key triggers OAuth detection which
|
||||
# adds the required anthropic-beta and browser-access headers.
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers={"authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
elif api_key and use_codex:
|
||||
# OpenAI Codex subscription routes through the ChatGPT backend
|
||||
# (chatgpt.com/backend-api/codex/responses), NOT the standard
|
||||
# OpenAI API. The consumer OAuth token lacks platform API scopes.
|
||||
extra_headers: dict[str, str] = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"User-Agent": "CodexBar",
|
||||
}
|
||||
account_id = get_codex_account_id()
|
||||
if account_id:
|
||||
extra_headers["ChatGPT-Account-Id"] = account_id
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
api_base="https://chatgpt.com/backend-api/codex",
|
||||
extra_headers=extra_headers,
|
||||
store=False,
|
||||
allowed_openai_params=["store"],
|
||||
)
|
||||
elif api_key and use_kimi_code:
|
||||
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
|
||||
# The api_base is set automatically by LiteLLMProvider for kimi/ models.
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
elif use_antigravity:
|
||||
# Direct OAuth to Google's internal Cloud Code Assist gateway.
|
||||
# No local proxy required — AntigravityProvider handles token
|
||||
# refresh and Gemini-format request/response conversion natively.
|
||||
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
|
||||
|
||||
provider = AntigravityProvider(model=self.model)
|
||||
if not provider.has_credentials():
|
||||
print(
|
||||
"Warning: Antigravity credentials not found. "
|
||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
||||
)
|
||||
self._llm = provider
|
||||
else:
|
||||
# Local models (e.g. Ollama) don't need an API key
|
||||
if self._is_local_model(self.model):
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
# Fall back to environment variable
|
||||
# First check api_key_env_var from config (set by quickstart)
|
||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(
|
||||
self.model
|
||||
)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=os.environ[api_key_env],
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
# Fall back to credential store
|
||||
api_key = self._get_api_key_from_credential_store()
|
||||
if api_key:
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model, api_key=api_key, api_base=api_base
|
||||
)
|
||||
# Set env var so downstream code (e.g. cleanup LLM in
|
||||
# node._extract_json) can also find it
|
||||
if api_key_env:
|
||||
os.environ[api_key_env] = api_key
|
||||
elif api_key_env:
|
||||
logger.warning(
|
||||
"%s not set. LLM calls will fail. "
|
||||
"Set it with: export %s=your-api-key",
|
||||
api_key_env,
|
||||
api_key_env,
|
||||
)
|
||||
|
||||
# Fail fast if the agent needs an LLM but none was configured
|
||||
if self._llm is None:
|
||||
has_llm_nodes = any(
|
||||
node.node_type in ("event_loop", "gcu") for node in self.graph.nodes
|
||||
)
|
||||
if has_llm_nodes:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
if self._is_local_model(self.model):
|
||||
raise CredentialError(
|
||||
f"Failed to initialize LLM for local model '{self.model}'. "
|
||||
f"Ensure your local LLM server is running "
|
||||
f"(e.g. 'ollama serve' for Ollama)."
|
||||
)
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
hint = (
|
||||
f"Set it with: export {api_key_env}=your-api-key"
|
||||
if api_key_env
|
||||
else "Configure an API key for your LLM provider."
|
||||
)
|
||||
raise CredentialError(f"LLM API key not found for model '{self.model}'. {hint}")
|
||||
|
||||
# Get tools for runtime
|
||||
# (GCU and file tools are now registered via mcp_servers.json or MCP registry,
|
||||
# not auto-registered here. Agents declare them in agent.json.)
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
|
||||
# Collect connected account info for system prompt injection
|
||||
accounts_prompt = ""
|
||||
accounts_data: list[dict] | None = None
|
||||
tool_provider_map: dict[str, str] | None = None
|
||||
try:
|
||||
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
|
||||
|
||||
if self._credential_store is not None:
|
||||
adapter = CredentialStoreAdapter(store=self._credential_store)
|
||||
else:
|
||||
adapter = CredentialStoreAdapter.default()
|
||||
accounts_data = adapter.get_all_account_info()
|
||||
tool_provider_map = adapter.get_tool_provider_map()
|
||||
if accounts_data:
|
||||
from framework.orchestrator.prompting import build_accounts_prompt
|
||||
|
||||
accounts_prompt = build_accounts_prompt(accounts_data, tool_provider_map)
|
||||
except Exception:
|
||||
pass # Best-effort — agent works without account info
|
||||
|
||||
# Skill configuration — the runtime handles discovery, loading, trust-gating and
|
||||
# prompt rasterization. The runner just builds the config.
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManagerConfig
|
||||
|
||||
skills_manager_config = SkillsManagerConfig(
|
||||
skills_config=SkillsConfig.from_agent_vars(
|
||||
default_skills=getattr(self, "_agent_default_skills", None),
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
# Build default pipeline stages
|
||||
# Default infrastructure stages (always present)
|
||||
pipeline_stages = [
|
||||
LlmProviderStage(
|
||||
model=self.model,
|
||||
mock_mode=self.mock_mode,
|
||||
llm=self._llm,
|
||||
),
|
||||
project_root=self.agent_path,
|
||||
interactive=self._interactive,
|
||||
CredentialResolverStage(
|
||||
credential_store=self._credential_store,
|
||||
),
|
||||
McpRegistryStage(
|
||||
server_refs=mcp_refs,
|
||||
agent_path=self.agent_path,
|
||||
tool_registry=self._tool_registry,
|
||||
),
|
||||
SkillRegistryStage(
|
||||
project_root=self.agent_path,
|
||||
interactive=self._interactive,
|
||||
skills_config=SkillsConfig.from_agent_vars(
|
||||
default_skills=getattr(self, "_agent_default_skills", None),
|
||||
skills=getattr(self, "_agent_skills", None),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# Merge user-configured stages from ~/.hive/configuration.json
|
||||
from framework.config import get_hive_config
|
||||
from framework.pipeline.registry import build_pipeline_from_config
|
||||
|
||||
hive_config = get_hive_config()
|
||||
user_stages_config = hive_config.get("pipeline", {}).get("stages", [])
|
||||
if user_stages_config:
|
||||
user_pipeline = build_pipeline_from_config(user_stages_config)
|
||||
pipeline_stages.extend(user_pipeline.stages)
|
||||
|
||||
# Merge agent-level overrides from agent.json pipeline field
|
||||
if agent_json.exists():
|
||||
try:
|
||||
agent_pipeline = (
|
||||
_json.loads(agent_json.read_text(encoding="utf-8"))
|
||||
.get("pipeline", {})
|
||||
.get("stages", [])
|
||||
)
|
||||
if agent_pipeline:
|
||||
agent_stages = build_pipeline_from_config(agent_pipeline)
|
||||
pipeline_stages.extend(agent_stages.stages)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._setup_agent_runtime_with_pipeline(
|
||||
pipeline_stages=pipeline_stages,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
|
||||
self._setup_agent_runtime(
|
||||
tools,
|
||||
tool_executor,
|
||||
accounts_prompt=accounts_prompt,
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
def _setup_agent_runtime_with_pipeline(
|
||||
self,
|
||||
pipeline_stages: list,
|
||||
event_bus=None,
|
||||
) -> None:
|
||||
"""Create AgentHost with pipeline stages."""
|
||||
from framework.host.execution_manager import EntryPointSpec
|
||||
from framework.orchestrator.checkpoint_config import CheckpointConfig
|
||||
|
||||
entry_points = []
|
||||
if self.graph.entry_node:
|
||||
entry_points.append(
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.graph.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
)
|
||||
|
||||
log_store = RuntimeLogStore(
|
||||
base_path=self._storage_path / "runtime_logs",
|
||||
)
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False,
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True,
|
||||
)
|
||||
|
||||
runtime_config = None
|
||||
if self.runtime_config is not None:
|
||||
from framework.host.agent_host import AgentRuntimeConfig
|
||||
|
||||
if isinstance(self.runtime_config, AgentRuntimeConfig):
|
||||
runtime_config = self.runtime_config
|
||||
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_points,
|
||||
llm=None, # Injected by LlmProviderStage
|
||||
tools=[], # Injected by McpRegistryStage
|
||||
tool_executor=None, # Injected by McpRegistryStage
|
||||
runtime_log_store=log_store,
|
||||
checkpoint_config=checkpoint_config,
|
||||
config=runtime_config,
|
||||
graph_id=self.graph.id or self.agent_path.name,
|
||||
event_bus=event_bus,
|
||||
pipeline_stages=pipeline_stages,
|
||||
)
|
||||
self._agent_runtime.intro_message = self.intro_message
|
||||
|
||||
def _get_api_key_env_var(self, model: str) -> str | None:
|
||||
"""Get the environment variable name for the API key based on model name."""
|
||||
@@ -2009,83 +1854,6 @@ class AgentLoader:
|
||||
)
|
||||
return model.lower().startswith(LOCAL_PREFIXES)
|
||||
|
||||
def _setup_agent_runtime(
|
||||
self,
|
||||
tools: list,
|
||||
tool_executor: Callable | None,
|
||||
accounts_prompt: str = "",
|
||||
accounts_data: list[dict] | None = None,
|
||||
tool_provider_map: dict[str, str] | None = None,
|
||||
event_bus=None,
|
||||
skills_catalog_prompt: str = "",
|
||||
protocols_prompt: str = "",
|
||||
skill_dirs: list[str] | None = None,
|
||||
skills_manager_config=None,
|
||||
) -> None:
|
||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
||||
entry_points = []
|
||||
|
||||
# Always create a primary entry point for the graph's entry node.
|
||||
# For multi-entry-point agents this ensures the primary path (e.g.
|
||||
# user-facing rule setup) is reachable alongside async entry points.
|
||||
if self.graph.entry_node:
|
||||
entry_points.insert(
|
||||
0,
|
||||
EntryPointSpec(
|
||||
id="default",
|
||||
name="Default",
|
||||
entry_node=self.graph.entry_node,
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
)
|
||||
|
||||
# Create AgentRuntime with all entry points
|
||||
log_store = RuntimeLogStore(base_path=self._storage_path / "runtime_logs")
|
||||
|
||||
# Enable checkpointing by default for resumable sessions
|
||||
from framework.orchestrator.checkpoint_config import CheckpointConfig
|
||||
|
||||
checkpoint_config = CheckpointConfig(
|
||||
enabled=True,
|
||||
checkpoint_on_node_start=False, # Only checkpoint after nodes complete
|
||||
checkpoint_on_node_complete=True,
|
||||
checkpoint_max_age_days=7,
|
||||
async_checkpoint=True, # Non-blocking
|
||||
)
|
||||
|
||||
# Handle runtime_config - only pass through if it's actually an AgentRuntimeConfig.
|
||||
# Agents may export a RuntimeConfig (LLM settings) or queen-generated custom classes
|
||||
# that would crash AgentRuntime if passed through.
|
||||
runtime_config = None
|
||||
if self.runtime_config is not None:
|
||||
from framework.host.agent_host import AgentRuntimeConfig
|
||||
|
||||
if isinstance(self.runtime_config, AgentRuntimeConfig):
|
||||
runtime_config = self.runtime_config
|
||||
|
||||
self._agent_runtime = create_agent_runtime(
|
||||
graph=self.graph,
|
||||
goal=self.goal,
|
||||
storage_path=self._storage_path,
|
||||
entry_points=entry_points,
|
||||
llm=self._llm,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
runtime_log_store=log_store,
|
||||
checkpoint_config=checkpoint_config,
|
||||
config=runtime_config,
|
||||
graph_id=self.graph.id or self.agent_path.name,
|
||||
accounts_prompt=accounts_prompt,
|
||||
accounts_data=accounts_data,
|
||||
tool_provider_map=tool_provider_map,
|
||||
event_bus=event_bus,
|
||||
skills_manager_config=skills_manager_config,
|
||||
)
|
||||
|
||||
# Pass intro_message through for TUI display
|
||||
self._agent_runtime.intro_message = self.intro_message
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution modes
|
||||
#
|
||||
|
||||
@@ -262,15 +262,21 @@ class ToolRegistry:
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
registry_ref = self
|
||||
|
||||
def executor(tool_use: ToolUse) -> ToolResult:
|
||||
if tool_use.name not in self._tools:
|
||||
# Check if credential files changed (lightweight dir listing).
|
||||
# If new OAuth tokens appeared, restarts MCP servers to pick them up.
|
||||
registry_ref.resync_mcp_servers_if_needed()
|
||||
|
||||
if tool_use.name not in registry_ref._tools:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}),
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
registered = self._tools[tool_use.name]
|
||||
registered = registry_ref._tools[tool_use.name]
|
||||
try:
|
||||
result = registered.executor(tool_use.input)
|
||||
|
||||
@@ -922,6 +928,11 @@ class ToolRegistry:
|
||||
clients and re-loads them so the new subprocess picks up the fresh
|
||||
credentials.
|
||||
|
||||
Note: Individual credential TTL/refresh is handled by the MCP server
|
||||
process internally -- it resolves tokens from the credential store
|
||||
on every tool call, not at startup. This method only handles the case
|
||||
where entirely new credential files appear.
|
||||
|
||||
Returns True if a resync was performed, False otherwise.
|
||||
"""
|
||||
if not self._mcp_clients or self._mcp_config_path is None:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Credential resolver pipeline stage.
|
||||
|
||||
Resolves connected accounts from the credential store and builds
|
||||
the ``accounts_prompt`` and ``tool_provider_map`` for system prompt
|
||||
injection. Replaces the credential resolution block in
|
||||
``AgentLoader._setup()`` (lines 1861-1879).
|
||||
Resolves connected accounts at startup. Individual credential TTL/refresh
|
||||
is handled by MCP server processes internally -- they resolve tokens from
|
||||
the credential store on every tool call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -19,37 +18,41 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@register("credential_resolver")
|
||||
class CredentialResolverStage(PipelineStage):
|
||||
"""Resolve connected accounts and inject into pipeline context."""
|
||||
"""Resolve connected accounts for system prompt injection."""
|
||||
|
||||
order = 40 # before MCP (tools need account info for routing)
|
||||
order = 40
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self._accounts_prompt = ""
|
||||
self._accounts_data: list[dict] | None = None
|
||||
self._tool_provider_map: dict[str, str] | None = None
|
||||
def __init__(self, credential_store: Any = None, **kwargs: Any) -> None:
|
||||
self._credential_store = credential_store
|
||||
self.accounts_prompt = ""
|
||||
self.accounts_data: list[dict] | None = None
|
||||
self.tool_provider_map: dict[str, str] | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Resolve credentials from the store."""
|
||||
try:
|
||||
from aden_tools.credentials.store_adapter import (
|
||||
CredentialStoreAdapter,
|
||||
)
|
||||
from framework.orchestrator.prompting import build_accounts_prompt
|
||||
|
||||
adapter = CredentialStoreAdapter.default()
|
||||
self._accounts_data = adapter.get_all_account_info()
|
||||
self._tool_provider_map = adapter.get_tool_provider_map()
|
||||
if self._accounts_data:
|
||||
self._accounts_prompt = build_accounts_prompt(
|
||||
self._accounts_data,
|
||||
self._tool_provider_map,
|
||||
if self._credential_store is not None:
|
||||
adapter = CredentialStoreAdapter(store=self._credential_store)
|
||||
else:
|
||||
adapter = CredentialStoreAdapter.default()
|
||||
self.accounts_data = adapter.get_all_account_info()
|
||||
self.tool_provider_map = adapter.get_tool_provider_map()
|
||||
if self.accounts_data:
|
||||
self.accounts_prompt = build_accounts_prompt(
|
||||
self.accounts_data, self.tool_provider_map,
|
||||
)
|
||||
logger.info(
|
||||
"[pipeline] CredentialResolverStage: %d accounts",
|
||||
len(self.accounts_data or []),
|
||||
)
|
||||
except Exception:
|
||||
pass # best-effort -- agent works without account info
|
||||
logger.debug(
|
||||
"Credential resolution failed (non-fatal)", exc_info=True,
|
||||
)
|
||||
|
||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""Inject credential info into pipeline context."""
|
||||
ctx.metadata["accounts_prompt"] = self._accounts_prompt
|
||||
ctx.metadata["accounts_data"] = self._accounts_data
|
||||
ctx.metadata["tool_provider_map"] = self._tool_provider_map
|
||||
return PipelineResult(action="continue")
|
||||
|
||||
@@ -1,17 +1,7 @@
|
||||
"""LLM provider pipeline stage.
|
||||
|
||||
Resolves the LLM provider (model, API key, OAuth token) from the
|
||||
global config and injects it into the pipeline context. Replaces
|
||||
the 150-line provider resolution block in ``AgentLoader._setup()``.
|
||||
|
||||
Supports all auth methods:
|
||||
- Claude Code subscription (OAuth token from ~/.claude/.credentials.json)
|
||||
- Codex subscription (Keychain / ~/.codex/auth.json)
|
||||
- Kimi Code subscription (~/.kimi/config.toml)
|
||||
- Antigravity (Google Cloud Code Assist OAuth)
|
||||
- Environment variable (ANTHROPIC_API_KEY, OPENAI_API_KEY, etc.)
|
||||
- Key pool (multiple keys with rotation)
|
||||
- Local models (Ollama, no key needed)
|
||||
Resolves the LLM provider from global config. This is the ONLY place
|
||||
the LLM gets created for worker agents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -27,84 +17,79 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@register("llm_provider")
|
||||
class LlmProviderStage(PipelineStage):
|
||||
"""Resolve LLM provider and inject into pipeline context."""
|
||||
"""Resolve LLM provider and make it available."""
|
||||
|
||||
order = 10 # earliest -- everything else depends on having an LLM
|
||||
order = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | None = None,
|
||||
mock_mode: bool = False,
|
||||
llm: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._mock_mode = mock_mode
|
||||
self._llm: Any = None
|
||||
self.llm = llm # Pre-injected LLM (e.g. from session)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Resolve and create the LLM provider."""
|
||||
from framework.config import get_api_key, get_api_keys, get_hive_config, get_preferred_model
|
||||
if self.llm is not None:
|
||||
return # Already injected
|
||||
|
||||
from framework.config import (
|
||||
get_api_key,
|
||||
get_api_keys,
|
||||
get_hive_config,
|
||||
get_preferred_model,
|
||||
)
|
||||
|
||||
model = self._model or get_preferred_model()
|
||||
|
||||
if self._mock_mode:
|
||||
from framework.llm.mock import MockLLMProvider
|
||||
|
||||
self._llm = MockLLMProvider(model=model)
|
||||
self.llm = MockLLMProvider(model=model)
|
||||
return
|
||||
|
||||
config = get_hive_config()
|
||||
llm_config = config.get("llm", {})
|
||||
api_base = llm_config.get("api_base")
|
||||
api_key = get_api_key()
|
||||
api_keys = get_api_keys()
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
# Check for Antigravity (special provider, not LiteLLM)
|
||||
# Check for Antigravity (special provider)
|
||||
if llm_config.get("use_antigravity_subscription"):
|
||||
try:
|
||||
from framework.llm.antigravity import AntigravityProvider
|
||||
|
||||
provider = AntigravityProvider(model=model)
|
||||
if provider.has_credentials():
|
||||
self._llm = provider
|
||||
self.llm = provider
|
||||
logger.info("[pipeline] LlmProviderStage: Antigravity")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Key pool or single key
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
api_key = get_api_key()
|
||||
api_keys = get_api_keys()
|
||||
|
||||
if api_keys and len(api_keys) > 1:
|
||||
self._llm = LiteLLMProvider(
|
||||
model=model,
|
||||
api_keys=api_keys,
|
||||
api_base=api_base,
|
||||
self.llm = LiteLLMProvider(
|
||||
model=model, api_keys=api_keys, api_base=api_base,
|
||||
)
|
||||
elif api_key:
|
||||
# Detect OAuth subscriptions for special headers
|
||||
is_claude_oauth = api_key.startswith("sk-ant-oat")
|
||||
if is_claude_oauth:
|
||||
self._llm = LiteLLMProvider(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers={"authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
else:
|
||||
self._llm = LiteLLMProvider(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
# No key -- local models or env var fallback
|
||||
self._llm = LiteLLMProvider(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
extra = {}
|
||||
if api_key.startswith("sk-ant-oat"):
|
||||
extra["extra_headers"] = {
|
||||
"authorization": f"Bearer {api_key}"
|
||||
}
|
||||
self.llm = LiteLLMProvider(
|
||||
model=model, api_key=api_key, api_base=api_base, **extra,
|
||||
)
|
||||
else:
|
||||
self.llm = LiteLLMProvider(model=model, api_base=api_base)
|
||||
|
||||
logger.info("[pipeline] LlmProviderStage: %s", model)
|
||||
|
||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""Inject LLM provider into pipeline context."""
|
||||
if self._llm:
|
||||
ctx.metadata["llm"] = self._llm
|
||||
return PipelineResult(action="continue")
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
"""MCP registry pipeline stage.
|
||||
|
||||
Resolves MCP server references from the agent config against the global
|
||||
registry (``~/.hive/mcp_registry/installed.json``) and registers tools.
|
||||
Replaces the per-agent ``mcp_servers.json`` pattern with declarative
|
||||
name-based references.
|
||||
|
||||
Agent config declares servers by name::
|
||||
|
||||
{"mcp_servers": [{"name": "hive-tools"}, {"name": "gcu-tools"}]}
|
||||
|
||||
The stage resolves each name from the global registry at ``initialize()``
|
||||
time and injects the resolved ``ToolRegistry`` into the pipeline context.
|
||||
registry and registers tools. This is the ONLY place MCP tools get loaded.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -27,12 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@register("mcp_registry")
|
||||
class McpRegistryStage(PipelineStage):
|
||||
"""Resolve MCP tools from the global registry.
|
||||
|
||||
On ``initialize()``, connects to MCP servers declared in the agent
|
||||
config. On ``process()``, injects ``tools`` and ``tool_executor``
|
||||
into the pipeline context metadata for downstream consumption.
|
||||
"""
|
||||
"""Resolve MCP tools from the global registry."""
|
||||
|
||||
order = 50
|
||||
|
||||
@@ -40,55 +27,66 @@ class McpRegistryStage(PipelineStage):
|
||||
self,
|
||||
server_refs: list[dict[str, Any]] | None = None,
|
||||
agent_path: str | Path | None = None,
|
||||
tool_registry: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
server_refs: List of ``{"name": "server-name"}`` dicts from
|
||||
the agent config's ``mcp_servers`` field.
|
||||
agent_path: Path to the agent directory. If a ``mcp_servers.json``
|
||||
file exists there, it's loaded as a fallback.
|
||||
"""
|
||||
self._server_refs = server_refs or []
|
||||
self._agent_path = Path(agent_path) if agent_path else None
|
||||
self._tool_registry: Any = None
|
||||
self._tool_registry = tool_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Connect to MCP servers and discover tools."""
|
||||
if self._tool_registry is None:
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
|
||||
self._tool_registry = ToolRegistry()
|
||||
|
||||
from framework.loader.mcp_registry import MCPRegistry
|
||||
from framework.loader.tool_registry import ToolRegistry
|
||||
|
||||
self._tool_registry = ToolRegistry()
|
||||
registry = MCPRegistry()
|
||||
mcp_loaded = False
|
||||
|
||||
# 1. Resolve named server refs from global registry
|
||||
# 1. From agent.json mcp_servers refs
|
||||
if self._server_refs:
|
||||
names = [ref["name"] for ref in self._server_refs if ref.get("name")]
|
||||
if names:
|
||||
configs = registry.resolve_for_agent(include=names)
|
||||
if configs:
|
||||
self._tool_registry.load_registry_servers(configs)
|
||||
self._tool_registry.load_registry_servers(
|
||||
[asdict(c) for c in configs]
|
||||
)
|
||||
mcp_loaded = True
|
||||
logger.info(
|
||||
"McpRegistryStage: resolved %d servers from registry",
|
||||
"[pipeline] McpRegistryStage: loaded %d servers: %s",
|
||||
len(configs),
|
||||
names,
|
||||
)
|
||||
|
||||
# 2. Fallback: load mcp_servers.json if it exists (backward compat)
|
||||
if self._agent_path:
|
||||
# 2. Legacy: mcp_servers.json
|
||||
if not mcp_loaded and self._agent_path:
|
||||
mcp_json = self._agent_path / "mcp_servers.json"
|
||||
if mcp_json.exists():
|
||||
self._tool_registry.load_mcp_config(mcp_json)
|
||||
mcp_loaded = True
|
||||
|
||||
# 3. Fallback: all servers from global registry
|
||||
if not mcp_loaded:
|
||||
configs = registry.resolve_for_agent(profile="all")
|
||||
if configs:
|
||||
self._tool_registry.load_registry_servers(
|
||||
[asdict(c) for c in configs]
|
||||
)
|
||||
logger.info(
|
||||
"McpRegistryStage: loaded mcp_servers.json from %s",
|
||||
self._agent_path.name,
|
||||
"[pipeline] McpRegistryStage: loaded %d servers (fallback)",
|
||||
len(configs),
|
||||
)
|
||||
|
||||
total = len(self._tool_registry.get_tools())
|
||||
logger.info("[pipeline] McpRegistryStage: %d tools available", total)
|
||||
|
||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""Inject resolved tools into pipeline context."""
|
||||
if self._tool_registry:
|
||||
ctx.metadata["tool_registry"] = self._tool_registry
|
||||
ctx.metadata["tools"] = list(
|
||||
self._tool_registry.get_tools().values()
|
||||
)
|
||||
ctx.metadata["tool_executor"] = self._tool_registry.get_executor()
|
||||
return PipelineResult(action="continue")
|
||||
|
||||
@property
|
||||
def tool_registry(self):
|
||||
return self._tool_registry
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
"""Skill registry pipeline stage.
|
||||
|
||||
Discovers and loads skills, injects skill prompts into the pipeline
|
||||
context. Replaces the standalone ``SkillsManager`` initialization
|
||||
in ``AgentHost.__init__()``.
|
||||
|
||||
Supports hot-reload: when ``SKILL.md`` files change on disk, the
|
||||
cached prompts are rebuilt and the next pipeline execution picks
|
||||
up the new values.
|
||||
Discovers and loads skills. This is the ONLY place skills get loaded.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -23,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@register("skill_registry")
|
||||
class SkillRegistryStage(PipelineStage):
|
||||
"""Discover skills and inject prompts into pipeline context."""
|
||||
"""Discover skills and provide prompts."""
|
||||
|
||||
order = 60
|
||||
|
||||
@@ -31,34 +25,31 @@ class SkillRegistryStage(PipelineStage):
|
||||
self,
|
||||
project_root: str | Path | None = None,
|
||||
interactive: bool = True,
|
||||
skills_config: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._project_root = Path(project_root) if project_root else None
|
||||
self._interactive = interactive
|
||||
self._skills_manager: Any = None
|
||||
self._skills_config = skills_config
|
||||
self.skills_manager: Any = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Discover skills and start hot-reload watcher."""
|
||||
from framework.skills.config import SkillsConfig
|
||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||
|
||||
config = SkillsManagerConfig(
|
||||
skills_config=self._skills_config or SkillsConfig(),
|
||||
project_root=self._project_root,
|
||||
interactive=self._interactive,
|
||||
)
|
||||
self._skills_manager = SkillsManager(config)
|
||||
self._skills_manager.load()
|
||||
await self._skills_manager.start_watching()
|
||||
self.skills_manager = SkillsManager(config)
|
||||
self.skills_manager.load()
|
||||
await self.skills_manager.start_watching()
|
||||
logger.info(
|
||||
"[pipeline] SkillRegistryStage: catalog=%d chars, protocols=%d chars",
|
||||
len(self.skills_manager.skills_catalog_prompt),
|
||||
len(self.skills_manager.protocols_prompt),
|
||||
)
|
||||
|
||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||
"""Inject skill prompts into pipeline context."""
|
||||
if self._skills_manager:
|
||||
ctx.metadata["skills_catalog_prompt"] = (
|
||||
self._skills_manager.skills_catalog_prompt
|
||||
)
|
||||
ctx.metadata["protocols_prompt"] = (
|
||||
self._skills_manager.protocols_prompt
|
||||
)
|
||||
ctx.metadata["skill_dirs"] = (
|
||||
self._skills_manager.allowlisted_dirs
|
||||
)
|
||||
return PipelineResult(action="continue")
|
||||
|
||||
Reference in New Issue
Block a user