From db572b9be603571e075772cc73ba1c04d79df483 Mon Sep 17 00:00:00 2001 From: Timothy Date: Tue, 7 Apr 2026 16:15:40 -0700 Subject: [PATCH] fix: mcp registry pipeline stage --- core/framework/agents/queen/nodes/__init__.py | 3 +- core/framework/host/agent_host.py | 55 +- core/framework/loader/agent_loader.py | 494 +++++------------- core/framework/loader/tool_registry.py | 15 +- .../pipeline/stages/credential_resolver.py | 49 +- .../framework/pipeline/stages/llm_provider.py | 89 ++-- .../framework/pipeline/stages/mcp_registry.py | 78 ++- .../pipeline/stages/skill_registry.py | 39 +- 8 files changed, 313 insertions(+), 509 deletions(-) diff --git a/core/framework/agents/queen/nodes/__init__.py b/core/framework/agents/queen/nodes/__init__.py index b1c182df..278c03d5 100644 --- a/core/framework/agents/queen/nodes/__init__.py +++ b/core/framework/agents/queen/nodes/__init__.py @@ -470,7 +470,8 @@ The agent.json must include ALL of these in one write: `input_keys`, `output_keys`, `success_criteria` - `edges` — connecting all nodes with proper conditions - `entry_node`, `terminal_nodes` -- `mcp_servers` — reference by name: `[{"name": "hive-tools"}, {"name": "gcu-tools"}]` +- `mcp_servers` — REQUIRED. Always include all three: \ +`[{"name": "hive-tools"}, {"name": "gcu-tools"}, {"name": "files-tools"}]` - `loop_config` — `max_iterations`, `max_context_tokens` **Write the COMPLETE config in one `write_file` call. No TODOs, no placeholders.** \ diff --git a/core/framework/host/agent_host.py b/core/framework/host/agent_host.py index d858e560..e3d33831 100644 --- a/core/framework/host/agent_host.py +++ b/core/framework/host/agent_host.py @@ -351,7 +351,12 @@ class AgentHost: # Start storage await self._storage.start() - # Create streams for each entry point + # Initialize pipeline stages FIRST -- they inject LLM, tools, + # credentials, and skills into the host before streams are created. + await self._pipeline.initialize_all() + self._apply_pipeline_results() + + # Create streams for each entry point (uses pipeline results) for ep_id, spec in self._entry_points.items(): stream = ExecutionManager( stream_id=ep_id, @@ -804,9 +809,6 @@ class AgentHost: # Start skill hot-reload watcher (no-op if watchfiles not installed) await self._skills_manager.start_watching() - # Initialize pipeline stages (one-time setup) - await self._pipeline.initialize_all() - self._running = True self._timers_paused = False n_stages = len(self._pipeline.stages) @@ -899,6 +901,49 @@ class AgentHost: # Primary graph (also stored in self._streams) return self._streams.get(entry_point_id) + def _apply_pipeline_results(self) -> None: + """Extract tools/LLM/credentials/skills from pipeline stages. + + Called after ``pipeline.initialize_all()`` so stages have finished + their async setup (MCP connected, skills discovered, etc.). + The host reads stage properties and updates its own state. + """ + for stage in self._pipeline.stages: + stage_name = stage.__class__.__name__ + + # McpRegistryStage -> tools + if hasattr(stage, "tool_registry") and stage.tool_registry is not None: + tools = list(stage.tool_registry.get_tools().values()) + executor = stage.tool_registry.get_executor() + if tools: + self._tools = tools + self._tool_executor = executor + logger.info( + "Pipeline injected %d tools from %s", + len(tools), stage_name, + ) + + # LlmProviderStage -> LLM + if hasattr(stage, "llm") and stage.llm is not None: + if self._llm is None: + self._llm = stage.llm + logger.info( + "Pipeline injected LLM from %s", stage_name, + ) + + # CredentialResolverStage -> accounts + if hasattr(stage, "accounts_prompt") and stage.accounts_prompt: + self._accounts_prompt = stage.accounts_prompt + self._accounts_data = getattr(stage, "accounts_data", None) + self._tool_provider_map = getattr( + stage, "tool_provider_map", None, + ) + + # SkillRegistryStage -> skills manager + if hasattr(stage, "skills_manager") and stage.skills_manager is not None: + self._skills_manager = stage.skills_manager + + @staticmethod def _load_pipeline_from_config(): """Build pipeline from ``~/.hive/configuration.json`` ``pipeline`` key. @@ -1917,6 +1962,7 @@ def create_agent_runtime( skills_catalog_prompt: str = "", protocols_prompt: str = "", skill_dirs: list[str] | None = None, + pipeline_stages: "list[PipelineStage] | None" = None, ) -> AgentHost: """ Create and configure an AgentHost with entry points. @@ -1980,6 +2026,7 @@ def create_agent_runtime( skills_catalog_prompt=skills_catalog_prompt, protocols_prompt=protocols_prompt, skill_dirs=skill_dirs, + pipeline_stages=pipeline_stages, ) for spec in entry_points: diff --git a/core/framework/loader/agent_loader.py b/core/framework/loader/agent_loader.py index 829db78c..bcca08ba 100644 --- a/core/framework/loader/agent_loader.py +++ b/core/framework/loader/agent_loader.py @@ -1267,82 +1267,7 @@ class AgentLoader: os.environ["HIVE_AGENT_NAME"] = agent_path.name os.environ["HIVE_STORAGE_PATH"] = str(self._storage_path) - # Load MCP servers: prefer agent.json mcp_servers refs -> global registry - # Fallback to mcp_servers.json if it exists (legacy) - mcp_config_path = agent_path / "mcp_servers.json" - agent_json_path = agent_path / "agent.json" - - logger.info( - "MCP loading: agent_json=%s, mcp_json=%s", - agent_json_path.exists(), - mcp_config_path.exists(), - ) - - mcp_loaded = False - # 1. From agent.json mcp_servers field (resolved via global registry) - if agent_json_path.exists(): - try: - import json as _json - - agent_data = _json.loads(agent_json_path.read_text(encoding="utf-8")) - server_refs = agent_data.get("mcp_servers", []) - if server_refs: - names = [ref["name"] for ref in server_refs if ref.get("name")] - if names: - from framework.loader.mcp_registry import MCPRegistry - - registry = MCPRegistry() - configs = registry.resolve_for_agent(include=names) - if configs: - self._tool_registry.load_registry_servers(configs) - mcp_loaded = True - logger.info( - "Loaded %d MCP servers from registry: %s", - len(configs), - names, - ) - except Exception: - logger.warning("Failed to load MCP servers from agent.json refs", exc_info=True) - - # 2. Legacy: mcp_servers.json file - if not mcp_loaded and mcp_config_path.exists(): - self._load_mcp_servers_from_config(mcp_config_path) - mcp_loaded = True - - # 3. Fallback: load all servers from global registry - if not mcp_loaded: - try: - from framework.loader.mcp_registry import MCPRegistry - - registry = MCPRegistry() - configs = registry.resolve_for_agent(profile="all") - if configs: - self._tool_registry.load_registry_servers(configs) - logger.info( - "Loaded %d MCP servers from global registry (fallback)", - len(configs), - ) - except Exception: - logger.warning("Failed to load MCP servers from global registry", exc_info=True) - - # Auto-discover registry-selected MCP servers from mcp_registry.json - self._load_registry_mcp_servers(agent_path) - - # Summary: how many tools were loaded from MCP servers? - _all_tools = self._tool_registry.get_tools() - logger.info( - "MCP loading complete: %d tools from tool registry", - len(_all_tools), - ) - if not _all_tools: - logger.warning( - "ZERO tools loaded from MCP servers. Workers will only have " - "synthetic tools (set_output, escalate). Check: " - "1) agent.json mcp_servers field, " - "2) ~/.hive/mcp_registry/installed.json, " - "3) MCP server process startup." - ) - + # MCP tools are loaded by McpRegistryStage in the pipeline during AgentHost.start() @staticmethod def _import_agent_module(agent_path: Path): """Import an agent package from its directory path. @@ -1695,228 +1620,148 @@ class AgentLoader: self._approval_callback = callback def _setup(self, event_bus=None) -> None: - """Set up runtime, LLM, and executor.""" - # Configure structured logging (auto-detects JSON vs human-readable) + """Set up runtime via pipeline stages. + + Builds a pipeline with the default stages (LLM, credentials, MCP, + skills) and passes it to AgentHost. The stages initialize during + ``AgentHost.start()`` and inject tools/LLM/credentials/skills. + """ from framework.observability import configure_logging + from framework.pipeline.stages.credential_resolver import CredentialResolverStage + from framework.pipeline.stages.llm_provider import LlmProviderStage + from framework.pipeline.stages.mcp_registry import McpRegistryStage + from framework.pipeline.stages.skill_registry import SkillRegistryStage + from framework.skills.config import SkillsConfig configure_logging(level="INFO", format="auto") - # Set up session context for tools (agent_id) + # Set up session context for tools agent_id = self.graph.id or "unknown" + self._tool_registry.set_session_context(agent_id=agent_id) - self._tool_registry.set_session_context( - agent_id=agent_id, - ) + # Read MCP server refs from agent.json + mcp_refs = [] + agent_json = self.agent_path / "agent.json" + if agent_json.exists(): + try: + import json as _json - # Create LLM provider - # Uses LiteLLM which auto-detects the provider from model name - # Skip if already injected (e.g. worker agents with a pre-built LLM) - if self._llm is not None: - pass # LLM already configured externally - elif self.mock_mode: - # Use mock LLM for testing without real API calls - from framework.llm.mock import MockLLMProvider + data = _json.loads(agent_json.read_text(encoding="utf-8")) + mcp_refs = data.get("mcp_servers", []) + except Exception: + pass - self._llm = MockLLMProvider(model=self.model) - else: - from framework.llm.litellm import LiteLLMProvider - - # Check if a subscription mode is configured - config = get_hive_config() - llm_config = config.get("llm", {}) - use_claude_code = llm_config.get("use_claude_code_subscription", False) - use_codex = llm_config.get("use_codex_subscription", False) - use_kimi_code = llm_config.get("use_kimi_code_subscription", False) - use_antigravity = llm_config.get("use_antigravity_subscription", False) - api_base = llm_config.get("api_base") - - api_key = None - if use_claude_code: - # Get OAuth token from Claude Code subscription - api_key = get_claude_code_token() - if not api_key: - logger.warning( - "Claude Code subscription configured but no token found. " - "Run 'claude' to authenticate, then try again." - ) - elif use_codex: - # Get OAuth token from Codex subscription - api_key = get_codex_token() - if not api_key: - logger.warning( - "Codex subscription configured but no token found. " - "Run 'codex' to authenticate, then try again." - ) - elif use_kimi_code: - # Get API key from Kimi Code CLI config (~/.kimi/config.toml) - api_key = get_kimi_code_token() - if not api_key: - logger.warning( - "Kimi Code subscription configured but no key found. " - "Run 'kimi /login' to authenticate, then try again." - ) - elif use_antigravity: - pass # AntigravityProvider handles credentials internally - - if api_key and use_claude_code: - # Use litellm's built-in Anthropic OAuth support. - # The lowercase "authorization" key triggers OAuth detection which - # adds the required anthropic-beta and browser-access headers. - self._llm = LiteLLMProvider( - model=self.model, - api_key=api_key, - api_base=api_base, - extra_headers={"authorization": f"Bearer {api_key}"}, - ) - elif api_key and use_codex: - # OpenAI Codex subscription routes through the ChatGPT backend - # (chatgpt.com/backend-api/codex/responses), NOT the standard - # OpenAI API. The consumer OAuth token lacks platform API scopes. - extra_headers: dict[str, str] = { - "Authorization": f"Bearer {api_key}", - "User-Agent": "CodexBar", - } - account_id = get_codex_account_id() - if account_id: - extra_headers["ChatGPT-Account-Id"] = account_id - self._llm = LiteLLMProvider( - model=self.model, - api_key=api_key, - api_base="https://chatgpt.com/backend-api/codex", - extra_headers=extra_headers, - store=False, - allowed_openai_params=["store"], - ) - elif api_key and use_kimi_code: - # Kimi Code subscription uses the Kimi coding API (OpenAI-compatible). - # The api_base is set automatically by LiteLLMProvider for kimi/ models. - self._llm = LiteLLMProvider( - model=self.model, - api_key=api_key, - api_base=api_base, - ) - elif use_antigravity: - # Direct OAuth to Google's internal Cloud Code Assist gateway. - # No local proxy required — AntigravityProvider handles token - # refresh and Gemini-format request/response conversion natively. - from framework.llm.antigravity import AntigravityProvider # noqa: PLC0415 - - provider = AntigravityProvider(model=self.model) - if not provider.has_credentials(): - print( - "Warning: Antigravity credentials not found. " - "Run: uv run python core/antigravity_auth.py auth account add" - ) - self._llm = provider - else: - # Local models (e.g. Ollama) don't need an API key - if self._is_local_model(self.model): - self._llm = LiteLLMProvider( - model=self.model, - api_base=api_base, - ) - else: - # Fall back to environment variable - # First check api_key_env_var from config (set by quickstart) - api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var( - self.model - ) - if api_key_env and os.environ.get(api_key_env): - self._llm = LiteLLMProvider( - model=self.model, - api_key=os.environ[api_key_env], - api_base=api_base, - ) - else: - # Fall back to credential store - api_key = self._get_api_key_from_credential_store() - if api_key: - self._llm = LiteLLMProvider( - model=self.model, api_key=api_key, api_base=api_base - ) - # Set env var so downstream code (e.g. cleanup LLM in - # node._extract_json) can also find it - if api_key_env: - os.environ[api_key_env] = api_key - elif api_key_env: - logger.warning( - "%s not set. LLM calls will fail. " - "Set it with: export %s=your-api-key", - api_key_env, - api_key_env, - ) - - # Fail fast if the agent needs an LLM but none was configured - if self._llm is None: - has_llm_nodes = any( - node.node_type in ("event_loop", "gcu") for node in self.graph.nodes - ) - if has_llm_nodes: - from framework.credentials.models import CredentialError - - if self._is_local_model(self.model): - raise CredentialError( - f"Failed to initialize LLM for local model '{self.model}'. " - f"Ensure your local LLM server is running " - f"(e.g. 'ollama serve' for Ollama)." - ) - api_key_env = self._get_api_key_env_var(self.model) - hint = ( - f"Set it with: export {api_key_env}=your-api-key" - if api_key_env - else "Configure an API key for your LLM provider." - ) - raise CredentialError(f"LLM API key not found for model '{self.model}'. {hint}") - - # Get tools for runtime - # (GCU and file tools are now registered via mcp_servers.json or MCP registry, - # not auto-registered here. Agents declare them in agent.json.) - tools = list(self._tool_registry.get_tools().values()) - tool_executor = self._tool_registry.get_executor() - - # Collect connected account info for system prompt injection - accounts_prompt = "" - accounts_data: list[dict] | None = None - tool_provider_map: dict[str, str] | None = None - try: - from aden_tools.credentials.store_adapter import CredentialStoreAdapter - - if self._credential_store is not None: - adapter = CredentialStoreAdapter(store=self._credential_store) - else: - adapter = CredentialStoreAdapter.default() - accounts_data = adapter.get_all_account_info() - tool_provider_map = adapter.get_tool_provider_map() - if accounts_data: - from framework.orchestrator.prompting import build_accounts_prompt - - accounts_prompt = build_accounts_prompt(accounts_data, tool_provider_map) - except Exception: - pass # Best-effort — agent works without account info - - # Skill configuration — the runtime handles discovery, loading, trust-gating and - # prompt rasterization. The runner just builds the config. - from framework.skills.config import SkillsConfig - from framework.skills.manager import SkillsManagerConfig - - skills_manager_config = SkillsManagerConfig( - skills_config=SkillsConfig.from_agent_vars( - default_skills=getattr(self, "_agent_default_skills", None), - skills=getattr(self, "_agent_skills", None), + # Build default pipeline stages + # Default infrastructure stages (always present) + pipeline_stages = [ + LlmProviderStage( + model=self.model, + mock_mode=self.mock_mode, + llm=self._llm, ), - project_root=self.agent_path, - interactive=self._interactive, + CredentialResolverStage( + credential_store=self._credential_store, + ), + McpRegistryStage( + server_refs=mcp_refs, + agent_path=self.agent_path, + tool_registry=self._tool_registry, + ), + SkillRegistryStage( + project_root=self.agent_path, + interactive=self._interactive, + skills_config=SkillsConfig.from_agent_vars( + default_skills=getattr(self, "_agent_default_skills", None), + skills=getattr(self, "_agent_skills", None), + ), + ), + ] + + # Merge user-configured stages from ~/.hive/configuration.json + from framework.config import get_hive_config + from framework.pipeline.registry import build_pipeline_from_config + + hive_config = get_hive_config() + user_stages_config = hive_config.get("pipeline", {}).get("stages", []) + if user_stages_config: + user_pipeline = build_pipeline_from_config(user_stages_config) + pipeline_stages.extend(user_pipeline.stages) + + # Merge agent-level overrides from agent.json pipeline field + if agent_json.exists(): + try: + agent_pipeline = ( + _json.loads(agent_json.read_text(encoding="utf-8")) + .get("pipeline", {}) + .get("stages", []) + ) + if agent_pipeline: + agent_stages = build_pipeline_from_config(agent_pipeline) + pipeline_stages.extend(agent_stages.stages) + except Exception: + pass + + self._setup_agent_runtime_with_pipeline( + pipeline_stages=pipeline_stages, + event_bus=event_bus, ) - self._setup_agent_runtime( - tools, - tool_executor, - accounts_prompt=accounts_prompt, - accounts_data=accounts_data, - tool_provider_map=tool_provider_map, - event_bus=event_bus, - skills_manager_config=skills_manager_config, + def _setup_agent_runtime_with_pipeline( + self, + pipeline_stages: list, + event_bus=None, + ) -> None: + """Create AgentHost with pipeline stages.""" + from framework.host.execution_manager import EntryPointSpec + from framework.orchestrator.checkpoint_config import CheckpointConfig + + entry_points = [] + if self.graph.entry_node: + entry_points.append( + EntryPointSpec( + id="default", + name="Default", + entry_node=self.graph.entry_node, + trigger_type="manual", + isolation_level="shared", + ), + ) + + log_store = RuntimeLogStore( + base_path=self._storage_path / "runtime_logs", ) + checkpoint_config = CheckpointConfig( + enabled=True, + checkpoint_on_node_start=False, + checkpoint_on_node_complete=True, + checkpoint_max_age_days=7, + async_checkpoint=True, + ) + + runtime_config = None + if self.runtime_config is not None: + from framework.host.agent_host import AgentRuntimeConfig + + if isinstance(self.runtime_config, AgentRuntimeConfig): + runtime_config = self.runtime_config + + + self._agent_runtime = create_agent_runtime( + graph=self.graph, + goal=self.goal, + storage_path=self._storage_path, + entry_points=entry_points, + llm=None, # Injected by LlmProviderStage + tools=[], # Injected by McpRegistryStage + tool_executor=None, # Injected by McpRegistryStage + runtime_log_store=log_store, + checkpoint_config=checkpoint_config, + config=runtime_config, + graph_id=self.graph.id or self.agent_path.name, + event_bus=event_bus, + pipeline_stages=pipeline_stages, + ) + self._agent_runtime.intro_message = self.intro_message def _get_api_key_env_var(self, model: str) -> str | None: """Get the environment variable name for the API key based on model name.""" @@ -2009,83 +1854,6 @@ class AgentLoader: ) return model.lower().startswith(LOCAL_PREFIXES) - def _setup_agent_runtime( - self, - tools: list, - tool_executor: Callable | None, - accounts_prompt: str = "", - accounts_data: list[dict] | None = None, - tool_provider_map: dict[str, str] | None = None, - event_bus=None, - skills_catalog_prompt: str = "", - protocols_prompt: str = "", - skill_dirs: list[str] | None = None, - skills_manager_config=None, - ) -> None: - """Set up multi-entry-point execution using AgentRuntime.""" - entry_points = [] - - # Always create a primary entry point for the graph's entry node. - # For multi-entry-point agents this ensures the primary path (e.g. - # user-facing rule setup) is reachable alongside async entry points. - if self.graph.entry_node: - entry_points.insert( - 0, - EntryPointSpec( - id="default", - name="Default", - entry_node=self.graph.entry_node, - trigger_type="manual", - isolation_level="shared", - ), - ) - - # Create AgentRuntime with all entry points - log_store = RuntimeLogStore(base_path=self._storage_path / "runtime_logs") - - # Enable checkpointing by default for resumable sessions - from framework.orchestrator.checkpoint_config import CheckpointConfig - - checkpoint_config = CheckpointConfig( - enabled=True, - checkpoint_on_node_start=False, # Only checkpoint after nodes complete - checkpoint_on_node_complete=True, - checkpoint_max_age_days=7, - async_checkpoint=True, # Non-blocking - ) - - # Handle runtime_config - only pass through if it's actually an AgentRuntimeConfig. - # Agents may export a RuntimeConfig (LLM settings) or queen-generated custom classes - # that would crash AgentRuntime if passed through. - runtime_config = None - if self.runtime_config is not None: - from framework.host.agent_host import AgentRuntimeConfig - - if isinstance(self.runtime_config, AgentRuntimeConfig): - runtime_config = self.runtime_config - - self._agent_runtime = create_agent_runtime( - graph=self.graph, - goal=self.goal, - storage_path=self._storage_path, - entry_points=entry_points, - llm=self._llm, - tools=tools, - tool_executor=tool_executor, - runtime_log_store=log_store, - checkpoint_config=checkpoint_config, - config=runtime_config, - graph_id=self.graph.id or self.agent_path.name, - accounts_prompt=accounts_prompt, - accounts_data=accounts_data, - tool_provider_map=tool_provider_map, - event_bus=event_bus, - skills_manager_config=skills_manager_config, - ) - - # Pass intro_message through for TUI display - self._agent_runtime.intro_message = self.intro_message - # ------------------------------------------------------------------ # Execution modes # diff --git a/core/framework/loader/tool_registry.py b/core/framework/loader/tool_registry.py index 164c2654..4c862e44 100644 --- a/core/framework/loader/tool_registry.py +++ b/core/framework/loader/tool_registry.py @@ -262,15 +262,21 @@ class ToolRegistry: is_error=False, ) + registry_ref = self + def executor(tool_use: ToolUse) -> ToolResult: - if tool_use.name not in self._tools: + # Check if credential files changed (lightweight dir listing). + # If new OAuth tokens appeared, restarts MCP servers to pick them up. + registry_ref.resync_mcp_servers_if_needed() + + if tool_use.name not in registry_ref._tools: return ToolResult( tool_use_id=tool_use.id, content=json.dumps({"error": f"Unknown tool: {tool_use.name}"}), is_error=True, ) - registered = self._tools[tool_use.name] + registered = registry_ref._tools[tool_use.name] try: result = registered.executor(tool_use.input) @@ -922,6 +928,11 @@ class ToolRegistry: clients and re-loads them so the new subprocess picks up the fresh credentials. + Note: Individual credential TTL/refresh is handled by the MCP server + process internally -- it resolves tokens from the credential store + on every tool call, not at startup. This method only handles the case + where entirely new credential files appear. + Returns True if a resync was performed, False otherwise. """ if not self._mcp_clients or self._mcp_config_path is None: diff --git a/core/framework/pipeline/stages/credential_resolver.py b/core/framework/pipeline/stages/credential_resolver.py index 333dbd66..b76df37f 100644 --- a/core/framework/pipeline/stages/credential_resolver.py +++ b/core/framework/pipeline/stages/credential_resolver.py @@ -1,9 +1,8 @@ """Credential resolver pipeline stage. -Resolves connected accounts from the credential store and builds -the ``accounts_prompt`` and ``tool_provider_map`` for system prompt -injection. Replaces the credential resolution block in -``AgentLoader._setup()`` (lines 1861-1879). +Resolves connected accounts at startup. Individual credential TTL/refresh +is handled by MCP server processes internally -- they resolve tokens from +the credential store on every tool call. """ from __future__ import annotations @@ -19,37 +18,41 @@ logger = logging.getLogger(__name__) @register("credential_resolver") class CredentialResolverStage(PipelineStage): - """Resolve connected accounts and inject into pipeline context.""" + """Resolve connected accounts for system prompt injection.""" - order = 40 # before MCP (tools need account info for routing) + order = 40 - def __init__(self, **kwargs: Any) -> None: - self._accounts_prompt = "" - self._accounts_data: list[dict] | None = None - self._tool_provider_map: dict[str, str] | None = None + def __init__(self, credential_store: Any = None, **kwargs: Any) -> None: + self._credential_store = credential_store + self.accounts_prompt = "" + self.accounts_data: list[dict] | None = None + self.tool_provider_map: dict[str, str] | None = None async def initialize(self) -> None: - """Resolve credentials from the store.""" try: from aden_tools.credentials.store_adapter import ( CredentialStoreAdapter, ) from framework.orchestrator.prompting import build_accounts_prompt - adapter = CredentialStoreAdapter.default() - self._accounts_data = adapter.get_all_account_info() - self._tool_provider_map = adapter.get_tool_provider_map() - if self._accounts_data: - self._accounts_prompt = build_accounts_prompt( - self._accounts_data, - self._tool_provider_map, + if self._credential_store is not None: + adapter = CredentialStoreAdapter(store=self._credential_store) + else: + adapter = CredentialStoreAdapter.default() + self.accounts_data = adapter.get_all_account_info() + self.tool_provider_map = adapter.get_tool_provider_map() + if self.accounts_data: + self.accounts_prompt = build_accounts_prompt( + self.accounts_data, self.tool_provider_map, ) + logger.info( + "[pipeline] CredentialResolverStage: %d accounts", + len(self.accounts_data or []), + ) except Exception: - pass # best-effort -- agent works without account info + logger.debug( + "Credential resolution failed (non-fatal)", exc_info=True, + ) async def process(self, ctx: PipelineContext) -> PipelineResult: - """Inject credential info into pipeline context.""" - ctx.metadata["accounts_prompt"] = self._accounts_prompt - ctx.metadata["accounts_data"] = self._accounts_data - ctx.metadata["tool_provider_map"] = self._tool_provider_map return PipelineResult(action="continue") diff --git a/core/framework/pipeline/stages/llm_provider.py b/core/framework/pipeline/stages/llm_provider.py index d3463401..899342f2 100644 --- a/core/framework/pipeline/stages/llm_provider.py +++ b/core/framework/pipeline/stages/llm_provider.py @@ -1,17 +1,7 @@ """LLM provider pipeline stage. -Resolves the LLM provider (model, API key, OAuth token) from the -global config and injects it into the pipeline context. Replaces -the 150-line provider resolution block in ``AgentLoader._setup()``. - -Supports all auth methods: -- Claude Code subscription (OAuth token from ~/.claude/.credentials.json) -- Codex subscription (Keychain / ~/.codex/auth.json) -- Kimi Code subscription (~/.kimi/config.toml) -- Antigravity (Google Cloud Code Assist OAuth) -- Environment variable (ANTHROPIC_API_KEY, OPENAI_API_KEY, etc.) -- Key pool (multiple keys with rotation) -- Local models (Ollama, no key needed) +Resolves the LLM provider from global config. This is the ONLY place +the LLM gets created for worker agents. """ from __future__ import annotations @@ -27,84 +17,79 @@ logger = logging.getLogger(__name__) @register("llm_provider") class LlmProviderStage(PipelineStage): - """Resolve LLM provider and inject into pipeline context.""" + """Resolve LLM provider and make it available.""" - order = 10 # earliest -- everything else depends on having an LLM + order = 10 def __init__( self, model: str | None = None, mock_mode: bool = False, + llm: Any = None, **kwargs: Any, ) -> None: self._model = model self._mock_mode = mock_mode - self._llm: Any = None + self.llm = llm # Pre-injected LLM (e.g. from session) async def initialize(self) -> None: - """Resolve and create the LLM provider.""" - from framework.config import get_api_key, get_api_keys, get_hive_config, get_preferred_model + if self.llm is not None: + return # Already injected + + from framework.config import ( + get_api_key, + get_api_keys, + get_hive_config, + get_preferred_model, + ) model = self._model or get_preferred_model() if self._mock_mode: from framework.llm.mock import MockLLMProvider - self._llm = MockLLMProvider(model=model) + self.llm = MockLLMProvider(model=model) return config = get_hive_config() llm_config = config.get("llm", {}) api_base = llm_config.get("api_base") - api_key = get_api_key() - api_keys = get_api_keys() - from framework.llm.litellm import LiteLLMProvider - - # Check for Antigravity (special provider, not LiteLLM) + # Check for Antigravity (special provider) if llm_config.get("use_antigravity_subscription"): try: from framework.llm.antigravity import AntigravityProvider provider = AntigravityProvider(model=model) if provider.has_credentials(): - self._llm = provider + self.llm = provider + logger.info("[pipeline] LlmProviderStage: Antigravity") return except Exception: pass - # Key pool or single key + from framework.llm.litellm import LiteLLMProvider + + api_key = get_api_key() + api_keys = get_api_keys() + if api_keys and len(api_keys) > 1: - self._llm = LiteLLMProvider( - model=model, - api_keys=api_keys, - api_base=api_base, + self.llm = LiteLLMProvider( + model=model, api_keys=api_keys, api_base=api_base, ) elif api_key: - # Detect OAuth subscriptions for special headers - is_claude_oauth = api_key.startswith("sk-ant-oat") - if is_claude_oauth: - self._llm = LiteLLMProvider( - model=model, - api_key=api_key, - api_base=api_base, - extra_headers={"authorization": f"Bearer {api_key}"}, - ) - else: - self._llm = LiteLLMProvider( - model=model, - api_key=api_key, - api_base=api_base, - ) - else: - # No key -- local models or env var fallback - self._llm = LiteLLMProvider( - model=model, - api_base=api_base, + extra = {} + if api_key.startswith("sk-ant-oat"): + extra["extra_headers"] = { + "authorization": f"Bearer {api_key}" + } + self.llm = LiteLLMProvider( + model=model, api_key=api_key, api_base=api_base, **extra, ) + else: + self.llm = LiteLLMProvider(model=model, api_base=api_base) + + logger.info("[pipeline] LlmProviderStage: %s", model) async def process(self, ctx: PipelineContext) -> PipelineResult: - """Inject LLM provider into pipeline context.""" - if self._llm: - ctx.metadata["llm"] = self._llm return PipelineResult(action="continue") diff --git a/core/framework/pipeline/stages/mcp_registry.py b/core/framework/pipeline/stages/mcp_registry.py index 215781d0..989cfd98 100644 --- a/core/framework/pipeline/stages/mcp_registry.py +++ b/core/framework/pipeline/stages/mcp_registry.py @@ -1,21 +1,13 @@ """MCP registry pipeline stage. Resolves MCP server references from the agent config against the global -registry (``~/.hive/mcp_registry/installed.json``) and registers tools. -Replaces the per-agent ``mcp_servers.json`` pattern with declarative -name-based references. - -Agent config declares servers by name:: - - {"mcp_servers": [{"name": "hive-tools"}, {"name": "gcu-tools"}]} - -The stage resolves each name from the global registry at ``initialize()`` -time and injects the resolved ``ToolRegistry`` into the pipeline context. +registry and registers tools. This is the ONLY place MCP tools get loaded. """ from __future__ import annotations import logging +from dataclasses import asdict from pathlib import Path from typing import Any @@ -27,12 +19,7 @@ logger = logging.getLogger(__name__) @register("mcp_registry") class McpRegistryStage(PipelineStage): - """Resolve MCP tools from the global registry. - - On ``initialize()``, connects to MCP servers declared in the agent - config. On ``process()``, injects ``tools`` and ``tool_executor`` - into the pipeline context metadata for downstream consumption. - """ + """Resolve MCP tools from the global registry.""" order = 50 @@ -40,55 +27,66 @@ class McpRegistryStage(PipelineStage): self, server_refs: list[dict[str, Any]] | None = None, agent_path: str | Path | None = None, + tool_registry: Any = None, **kwargs: Any, ) -> None: - """ - Args: - server_refs: List of ``{"name": "server-name"}`` dicts from - the agent config's ``mcp_servers`` field. - agent_path: Path to the agent directory. If a ``mcp_servers.json`` - file exists there, it's loaded as a fallback. - """ self._server_refs = server_refs or [] self._agent_path = Path(agent_path) if agent_path else None - self._tool_registry: Any = None + self._tool_registry = tool_registry async def initialize(self) -> None: """Connect to MCP servers and discover tools.""" + if self._tool_registry is None: + from framework.loader.tool_registry import ToolRegistry + + self._tool_registry = ToolRegistry() + from framework.loader.mcp_registry import MCPRegistry - from framework.loader.tool_registry import ToolRegistry - self._tool_registry = ToolRegistry() registry = MCPRegistry() + mcp_loaded = False - # 1. Resolve named server refs from global registry + # 1. From agent.json mcp_servers refs if self._server_refs: names = [ref["name"] for ref in self._server_refs if ref.get("name")] if names: configs = registry.resolve_for_agent(include=names) if configs: - self._tool_registry.load_registry_servers(configs) + self._tool_registry.load_registry_servers( + [asdict(c) for c in configs] + ) + mcp_loaded = True logger.info( - "McpRegistryStage: resolved %d servers from registry", + "[pipeline] McpRegistryStage: loaded %d servers: %s", len(configs), + names, ) - # 2. Fallback: load mcp_servers.json if it exists (backward compat) - if self._agent_path: + # 2. Legacy: mcp_servers.json + if not mcp_loaded and self._agent_path: mcp_json = self._agent_path / "mcp_servers.json" if mcp_json.exists(): self._tool_registry.load_mcp_config(mcp_json) + mcp_loaded = True + + # 3. Fallback: all servers from global registry + if not mcp_loaded: + configs = registry.resolve_for_agent(profile="all") + if configs: + self._tool_registry.load_registry_servers( + [asdict(c) for c in configs] + ) logger.info( - "McpRegistryStage: loaded mcp_servers.json from %s", - self._agent_path.name, + "[pipeline] McpRegistryStage: loaded %d servers (fallback)", + len(configs), ) + total = len(self._tool_registry.get_tools()) + logger.info("[pipeline] McpRegistryStage: %d tools available", total) + async def process(self, ctx: PipelineContext) -> PipelineResult: - """Inject resolved tools into pipeline context.""" - if self._tool_registry: - ctx.metadata["tool_registry"] = self._tool_registry - ctx.metadata["tools"] = list( - self._tool_registry.get_tools().values() - ) - ctx.metadata["tool_executor"] = self._tool_registry.get_executor() return PipelineResult(action="continue") + + @property + def tool_registry(self): + return self._tool_registry diff --git a/core/framework/pipeline/stages/skill_registry.py b/core/framework/pipeline/stages/skill_registry.py index e9ea7ebf..71a73a69 100644 --- a/core/framework/pipeline/stages/skill_registry.py +++ b/core/framework/pipeline/stages/skill_registry.py @@ -1,12 +1,6 @@ """Skill registry pipeline stage. -Discovers and loads skills, injects skill prompts into the pipeline -context. Replaces the standalone ``SkillsManager`` initialization -in ``AgentHost.__init__()``. - -Supports hot-reload: when ``SKILL.md`` files change on disk, the -cached prompts are rebuilt and the next pipeline execution picks -up the new values. +Discovers and loads skills. This is the ONLY place skills get loaded. """ from __future__ import annotations @@ -23,7 +17,7 @@ logger = logging.getLogger(__name__) @register("skill_registry") class SkillRegistryStage(PipelineStage): - """Discover skills and inject prompts into pipeline context.""" + """Discover skills and provide prompts.""" order = 60 @@ -31,34 +25,31 @@ class SkillRegistryStage(PipelineStage): self, project_root: str | Path | None = None, interactive: bool = True, + skills_config: Any = None, **kwargs: Any, ) -> None: self._project_root = Path(project_root) if project_root else None self._interactive = interactive - self._skills_manager: Any = None + self._skills_config = skills_config + self.skills_manager: Any = None async def initialize(self) -> None: - """Discover skills and start hot-reload watcher.""" + from framework.skills.config import SkillsConfig from framework.skills.manager import SkillsManager, SkillsManagerConfig config = SkillsManagerConfig( + skills_config=self._skills_config or SkillsConfig(), project_root=self._project_root, interactive=self._interactive, ) - self._skills_manager = SkillsManager(config) - self._skills_manager.load() - await self._skills_manager.start_watching() + self.skills_manager = SkillsManager(config) + self.skills_manager.load() + await self.skills_manager.start_watching() + logger.info( + "[pipeline] SkillRegistryStage: catalog=%d chars, protocols=%d chars", + len(self.skills_manager.skills_catalog_prompt), + len(self.skills_manager.protocols_prompt), + ) async def process(self, ctx: PipelineContext) -> PipelineResult: - """Inject skill prompts into pipeline context.""" - if self._skills_manager: - ctx.metadata["skills_catalog_prompt"] = ( - self._skills_manager.skills_catalog_prompt - ) - ctx.metadata["protocols_prompt"] = ( - self._skills_manager.protocols_prompt - ) - ctx.metadata["skill_dirs"] = ( - self._skills_manager.allowlisted_dirs - ) return PipelineResult(action="continue")