fix: mcp registry pipeline stage
This commit is contained in:
@@ -470,7 +470,8 @@ The agent.json must include ALL of these in one write:
|
|||||||
`input_keys`, `output_keys`, `success_criteria`
|
`input_keys`, `output_keys`, `success_criteria`
|
||||||
- `edges` — connecting all nodes with proper conditions
|
- `edges` — connecting all nodes with proper conditions
|
||||||
- `entry_node`, `terminal_nodes`
|
- `entry_node`, `terminal_nodes`
|
||||||
- `mcp_servers` — reference by name: `[{"name": "hive-tools"}, {"name": "gcu-tools"}]`
|
- `mcp_servers` — REQUIRED. Always include all three: \
|
||||||
|
`[{"name": "hive-tools"}, {"name": "gcu-tools"}, {"name": "files-tools"}]`
|
||||||
- `loop_config` — `max_iterations`, `max_context_tokens`
|
- `loop_config` — `max_iterations`, `max_context_tokens`
|
||||||
|
|
||||||
**Write the COMPLETE config in one `write_file` call. No TODOs, no placeholders.** \
|
**Write the COMPLETE config in one `write_file` call. No TODOs, no placeholders.** \
|
||||||
|
|||||||
@@ -351,7 +351,12 @@ class AgentHost:
|
|||||||
# Start storage
|
# Start storage
|
||||||
await self._storage.start()
|
await self._storage.start()
|
||||||
|
|
||||||
# Create streams for each entry point
|
# Initialize pipeline stages FIRST -- they inject LLM, tools,
|
||||||
|
# credentials, and skills into the host before streams are created.
|
||||||
|
await self._pipeline.initialize_all()
|
||||||
|
self._apply_pipeline_results()
|
||||||
|
|
||||||
|
# Create streams for each entry point (uses pipeline results)
|
||||||
for ep_id, spec in self._entry_points.items():
|
for ep_id, spec in self._entry_points.items():
|
||||||
stream = ExecutionManager(
|
stream = ExecutionManager(
|
||||||
stream_id=ep_id,
|
stream_id=ep_id,
|
||||||
@@ -804,9 +809,6 @@ class AgentHost:
|
|||||||
# Start skill hot-reload watcher (no-op if watchfiles not installed)
|
# Start skill hot-reload watcher (no-op if watchfiles not installed)
|
||||||
await self._skills_manager.start_watching()
|
await self._skills_manager.start_watching()
|
||||||
|
|
||||||
# Initialize pipeline stages (one-time setup)
|
|
||||||
await self._pipeline.initialize_all()
|
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._timers_paused = False
|
self._timers_paused = False
|
||||||
n_stages = len(self._pipeline.stages)
|
n_stages = len(self._pipeline.stages)
|
||||||
@@ -899,6 +901,49 @@ class AgentHost:
|
|||||||
# Primary graph (also stored in self._streams)
|
# Primary graph (also stored in self._streams)
|
||||||
return self._streams.get(entry_point_id)
|
return self._streams.get(entry_point_id)
|
||||||
|
|
||||||
|
def _apply_pipeline_results(self) -> None:
|
||||||
|
"""Extract tools/LLM/credentials/skills from pipeline stages.
|
||||||
|
|
||||||
|
Called after ``pipeline.initialize_all()`` so stages have finished
|
||||||
|
their async setup (MCP connected, skills discovered, etc.).
|
||||||
|
The host reads stage properties and updates its own state.
|
||||||
|
"""
|
||||||
|
for stage in self._pipeline.stages:
|
||||||
|
stage_name = stage.__class__.__name__
|
||||||
|
|
||||||
|
# McpRegistryStage -> tools
|
||||||
|
if hasattr(stage, "tool_registry") and stage.tool_registry is not None:
|
||||||
|
tools = list(stage.tool_registry.get_tools().values())
|
||||||
|
executor = stage.tool_registry.get_executor()
|
||||||
|
if tools:
|
||||||
|
self._tools = tools
|
||||||
|
self._tool_executor = executor
|
||||||
|
logger.info(
|
||||||
|
"Pipeline injected %d tools from %s",
|
||||||
|
len(tools), stage_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LlmProviderStage -> LLM
|
||||||
|
if hasattr(stage, "llm") and stage.llm is not None:
|
||||||
|
if self._llm is None:
|
||||||
|
self._llm = stage.llm
|
||||||
|
logger.info(
|
||||||
|
"Pipeline injected LLM from %s", stage_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CredentialResolverStage -> accounts
|
||||||
|
if hasattr(stage, "accounts_prompt") and stage.accounts_prompt:
|
||||||
|
self._accounts_prompt = stage.accounts_prompt
|
||||||
|
self._accounts_data = getattr(stage, "accounts_data", None)
|
||||||
|
self._tool_provider_map = getattr(
|
||||||
|
stage, "tool_provider_map", None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SkillRegistryStage -> skills manager
|
||||||
|
if hasattr(stage, "skills_manager") and stage.skills_manager is not None:
|
||||||
|
self._skills_manager = stage.skills_manager
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_pipeline_from_config():
|
def _load_pipeline_from_config():
|
||||||
"""Build pipeline from ``~/.hive/configuration.json`` ``pipeline`` key.
|
"""Build pipeline from ``~/.hive/configuration.json`` ``pipeline`` key.
|
||||||
@@ -1917,6 +1962,7 @@ def create_agent_runtime(
|
|||||||
skills_catalog_prompt: str = "",
|
skills_catalog_prompt: str = "",
|
||||||
protocols_prompt: str = "",
|
protocols_prompt: str = "",
|
||||||
skill_dirs: list[str] | None = None,
|
skill_dirs: list[str] | None = None,
|
||||||
|
pipeline_stages: "list[PipelineStage] | None" = None,
|
||||||
) -> AgentHost:
|
) -> AgentHost:
|
||||||
"""
|
"""
|
||||||
Create and configure an AgentHost with entry points.
|
Create and configure an AgentHost with entry points.
|
||||||
@@ -1980,6 +2026,7 @@ def create_agent_runtime(
|
|||||||
skills_catalog_prompt=skills_catalog_prompt,
|
skills_catalog_prompt=skills_catalog_prompt,
|
||||||
protocols_prompt=protocols_prompt,
|
protocols_prompt=protocols_prompt,
|
||||||
skill_dirs=skill_dirs,
|
skill_dirs=skill_dirs,
|
||||||
|
pipeline_stages=pipeline_stages,
|
||||||
)
|
)
|
||||||
|
|
||||||
for spec in entry_points:
|
for spec in entry_points:
|
||||||
|
|||||||
@@ -1267,82 +1267,7 @@ class AgentLoader:
|
|||||||
os.environ["HIVE_AGENT_NAME"] = agent_path.name
|
os.environ["HIVE_AGENT_NAME"] = agent_path.name
|
||||||
os.environ["HIVE_STORAGE_PATH"] = str(self._storage_path)
|
os.environ["HIVE_STORAGE_PATH"] = str(self._storage_path)
|
||||||
|
|
||||||
# Load MCP servers: prefer agent.json mcp_servers refs -> global registry
|
# MCP tools are loaded by McpRegistryStage in the pipeline during AgentHost.start()
|
||||||
# Fallback to mcp_servers.json if it exists (legacy)
|
|
||||||
mcp_config_path = agent_path / "mcp_servers.json"
|
|
||||||
agent_json_path = agent_path / "agent.json"
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"MCP loading: agent_json=%s, mcp_json=%s",
|
|
||||||
agent_json_path.exists(),
|
|
||||||
mcp_config_path.exists(),
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp_loaded = False
|
|
||||||
# 1. From agent.json mcp_servers field (resolved via global registry)
|
|
||||||
if agent_json_path.exists():
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
|
|
||||||
agent_data = _json.loads(agent_json_path.read_text(encoding="utf-8"))
|
|
||||||
server_refs = agent_data.get("mcp_servers", [])
|
|
||||||
if server_refs:
|
|
||||||
names = [ref["name"] for ref in server_refs if ref.get("name")]
|
|
||||||
if names:
|
|
||||||
from framework.loader.mcp_registry import MCPRegistry
|
|
||||||
|
|
||||||
registry = MCPRegistry()
|
|
||||||
configs = registry.resolve_for_agent(include=names)
|
|
||||||
if configs:
|
|
||||||
self._tool_registry.load_registry_servers(configs)
|
|
||||||
mcp_loaded = True
|
|
||||||
logger.info(
|
|
||||||
"Loaded %d MCP servers from registry: %s",
|
|
||||||
len(configs),
|
|
||||||
names,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to load MCP servers from agent.json refs", exc_info=True)
|
|
||||||
|
|
||||||
# 2. Legacy: mcp_servers.json file
|
|
||||||
if not mcp_loaded and mcp_config_path.exists():
|
|
||||||
self._load_mcp_servers_from_config(mcp_config_path)
|
|
||||||
mcp_loaded = True
|
|
||||||
|
|
||||||
# 3. Fallback: load all servers from global registry
|
|
||||||
if not mcp_loaded:
|
|
||||||
try:
|
|
||||||
from framework.loader.mcp_registry import MCPRegistry
|
|
||||||
|
|
||||||
registry = MCPRegistry()
|
|
||||||
configs = registry.resolve_for_agent(profile="all")
|
|
||||||
if configs:
|
|
||||||
self._tool_registry.load_registry_servers(configs)
|
|
||||||
logger.info(
|
|
||||||
"Loaded %d MCP servers from global registry (fallback)",
|
|
||||||
len(configs),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to load MCP servers from global registry", exc_info=True)
|
|
||||||
|
|
||||||
# Auto-discover registry-selected MCP servers from mcp_registry.json
|
|
||||||
self._load_registry_mcp_servers(agent_path)
|
|
||||||
|
|
||||||
# Summary: how many tools were loaded from MCP servers?
|
|
||||||
_all_tools = self._tool_registry.get_tools()
|
|
||||||
logger.info(
|
|
||||||
"MCP loading complete: %d tools from tool registry",
|
|
||||||
len(_all_tools),
|
|
||||||
)
|
|
||||||
if not _all_tools:
|
|
||||||
logger.warning(
|
|
||||||
"ZERO tools loaded from MCP servers. Workers will only have "
|
|
||||||
"synthetic tools (set_output, escalate). Check: "
|
|
||||||
"1) agent.json mcp_servers field, "
|
|
||||||
"2) ~/.hive/mcp_registry/installed.json, "
|
|
||||||
"3) MCP server process startup."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _import_agent_module(agent_path: Path):
|
def _import_agent_module(agent_path: Path):
|
||||||
"""Import an agent package from its directory path.
|
"""Import an agent package from its directory path.
|
||||||
@@ -1695,228 +1620,148 @@ class AgentLoader:
|
|||||||
self._approval_callback = callback
|
self._approval_callback = callback
|
||||||
|
|
||||||
def _setup(self, event_bus=None) -> None:
|
def _setup(self, event_bus=None) -> None:
|
||||||
"""Set up runtime, LLM, and executor."""
|
"""Set up runtime via pipeline stages.
|
||||||
# Configure structured logging (auto-detects JSON vs human-readable)
|
|
||||||
|
Builds a pipeline with the default stages (LLM, credentials, MCP,
|
||||||
|
skills) and passes it to AgentHost. The stages initialize during
|
||||||
|
``AgentHost.start()`` and inject tools/LLM/credentials/skills.
|
||||||
|
"""
|
||||||
from framework.observability import configure_logging
|
from framework.observability import configure_logging
|
||||||
|
from framework.pipeline.stages.credential_resolver import CredentialResolverStage
|
||||||
|
from framework.pipeline.stages.llm_provider import LlmProviderStage
|
||||||
|
from framework.pipeline.stages.mcp_registry import McpRegistryStage
|
||||||
|
from framework.pipeline.stages.skill_registry import SkillRegistryStage
|
||||||
|
from framework.skills.config import SkillsConfig
|
||||||
|
|
||||||
configure_logging(level="INFO", format="auto")
|
configure_logging(level="INFO", format="auto")
|
||||||
|
|
||||||
# Set up session context for tools (agent_id)
|
# Set up session context for tools
|
||||||
agent_id = self.graph.id or "unknown"
|
agent_id = self.graph.id or "unknown"
|
||||||
|
self._tool_registry.set_session_context(agent_id=agent_id)
|
||||||
|
|
||||||
self._tool_registry.set_session_context(
|
# Read MCP server refs from agent.json
|
||||||
agent_id=agent_id,
|
mcp_refs = []
|
||||||
)
|
agent_json = self.agent_path / "agent.json"
|
||||||
|
if agent_json.exists():
|
||||||
|
try:
|
||||||
|
import json as _json
|
||||||
|
|
||||||
# Create LLM provider
|
data = _json.loads(agent_json.read_text(encoding="utf-8"))
|
||||||
# Uses LiteLLM which auto-detects the provider from model name
|
mcp_refs = data.get("mcp_servers", [])
|
||||||
# Skip if already injected (e.g. worker agents with a pre-built LLM)
|
except Exception:
|
||||||
if self._llm is not None:
|
pass
|
||||||
pass # LLM already configured externally
|
|
||||||
elif self.mock_mode:
|
|
||||||
# Use mock LLM for testing without real API calls
|
|
||||||
from framework.llm.mock import MockLLMProvider
|
|
||||||
|
|
||||||
self._llm = MockLLMProvider(model=self.model)
|
# Build default pipeline stages
|
||||||
else:
|
# Default infrastructure stages (always present)
|
||||||
from framework.llm.litellm import LiteLLMProvider
|
pipeline_stages = [
|
||||||
|
LlmProviderStage(
|
||||||
# Check if a subscription mode is configured
|
model=self.model,
|
||||||
config = get_hive_config()
|
mock_mode=self.mock_mode,
|
||||||
llm_config = config.get("llm", {})
|
llm=self._llm,
|
||||||
use_claude_code = llm_config.get("use_claude_code_subscription", False)
|
|
||||||
use_codex = llm_config.get("use_codex_subscription", False)
|
|
||||||
use_kimi_code = llm_config.get("use_kimi_code_subscription", False)
|
|
||||||
use_antigravity = llm_config.get("use_antigravity_subscription", False)
|
|
||||||
api_base = llm_config.get("api_base")
|
|
||||||
|
|
||||||
api_key = None
|
|
||||||
if use_claude_code:
|
|
||||||
# Get OAuth token from Claude Code subscription
|
|
||||||
api_key = get_claude_code_token()
|
|
||||||
if not api_key:
|
|
||||||
logger.warning(
|
|
||||||
"Claude Code subscription configured but no token found. "
|
|
||||||
"Run 'claude' to authenticate, then try again."
|
|
||||||
)
|
|
||||||
elif use_codex:
|
|
||||||
# Get OAuth token from Codex subscription
|
|
||||||
api_key = get_codex_token()
|
|
||||||
if not api_key:
|
|
||||||
logger.warning(
|
|
||||||
"Codex subscription configured but no token found. "
|
|
||||||
"Run 'codex' to authenticate, then try again."
|
|
||||||
)
|
|
||||||
elif use_kimi_code:
|
|
||||||
# Get API key from Kimi Code CLI config (~/.kimi/config.toml)
|
|
||||||
api_key = get_kimi_code_token()
|
|
||||||
if not api_key:
|
|
||||||
logger.warning(
|
|
||||||
"Kimi Code subscription configured but no key found. "
|
|
||||||
"Run 'kimi /login' to authenticate, then try again."
|
|
||||||
)
|
|
||||||
elif use_antigravity:
|
|
||||||
pass # AntigravityProvider handles credentials internally
|
|
||||||
|
|
||||||
if api_key and use_claude_code:
|
|
||||||
# Use litellm's built-in Anthropic OAuth support.
|
|
||||||
# The lowercase "authorization" key triggers OAuth detection which
|
|
||||||
# adds the required anthropic-beta and browser-access headers.
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=self.model,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
extra_headers={"authorization": f"Bearer {api_key}"},
|
|
||||||
)
|
|
||||||
elif api_key and use_codex:
|
|
||||||
# OpenAI Codex subscription routes through the ChatGPT backend
|
|
||||||
# (chatgpt.com/backend-api/codex/responses), NOT the standard
|
|
||||||
# OpenAI API. The consumer OAuth token lacks platform API scopes.
|
|
||||||
extra_headers: dict[str, str] = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"User-Agent": "CodexBar",
|
|
||||||
}
|
|
||||||
account_id = get_codex_account_id()
|
|
||||||
if account_id:
|
|
||||||
extra_headers["ChatGPT-Account-Id"] = account_id
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=self.model,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base="https://chatgpt.com/backend-api/codex",
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
store=False,
|
|
||||||
allowed_openai_params=["store"],
|
|
||||||
)
|
|
||||||
elif api_key and use_kimi_code:
|
|
||||||
# Kimi Code subscription uses the Kimi coding API (OpenAI-compatible).
|
|
||||||
# The api_base is set automatically by LiteLLMProvider for kimi/ models.
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=self.model,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
)
|
|
||||||
elif use_antigravity:
|
|
||||||
# Direct OAuth to Google's internal Cloud Code Assist gateway.
|
|
||||||
# No local proxy required — AntigravityProvider handles token
|
|
||||||
# refresh and Gemini-format request/response conversion natively.
|
|
||||||
from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415
|
|
||||||
|
|
||||||
provider = AntigravityProvider(model=self.model)
|
|
||||||
if not provider.has_credentials():
|
|
||||||
print(
|
|
||||||
"Warning: Antigravity credentials not found. "
|
|
||||||
"Run: uv run python core/antigravity_auth.py auth account add"
|
|
||||||
)
|
|
||||||
self._llm = provider
|
|
||||||
else:
|
|
||||||
# Local models (e.g. Ollama) don't need an API key
|
|
||||||
if self._is_local_model(self.model):
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=self.model,
|
|
||||||
api_base=api_base,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fall back to environment variable
|
|
||||||
# First check api_key_env_var from config (set by quickstart)
|
|
||||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(
|
|
||||||
self.model
|
|
||||||
)
|
|
||||||
if api_key_env and os.environ.get(api_key_env):
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=self.model,
|
|
||||||
api_key=os.environ[api_key_env],
|
|
||||||
api_base=api_base,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fall back to credential store
|
|
||||||
api_key = self._get_api_key_from_credential_store()
|
|
||||||
if api_key:
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=self.model, api_key=api_key, api_base=api_base
|
|
||||||
)
|
|
||||||
# Set env var so downstream code (e.g. cleanup LLM in
|
|
||||||
# node._extract_json) can also find it
|
|
||||||
if api_key_env:
|
|
||||||
os.environ[api_key_env] = api_key
|
|
||||||
elif api_key_env:
|
|
||||||
logger.warning(
|
|
||||||
"%s not set. LLM calls will fail. "
|
|
||||||
"Set it with: export %s=your-api-key",
|
|
||||||
api_key_env,
|
|
||||||
api_key_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fail fast if the agent needs an LLM but none was configured
|
|
||||||
if self._llm is None:
|
|
||||||
has_llm_nodes = any(
|
|
||||||
node.node_type in ("event_loop", "gcu") for node in self.graph.nodes
|
|
||||||
)
|
|
||||||
if has_llm_nodes:
|
|
||||||
from framework.credentials.models import CredentialError
|
|
||||||
|
|
||||||
if self._is_local_model(self.model):
|
|
||||||
raise CredentialError(
|
|
||||||
f"Failed to initialize LLM for local model '{self.model}'. "
|
|
||||||
f"Ensure your local LLM server is running "
|
|
||||||
f"(e.g. 'ollama serve' for Ollama)."
|
|
||||||
)
|
|
||||||
api_key_env = self._get_api_key_env_var(self.model)
|
|
||||||
hint = (
|
|
||||||
f"Set it with: export {api_key_env}=your-api-key"
|
|
||||||
if api_key_env
|
|
||||||
else "Configure an API key for your LLM provider."
|
|
||||||
)
|
|
||||||
raise CredentialError(f"LLM API key not found for model '{self.model}'. {hint}")
|
|
||||||
|
|
||||||
# Get tools for runtime
|
|
||||||
# (GCU and file tools are now registered via mcp_servers.json or MCP registry,
|
|
||||||
# not auto-registered here. Agents declare them in agent.json.)
|
|
||||||
tools = list(self._tool_registry.get_tools().values())
|
|
||||||
tool_executor = self._tool_registry.get_executor()
|
|
||||||
|
|
||||||
# Collect connected account info for system prompt injection
|
|
||||||
accounts_prompt = ""
|
|
||||||
accounts_data: list[dict] | None = None
|
|
||||||
tool_provider_map: dict[str, str] | None = None
|
|
||||||
try:
|
|
||||||
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
|
|
||||||
|
|
||||||
if self._credential_store is not None:
|
|
||||||
adapter = CredentialStoreAdapter(store=self._credential_store)
|
|
||||||
else:
|
|
||||||
adapter = CredentialStoreAdapter.default()
|
|
||||||
accounts_data = adapter.get_all_account_info()
|
|
||||||
tool_provider_map = adapter.get_tool_provider_map()
|
|
||||||
if accounts_data:
|
|
||||||
from framework.orchestrator.prompting import build_accounts_prompt
|
|
||||||
|
|
||||||
accounts_prompt = build_accounts_prompt(accounts_data, tool_provider_map)
|
|
||||||
except Exception:
|
|
||||||
pass # Best-effort — agent works without account info
|
|
||||||
|
|
||||||
# Skill configuration — the runtime handles discovery, loading, trust-gating and
|
|
||||||
# prompt rasterization. The runner just builds the config.
|
|
||||||
from framework.skills.config import SkillsConfig
|
|
||||||
from framework.skills.manager import SkillsManagerConfig
|
|
||||||
|
|
||||||
skills_manager_config = SkillsManagerConfig(
|
|
||||||
skills_config=SkillsConfig.from_agent_vars(
|
|
||||||
default_skills=getattr(self, "_agent_default_skills", None),
|
|
||||||
skills=getattr(self, "_agent_skills", None),
|
|
||||||
),
|
),
|
||||||
project_root=self.agent_path,
|
CredentialResolverStage(
|
||||||
interactive=self._interactive,
|
credential_store=self._credential_store,
|
||||||
|
),
|
||||||
|
McpRegistryStage(
|
||||||
|
server_refs=mcp_refs,
|
||||||
|
agent_path=self.agent_path,
|
||||||
|
tool_registry=self._tool_registry,
|
||||||
|
),
|
||||||
|
SkillRegistryStage(
|
||||||
|
project_root=self.agent_path,
|
||||||
|
interactive=self._interactive,
|
||||||
|
skills_config=SkillsConfig.from_agent_vars(
|
||||||
|
default_skills=getattr(self, "_agent_default_skills", None),
|
||||||
|
skills=getattr(self, "_agent_skills", None),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Merge user-configured stages from ~/.hive/configuration.json
|
||||||
|
from framework.config import get_hive_config
|
||||||
|
from framework.pipeline.registry import build_pipeline_from_config
|
||||||
|
|
||||||
|
hive_config = get_hive_config()
|
||||||
|
user_stages_config = hive_config.get("pipeline", {}).get("stages", [])
|
||||||
|
if user_stages_config:
|
||||||
|
user_pipeline = build_pipeline_from_config(user_stages_config)
|
||||||
|
pipeline_stages.extend(user_pipeline.stages)
|
||||||
|
|
||||||
|
# Merge agent-level overrides from agent.json pipeline field
|
||||||
|
if agent_json.exists():
|
||||||
|
try:
|
||||||
|
agent_pipeline = (
|
||||||
|
_json.loads(agent_json.read_text(encoding="utf-8"))
|
||||||
|
.get("pipeline", {})
|
||||||
|
.get("stages", [])
|
||||||
|
)
|
||||||
|
if agent_pipeline:
|
||||||
|
agent_stages = build_pipeline_from_config(agent_pipeline)
|
||||||
|
pipeline_stages.extend(agent_stages.stages)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._setup_agent_runtime_with_pipeline(
|
||||||
|
pipeline_stages=pipeline_stages,
|
||||||
|
event_bus=event_bus,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._setup_agent_runtime(
|
def _setup_agent_runtime_with_pipeline(
|
||||||
tools,
|
self,
|
||||||
tool_executor,
|
pipeline_stages: list,
|
||||||
accounts_prompt=accounts_prompt,
|
event_bus=None,
|
||||||
accounts_data=accounts_data,
|
) -> None:
|
||||||
tool_provider_map=tool_provider_map,
|
"""Create AgentHost with pipeline stages."""
|
||||||
event_bus=event_bus,
|
from framework.host.execution_manager import EntryPointSpec
|
||||||
skills_manager_config=skills_manager_config,
|
from framework.orchestrator.checkpoint_config import CheckpointConfig
|
||||||
|
|
||||||
|
entry_points = []
|
||||||
|
if self.graph.entry_node:
|
||||||
|
entry_points.append(
|
||||||
|
EntryPointSpec(
|
||||||
|
id="default",
|
||||||
|
name="Default",
|
||||||
|
entry_node=self.graph.entry_node,
|
||||||
|
trigger_type="manual",
|
||||||
|
isolation_level="shared",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
log_store = RuntimeLogStore(
|
||||||
|
base_path=self._storage_path / "runtime_logs",
|
||||||
)
|
)
|
||||||
|
checkpoint_config = CheckpointConfig(
|
||||||
|
enabled=True,
|
||||||
|
checkpoint_on_node_start=False,
|
||||||
|
checkpoint_on_node_complete=True,
|
||||||
|
checkpoint_max_age_days=7,
|
||||||
|
async_checkpoint=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime_config = None
|
||||||
|
if self.runtime_config is not None:
|
||||||
|
from framework.host.agent_host import AgentRuntimeConfig
|
||||||
|
|
||||||
|
if isinstance(self.runtime_config, AgentRuntimeConfig):
|
||||||
|
runtime_config = self.runtime_config
|
||||||
|
|
||||||
|
|
||||||
|
self._agent_runtime = create_agent_runtime(
|
||||||
|
graph=self.graph,
|
||||||
|
goal=self.goal,
|
||||||
|
storage_path=self._storage_path,
|
||||||
|
entry_points=entry_points,
|
||||||
|
llm=None, # Injected by LlmProviderStage
|
||||||
|
tools=[], # Injected by McpRegistryStage
|
||||||
|
tool_executor=None, # Injected by McpRegistryStage
|
||||||
|
runtime_log_store=log_store,
|
||||||
|
checkpoint_config=checkpoint_config,
|
||||||
|
config=runtime_config,
|
||||||
|
graph_id=self.graph.id or self.agent_path.name,
|
||||||
|
event_bus=event_bus,
|
||||||
|
pipeline_stages=pipeline_stages,
|
||||||
|
)
|
||||||
|
self._agent_runtime.intro_message = self.intro_message
|
||||||
|
|
||||||
def _get_api_key_env_var(self, model: str) -> str | None:
|
def _get_api_key_env_var(self, model: str) -> str | None:
|
||||||
"""Get the environment variable name for the API key based on model name."""
|
"""Get the environment variable name for the API key based on model name."""
|
||||||
@@ -2009,83 +1854,6 @@ class AgentLoader:
|
|||||||
)
|
)
|
||||||
return model.lower().startswith(LOCAL_PREFIXES)
|
return model.lower().startswith(LOCAL_PREFIXES)
|
||||||
|
|
||||||
def _setup_agent_runtime(
|
|
||||||
self,
|
|
||||||
tools: list,
|
|
||||||
tool_executor: Callable | None,
|
|
||||||
accounts_prompt: str = "",
|
|
||||||
accounts_data: list[dict] | None = None,
|
|
||||||
tool_provider_map: dict[str, str] | None = None,
|
|
||||||
event_bus=None,
|
|
||||||
skills_catalog_prompt: str = "",
|
|
||||||
protocols_prompt: str = "",
|
|
||||||
skill_dirs: list[str] | None = None,
|
|
||||||
skills_manager_config=None,
|
|
||||||
) -> None:
|
|
||||||
"""Set up multi-entry-point execution using AgentRuntime."""
|
|
||||||
entry_points = []
|
|
||||||
|
|
||||||
# Always create a primary entry point for the graph's entry node.
|
|
||||||
# For multi-entry-point agents this ensures the primary path (e.g.
|
|
||||||
# user-facing rule setup) is reachable alongside async entry points.
|
|
||||||
if self.graph.entry_node:
|
|
||||||
entry_points.insert(
|
|
||||||
0,
|
|
||||||
EntryPointSpec(
|
|
||||||
id="default",
|
|
||||||
name="Default",
|
|
||||||
entry_node=self.graph.entry_node,
|
|
||||||
trigger_type="manual",
|
|
||||||
isolation_level="shared",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create AgentRuntime with all entry points
|
|
||||||
log_store = RuntimeLogStore(base_path=self._storage_path / "runtime_logs")
|
|
||||||
|
|
||||||
# Enable checkpointing by default for resumable sessions
|
|
||||||
from framework.orchestrator.checkpoint_config import CheckpointConfig
|
|
||||||
|
|
||||||
checkpoint_config = CheckpointConfig(
|
|
||||||
enabled=True,
|
|
||||||
checkpoint_on_node_start=False, # Only checkpoint after nodes complete
|
|
||||||
checkpoint_on_node_complete=True,
|
|
||||||
checkpoint_max_age_days=7,
|
|
||||||
async_checkpoint=True, # Non-blocking
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle runtime_config - only pass through if it's actually an AgentRuntimeConfig.
|
|
||||||
# Agents may export a RuntimeConfig (LLM settings) or queen-generated custom classes
|
|
||||||
# that would crash AgentRuntime if passed through.
|
|
||||||
runtime_config = None
|
|
||||||
if self.runtime_config is not None:
|
|
||||||
from framework.host.agent_host import AgentRuntimeConfig
|
|
||||||
|
|
||||||
if isinstance(self.runtime_config, AgentRuntimeConfig):
|
|
||||||
runtime_config = self.runtime_config
|
|
||||||
|
|
||||||
self._agent_runtime = create_agent_runtime(
|
|
||||||
graph=self.graph,
|
|
||||||
goal=self.goal,
|
|
||||||
storage_path=self._storage_path,
|
|
||||||
entry_points=entry_points,
|
|
||||||
llm=self._llm,
|
|
||||||
tools=tools,
|
|
||||||
tool_executor=tool_executor,
|
|
||||||
runtime_log_store=log_store,
|
|
||||||
checkpoint_config=checkpoint_config,
|
|
||||||
config=runtime_config,
|
|
||||||
graph_id=self.graph.id or self.agent_path.name,
|
|
||||||
accounts_prompt=accounts_prompt,
|
|
||||||
accounts_data=accounts_data,
|
|
||||||
tool_provider_map=tool_provider_map,
|
|
||||||
event_bus=event_bus,
|
|
||||||
skills_manager_config=skills_manager_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass intro_message through for TUI display
|
|
||||||
self._agent_runtime.intro_message = self.intro_message
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Execution modes
|
# Execution modes
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -262,15 +262,21 @@ class ToolRegistry:
|
|||||||
is_error=False,
|
is_error=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
registry_ref = self
|
||||||
|
|
||||||
def executor(tool_use: ToolUse) -> ToolResult:
|
def executor(tool_use: ToolUse) -> ToolResult:
|
||||||
if tool_use.name not in self._tools:
|
# Check if credential files changed (lightweight dir listing).
|
||||||
|
# If new OAuth tokens appeared, restarts MCP servers to pick them up.
|
||||||
|
registry_ref.resync_mcp_servers_if_needed()
|
||||||
|
|
||||||
|
if tool_use.name not in registry_ref._tools:
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
tool_use_id=tool_use.id,
|
tool_use_id=tool_use.id,
|
||||||
content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}),
|
content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}),
|
||||||
is_error=True,
|
is_error=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
registered = self._tools[tool_use.name]
|
registered = registry_ref._tools[tool_use.name]
|
||||||
try:
|
try:
|
||||||
result = registered.executor(tool_use.input)
|
result = registered.executor(tool_use.input)
|
||||||
|
|
||||||
@@ -922,6 +928,11 @@ class ToolRegistry:
|
|||||||
clients and re-loads them so the new subprocess picks up the fresh
|
clients and re-loads them so the new subprocess picks up the fresh
|
||||||
credentials.
|
credentials.
|
||||||
|
|
||||||
|
Note: Individual credential TTL/refresh is handled by the MCP server
|
||||||
|
process internally -- it resolves tokens from the credential store
|
||||||
|
on every tool call, not at startup. This method only handles the case
|
||||||
|
where entirely new credential files appear.
|
||||||
|
|
||||||
Returns True if a resync was performed, False otherwise.
|
Returns True if a resync was performed, False otherwise.
|
||||||
"""
|
"""
|
||||||
if not self._mcp_clients or self._mcp_config_path is None:
|
if not self._mcp_clients or self._mcp_config_path is None:
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""Credential resolver pipeline stage.
|
"""Credential resolver pipeline stage.
|
||||||
|
|
||||||
Resolves connected accounts from the credential store and builds
|
Resolves connected accounts at startup. Individual credential TTL/refresh
|
||||||
the ``accounts_prompt`` and ``tool_provider_map`` for system prompt
|
is handled by MCP server processes internally -- they resolve tokens from
|
||||||
injection. Replaces the credential resolution block in
|
the credential store on every tool call.
|
||||||
``AgentLoader._setup()`` (lines 1861-1879).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -19,37 +18,41 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
@register("credential_resolver")
|
@register("credential_resolver")
|
||||||
class CredentialResolverStage(PipelineStage):
|
class CredentialResolverStage(PipelineStage):
|
||||||
"""Resolve connected accounts and inject into pipeline context."""
|
"""Resolve connected accounts for system prompt injection."""
|
||||||
|
|
||||||
order = 40 # before MCP (tools need account info for routing)
|
order = 40
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, credential_store: Any = None, **kwargs: Any) -> None:
|
||||||
self._accounts_prompt = ""
|
self._credential_store = credential_store
|
||||||
self._accounts_data: list[dict] | None = None
|
self.accounts_prompt = ""
|
||||||
self._tool_provider_map: dict[str, str] | None = None
|
self.accounts_data: list[dict] | None = None
|
||||||
|
self.tool_provider_map: dict[str, str] | None = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Resolve credentials from the store."""
|
|
||||||
try:
|
try:
|
||||||
from aden_tools.credentials.store_adapter import (
|
from aden_tools.credentials.store_adapter import (
|
||||||
CredentialStoreAdapter,
|
CredentialStoreAdapter,
|
||||||
)
|
)
|
||||||
from framework.orchestrator.prompting import build_accounts_prompt
|
from framework.orchestrator.prompting import build_accounts_prompt
|
||||||
|
|
||||||
adapter = CredentialStoreAdapter.default()
|
if self._credential_store is not None:
|
||||||
self._accounts_data = adapter.get_all_account_info()
|
adapter = CredentialStoreAdapter(store=self._credential_store)
|
||||||
self._tool_provider_map = adapter.get_tool_provider_map()
|
else:
|
||||||
if self._accounts_data:
|
adapter = CredentialStoreAdapter.default()
|
||||||
self._accounts_prompt = build_accounts_prompt(
|
self.accounts_data = adapter.get_all_account_info()
|
||||||
self._accounts_data,
|
self.tool_provider_map = adapter.get_tool_provider_map()
|
||||||
self._tool_provider_map,
|
if self.accounts_data:
|
||||||
|
self.accounts_prompt = build_accounts_prompt(
|
||||||
|
self.accounts_data, self.tool_provider_map,
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"[pipeline] CredentialResolverStage: %d accounts",
|
||||||
|
len(self.accounts_data or []),
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # best-effort -- agent works without account info
|
logger.debug(
|
||||||
|
"Credential resolution failed (non-fatal)", exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
"""Inject credential info into pipeline context."""
|
|
||||||
ctx.metadata["accounts_prompt"] = self._accounts_prompt
|
|
||||||
ctx.metadata["accounts_data"] = self._accounts_data
|
|
||||||
ctx.metadata["tool_provider_map"] = self._tool_provider_map
|
|
||||||
return PipelineResult(action="continue")
|
return PipelineResult(action="continue")
|
||||||
|
|||||||
@@ -1,17 +1,7 @@
|
|||||||
"""LLM provider pipeline stage.
|
"""LLM provider pipeline stage.
|
||||||
|
|
||||||
Resolves the LLM provider (model, API key, OAuth token) from the
|
Resolves the LLM provider from global config. This is the ONLY place
|
||||||
global config and injects it into the pipeline context. Replaces
|
the LLM gets created for worker agents.
|
||||||
the 150-line provider resolution block in ``AgentLoader._setup()``.
|
|
||||||
|
|
||||||
Supports all auth methods:
|
|
||||||
- Claude Code subscription (OAuth token from ~/.claude/.credentials.json)
|
|
||||||
- Codex subscription (Keychain / ~/.codex/auth.json)
|
|
||||||
- Kimi Code subscription (~/.kimi/config.toml)
|
|
||||||
- Antigravity (Google Cloud Code Assist OAuth)
|
|
||||||
- Environment variable (ANTHROPIC_API_KEY, OPENAI_API_KEY, etc.)
|
|
||||||
- Key pool (multiple keys with rotation)
|
|
||||||
- Local models (Ollama, no key needed)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -27,84 +17,79 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
@register("llm_provider")
|
@register("llm_provider")
|
||||||
class LlmProviderStage(PipelineStage):
|
class LlmProviderStage(PipelineStage):
|
||||||
"""Resolve LLM provider and inject into pipeline context."""
|
"""Resolve LLM provider and make it available."""
|
||||||
|
|
||||||
order = 10 # earliest -- everything else depends on having an LLM
|
order = 10
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
mock_mode: bool = False,
|
mock_mode: bool = False,
|
||||||
|
llm: Any = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._model = model
|
self._model = model
|
||||||
self._mock_mode = mock_mode
|
self._mock_mode = mock_mode
|
||||||
self._llm: Any = None
|
self.llm = llm # Pre-injected LLM (e.g. from session)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Resolve and create the LLM provider."""
|
if self.llm is not None:
|
||||||
from framework.config import get_api_key, get_api_keys, get_hive_config, get_preferred_model
|
return # Already injected
|
||||||
|
|
||||||
|
from framework.config import (
|
||||||
|
get_api_key,
|
||||||
|
get_api_keys,
|
||||||
|
get_hive_config,
|
||||||
|
get_preferred_model,
|
||||||
|
)
|
||||||
|
|
||||||
model = self._model or get_preferred_model()
|
model = self._model or get_preferred_model()
|
||||||
|
|
||||||
if self._mock_mode:
|
if self._mock_mode:
|
||||||
from framework.llm.mock import MockLLMProvider
|
from framework.llm.mock import MockLLMProvider
|
||||||
|
|
||||||
self._llm = MockLLMProvider(model=model)
|
self.llm = MockLLMProvider(model=model)
|
||||||
return
|
return
|
||||||
|
|
||||||
config = get_hive_config()
|
config = get_hive_config()
|
||||||
llm_config = config.get("llm", {})
|
llm_config = config.get("llm", {})
|
||||||
api_base = llm_config.get("api_base")
|
api_base = llm_config.get("api_base")
|
||||||
api_key = get_api_key()
|
|
||||||
api_keys = get_api_keys()
|
|
||||||
|
|
||||||
from framework.llm.litellm import LiteLLMProvider
|
# Check for Antigravity (special provider)
|
||||||
|
|
||||||
# Check for Antigravity (special provider, not LiteLLM)
|
|
||||||
if llm_config.get("use_antigravity_subscription"):
|
if llm_config.get("use_antigravity_subscription"):
|
||||||
try:
|
try:
|
||||||
from framework.llm.antigravity import AntigravityProvider
|
from framework.llm.antigravity import AntigravityProvider
|
||||||
|
|
||||||
provider = AntigravityProvider(model=model)
|
provider = AntigravityProvider(model=model)
|
||||||
if provider.has_credentials():
|
if provider.has_credentials():
|
||||||
self._llm = provider
|
self.llm = provider
|
||||||
|
logger.info("[pipeline] LlmProviderStage: Antigravity")
|
||||||
return
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Key pool or single key
|
from framework.llm.litellm import LiteLLMProvider
|
||||||
|
|
||||||
|
api_key = get_api_key()
|
||||||
|
api_keys = get_api_keys()
|
||||||
|
|
||||||
if api_keys and len(api_keys) > 1:
|
if api_keys and len(api_keys) > 1:
|
||||||
self._llm = LiteLLMProvider(
|
self.llm = LiteLLMProvider(
|
||||||
model=model,
|
model=model, api_keys=api_keys, api_base=api_base,
|
||||||
api_keys=api_keys,
|
|
||||||
api_base=api_base,
|
|
||||||
)
|
)
|
||||||
elif api_key:
|
elif api_key:
|
||||||
# Detect OAuth subscriptions for special headers
|
extra = {}
|
||||||
is_claude_oauth = api_key.startswith("sk-ant-oat")
|
if api_key.startswith("sk-ant-oat"):
|
||||||
if is_claude_oauth:
|
extra["extra_headers"] = {
|
||||||
self._llm = LiteLLMProvider(
|
"authorization": f"Bearer {api_key}"
|
||||||
model=model,
|
}
|
||||||
api_key=api_key,
|
self.llm = LiteLLMProvider(
|
||||||
api_base=api_base,
|
model=model, api_key=api_key, api_base=api_base, **extra,
|
||||||
extra_headers={"authorization": f"Bearer {api_key}"},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=model,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No key -- local models or env var fallback
|
|
||||||
self._llm = LiteLLMProvider(
|
|
||||||
model=model,
|
|
||||||
api_base=api_base,
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.llm = LiteLLMProvider(model=model, api_base=api_base)
|
||||||
|
|
||||||
|
logger.info("[pipeline] LlmProviderStage: %s", model)
|
||||||
|
|
||||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
"""Inject LLM provider into pipeline context."""
|
|
||||||
if self._llm:
|
|
||||||
ctx.metadata["llm"] = self._llm
|
|
||||||
return PipelineResult(action="continue")
|
return PipelineResult(action="continue")
|
||||||
|
|||||||
@@ -1,21 +1,13 @@
|
|||||||
"""MCP registry pipeline stage.
|
"""MCP registry pipeline stage.
|
||||||
|
|
||||||
Resolves MCP server references from the agent config against the global
|
Resolves MCP server references from the agent config against the global
|
||||||
registry (``~/.hive/mcp_registry/installed.json``) and registers tools.
|
registry and registers tools. This is the ONLY place MCP tools get loaded.
|
||||||
Replaces the per-agent ``mcp_servers.json`` pattern with declarative
|
|
||||||
name-based references.
|
|
||||||
|
|
||||||
Agent config declares servers by name::
|
|
||||||
|
|
||||||
{"mcp_servers": [{"name": "hive-tools"}, {"name": "gcu-tools"}]}
|
|
||||||
|
|
||||||
The stage resolves each name from the global registry at ``initialize()``
|
|
||||||
time and injects the resolved ``ToolRegistry`` into the pipeline context.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -27,12 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
@register("mcp_registry")
|
@register("mcp_registry")
|
||||||
class McpRegistryStage(PipelineStage):
|
class McpRegistryStage(PipelineStage):
|
||||||
"""Resolve MCP tools from the global registry.
|
"""Resolve MCP tools from the global registry."""
|
||||||
|
|
||||||
On ``initialize()``, connects to MCP servers declared in the agent
|
|
||||||
config. On ``process()``, injects ``tools`` and ``tool_executor``
|
|
||||||
into the pipeline context metadata for downstream consumption.
|
|
||||||
"""
|
|
||||||
|
|
||||||
order = 50
|
order = 50
|
||||||
|
|
||||||
@@ -40,55 +27,66 @@ class McpRegistryStage(PipelineStage):
|
|||||||
self,
|
self,
|
||||||
server_refs: list[dict[str, Any]] | None = None,
|
server_refs: list[dict[str, Any]] | None = None,
|
||||||
agent_path: str | Path | None = None,
|
agent_path: str | Path | None = None,
|
||||||
|
tool_registry: Any = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
server_refs: List of ``{"name": "server-name"}`` dicts from
|
|
||||||
the agent config's ``mcp_servers`` field.
|
|
||||||
agent_path: Path to the agent directory. If a ``mcp_servers.json``
|
|
||||||
file exists there, it's loaded as a fallback.
|
|
||||||
"""
|
|
||||||
self._server_refs = server_refs or []
|
self._server_refs = server_refs or []
|
||||||
self._agent_path = Path(agent_path) if agent_path else None
|
self._agent_path = Path(agent_path) if agent_path else None
|
||||||
self._tool_registry: Any = None
|
self._tool_registry = tool_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Connect to MCP servers and discover tools."""
|
"""Connect to MCP servers and discover tools."""
|
||||||
|
if self._tool_registry is None:
|
||||||
|
from framework.loader.tool_registry import ToolRegistry
|
||||||
|
|
||||||
|
self._tool_registry = ToolRegistry()
|
||||||
|
|
||||||
from framework.loader.mcp_registry import MCPRegistry
|
from framework.loader.mcp_registry import MCPRegistry
|
||||||
from framework.loader.tool_registry import ToolRegistry
|
|
||||||
|
|
||||||
self._tool_registry = ToolRegistry()
|
|
||||||
registry = MCPRegistry()
|
registry = MCPRegistry()
|
||||||
|
mcp_loaded = False
|
||||||
|
|
||||||
# 1. Resolve named server refs from global registry
|
# 1. From agent.json mcp_servers refs
|
||||||
if self._server_refs:
|
if self._server_refs:
|
||||||
names = [ref["name"] for ref in self._server_refs if ref.get("name")]
|
names = [ref["name"] for ref in self._server_refs if ref.get("name")]
|
||||||
if names:
|
if names:
|
||||||
configs = registry.resolve_for_agent(include=names)
|
configs = registry.resolve_for_agent(include=names)
|
||||||
if configs:
|
if configs:
|
||||||
self._tool_registry.load_registry_servers(configs)
|
self._tool_registry.load_registry_servers(
|
||||||
|
[asdict(c) for c in configs]
|
||||||
|
)
|
||||||
|
mcp_loaded = True
|
||||||
logger.info(
|
logger.info(
|
||||||
"McpRegistryStage: resolved %d servers from registry",
|
"[pipeline] McpRegistryStage: loaded %d servers: %s",
|
||||||
len(configs),
|
len(configs),
|
||||||
|
names,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Fallback: load mcp_servers.json if it exists (backward compat)
|
# 2. Legacy: mcp_servers.json
|
||||||
if self._agent_path:
|
if not mcp_loaded and self._agent_path:
|
||||||
mcp_json = self._agent_path / "mcp_servers.json"
|
mcp_json = self._agent_path / "mcp_servers.json"
|
||||||
if mcp_json.exists():
|
if mcp_json.exists():
|
||||||
self._tool_registry.load_mcp_config(mcp_json)
|
self._tool_registry.load_mcp_config(mcp_json)
|
||||||
|
mcp_loaded = True
|
||||||
|
|
||||||
|
# 3. Fallback: all servers from global registry
|
||||||
|
if not mcp_loaded:
|
||||||
|
configs = registry.resolve_for_agent(profile="all")
|
||||||
|
if configs:
|
||||||
|
self._tool_registry.load_registry_servers(
|
||||||
|
[asdict(c) for c in configs]
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"McpRegistryStage: loaded mcp_servers.json from %s",
|
"[pipeline] McpRegistryStage: loaded %d servers (fallback)",
|
||||||
self._agent_path.name,
|
len(configs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
total = len(self._tool_registry.get_tools())
|
||||||
|
logger.info("[pipeline] McpRegistryStage: %d tools available", total)
|
||||||
|
|
||||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
"""Inject resolved tools into pipeline context."""
|
|
||||||
if self._tool_registry:
|
|
||||||
ctx.metadata["tool_registry"] = self._tool_registry
|
|
||||||
ctx.metadata["tools"] = list(
|
|
||||||
self._tool_registry.get_tools().values()
|
|
||||||
)
|
|
||||||
ctx.metadata["tool_executor"] = self._tool_registry.get_executor()
|
|
||||||
return PipelineResult(action="continue")
|
return PipelineResult(action="continue")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_registry(self):
|
||||||
|
return self._tool_registry
|
||||||
|
|||||||
@@ -1,12 +1,6 @@
|
|||||||
"""Skill registry pipeline stage.
|
"""Skill registry pipeline stage.
|
||||||
|
|
||||||
Discovers and loads skills, injects skill prompts into the pipeline
|
Discovers and loads skills. This is the ONLY place skills get loaded.
|
||||||
context. Replaces the standalone ``SkillsManager`` initialization
|
|
||||||
in ``AgentHost.__init__()``.
|
|
||||||
|
|
||||||
Supports hot-reload: when ``SKILL.md`` files change on disk, the
|
|
||||||
cached prompts are rebuilt and the next pipeline execution picks
|
|
||||||
up the new values.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -23,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
@register("skill_registry")
|
@register("skill_registry")
|
||||||
class SkillRegistryStage(PipelineStage):
|
class SkillRegistryStage(PipelineStage):
|
||||||
"""Discover skills and inject prompts into pipeline context."""
|
"""Discover skills and provide prompts."""
|
||||||
|
|
||||||
order = 60
|
order = 60
|
||||||
|
|
||||||
@@ -31,34 +25,31 @@ class SkillRegistryStage(PipelineStage):
|
|||||||
self,
|
self,
|
||||||
project_root: str | Path | None = None,
|
project_root: str | Path | None = None,
|
||||||
interactive: bool = True,
|
interactive: bool = True,
|
||||||
|
skills_config: Any = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._project_root = Path(project_root) if project_root else None
|
self._project_root = Path(project_root) if project_root else None
|
||||||
self._interactive = interactive
|
self._interactive = interactive
|
||||||
self._skills_manager: Any = None
|
self._skills_config = skills_config
|
||||||
|
self.skills_manager: Any = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Discover skills and start hot-reload watcher."""
|
from framework.skills.config import SkillsConfig
|
||||||
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
from framework.skills.manager import SkillsManager, SkillsManagerConfig
|
||||||
|
|
||||||
config = SkillsManagerConfig(
|
config = SkillsManagerConfig(
|
||||||
|
skills_config=self._skills_config or SkillsConfig(),
|
||||||
project_root=self._project_root,
|
project_root=self._project_root,
|
||||||
interactive=self._interactive,
|
interactive=self._interactive,
|
||||||
)
|
)
|
||||||
self._skills_manager = SkillsManager(config)
|
self.skills_manager = SkillsManager(config)
|
||||||
self._skills_manager.load()
|
self.skills_manager.load()
|
||||||
await self._skills_manager.start_watching()
|
await self.skills_manager.start_watching()
|
||||||
|
logger.info(
|
||||||
|
"[pipeline] SkillRegistryStage: catalog=%d chars, protocols=%d chars",
|
||||||
|
len(self.skills_manager.skills_catalog_prompt),
|
||||||
|
len(self.skills_manager.protocols_prompt),
|
||||||
|
)
|
||||||
|
|
||||||
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
async def process(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
"""Inject skill prompts into pipeline context."""
|
|
||||||
if self._skills_manager:
|
|
||||||
ctx.metadata["skills_catalog_prompt"] = (
|
|
||||||
self._skills_manager.skills_catalog_prompt
|
|
||||||
)
|
|
||||||
ctx.metadata["protocols_prompt"] = (
|
|
||||||
self._skills_manager.protocols_prompt
|
|
||||||
)
|
|
||||||
ctx.metadata["skill_dirs"] = (
|
|
||||||
self._skills_manager.allowlisted_dirs
|
|
||||||
)
|
|
||||||
return PipelineResult(action="continue")
|
return PipelineResult(action="continue")
|
||||||
|
|||||||
Reference in New Issue
Block a user