Files
hive/core/framework/loader/agent_loader.py
T
2026-04-28 19:16:35 -07:00

2109 lines
75 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Agent Runner - loads and runs exported agents."""
import json
import logging
import os
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import UTC
from pathlib import Path
from typing import Any
from framework.config import get_hive_config, get_preferred_model
from framework.credentials.validation import (
ensure_credential_key_env as _ensure_credential_key_env,
)
from framework.host.agent_host import AgentHost, AgentRuntimeConfig
from framework.host.execution_manager import EntryPointSpec
from framework.llm.provider import LLMProvider, Tool
from framework.loader.preload_validation import run_preload_validation
from framework.loader.tool_registry import ToolRegistry
from framework.orchestrator import Goal
from framework.orchestrator.edge import (
EdgeCondition,
EdgeSpec,
GraphSpec,
)
from framework.orchestrator.node import NodeSpec
from framework.orchestrator.orchestrator import ExecutionResult
logger = logging.getLogger(__name__)
CLAUDE_CREDENTIALS_FILE = Path.home() / ".claude" / ".credentials.json"
CLAUDE_OAUTH_TOKEN_URL = "https://console.anthropic.com/v1/oauth/token"
CLAUDE_OAUTH_CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
CLAUDE_KEYCHAIN_SERVICE = "Claude Code-credentials"
# Buffer in seconds before token expiry to trigger a proactive refresh
_TOKEN_REFRESH_BUFFER_SECS = 300 # 5 minutes
# Codex (OpenAI) subscription auth
CODEX_AUTH_FILE = Path.home() / ".codex" / "auth.json"
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
CODEX_KEYCHAIN_SERVICE = "Codex Auth"
_CODEX_TOKEN_LIFETIME_SECS = 3600 # 1 hour (no explicit expiry field)
def _read_claude_keychain() -> dict | None:
"""Read Claude Code credentials from macOS Keychain.
Returns the parsed JSON dict, or None if not on macOS or entry missing.
"""
import getpass
import platform
import subprocess
if platform.system() != "Darwin":
return None
try:
account = getpass.getuser()
result = subprocess.run(
[
"security",
"find-generic-password",
"-s",
CLAUDE_KEYCHAIN_SERVICE,
"-a",
account,
"-w",
],
capture_output=True,
encoding="utf-8",
timeout=5,
)
if result.returncode != 0:
return None
raw = result.stdout.strip()
if not raw:
return None
return json.loads(raw)
except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError) as exc:
logger.debug("Claude keychain read failed: %s", exc)
return None
def _save_claude_keychain(creds: dict) -> bool:
"""Write Claude Code credentials to macOS Keychain. Returns True on success."""
import getpass
import platform
import subprocess
if platform.system() != "Darwin":
return False
try:
account = getpass.getuser()
data = json.dumps(creds)
result = subprocess.run(
[
"security",
"add-generic-password",
"-U",
"-s",
CLAUDE_KEYCHAIN_SERVICE,
"-a",
account,
"-w",
data,
],
capture_output=True,
timeout=5,
)
return result.returncode == 0
except (subprocess.TimeoutExpired, OSError) as exc:
logger.debug("Claude keychain write failed: %s", exc)
return False
def _read_claude_credentials() -> dict | None:
"""Read Claude Code credentials from Keychain (macOS) or file (Linux/Windows)."""
# Try macOS Keychain first
creds = _read_claude_keychain()
if creds:
return creds
# Fall back to file
if not CLAUDE_CREDENTIALS_FILE.exists():
return None
try:
with open(CLAUDE_CREDENTIALS_FILE, encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
return None
def _refresh_claude_code_token(refresh_token: str) -> dict | None:
"""Refresh the Claude Code OAuth token using the refresh token.
POSTs to the Anthropic OAuth token endpoint with form-urlencoded data
(per OAuth 2.0 RFC 6749 Section 4.1.3).
Returns:
Dict with new token data (access_token, refresh_token, expires_in)
on success, None on failure.
"""
import urllib.error
import urllib.parse
import urllib.request
data = urllib.parse.urlencode(
{
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": CLAUDE_OAUTH_CLIENT_ID,
}
).encode("utf-8")
req = urllib.request.Request(
CLAUDE_OAUTH_TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
logger.debug("Claude Code token refresh failed: %s", exc)
return None
def _save_refreshed_credentials(token_data: dict) -> None:
"""Write refreshed token data back to Keychain (macOS) or credentials file."""
import time
creds = _read_claude_credentials()
if not creds:
return
try:
oauth = creds.get("claudeAiOauth", {})
oauth["accessToken"] = token_data["access_token"]
if "refresh_token" in token_data:
oauth["refreshToken"] = token_data["refresh_token"]
if "expires_in" in token_data:
oauth["expiresAt"] = int((time.time() + token_data["expires_in"]) * 1000)
creds["claudeAiOauth"] = oauth
# Try Keychain first (macOS), fall back to file
if _save_claude_keychain(creds):
logger.debug("Claude Code credentials refreshed in Keychain")
return
if CLAUDE_CREDENTIALS_FILE.exists():
with open(CLAUDE_CREDENTIALS_FILE, "w", encoding="utf-8") as f:
json.dump(creds, f, indent=2)
logger.debug("Claude Code credentials refreshed in file")
except (json.JSONDecodeError, OSError, KeyError) as exc:
logger.debug("Failed to save refreshed credentials: %s", exc)
def get_claude_code_token() -> str | None:
"""Get the OAuth token from Claude Code subscription with auto-refresh.
Reads from macOS Keychain (on Darwin) or ~/.claude/.credentials.json
(on Linux/Windows), as created by the Claude Code CLI.
If the token is expired or close to expiry, attempts an automatic
refresh using the stored refresh token.
Returns:
The access token if available, None otherwise.
"""
import time
creds = _read_claude_credentials()
if not creds:
return None
oauth = creds.get("claudeAiOauth", {})
access_token = oauth.get("accessToken")
if not access_token:
return None
# Check token expiry (expiresAt is in milliseconds)
expires_at_ms = oauth.get("expiresAt", 0)
now_ms = int(time.time() * 1000)
buffer_ms = _TOKEN_REFRESH_BUFFER_SECS * 1000
if expires_at_ms > now_ms + buffer_ms:
# Token is still valid
return access_token
# Token is expired or near expiry — attempt refresh
refresh_token = oauth.get("refreshToken")
if not refresh_token:
logger.warning("Claude Code token expired and no refresh token available")
return access_token # Return expired token; it may still work briefly
logger.info("Claude Code token expired or near expiry, refreshing...")
token_data = _refresh_claude_code_token(refresh_token)
if token_data and "access_token" in token_data:
_save_refreshed_credentials(token_data)
return token_data["access_token"]
# Refresh failed — return the existing token and warn
logger.warning("Claude Code token refresh failed. Run 'claude' to re-authenticate.")
return access_token
# ---------------------------------------------------------------------------
# Codex (OpenAI) subscription token helpers
# ---------------------------------------------------------------------------
def _get_codex_keychain_account() -> str:
"""Compute the macOS Keychain account name used by the Codex CLI.
The Codex CLI stores credentials under the account
``cli|<sha256(~/.codex)[:16]>`` in the ``Codex Auth`` service.
"""
import hashlib
codex_dir = str(Path.home() / ".codex")
digest = hashlib.sha256(codex_dir.encode()).hexdigest()[:16]
return f"cli|{digest}"
def _read_codex_keychain() -> dict | None:
"""Read Codex auth data from macOS Keychain (macOS only).
Returns the parsed JSON from the Keychain entry, or None if not
available (wrong platform, entry missing, etc.).
"""
import platform
import subprocess
if platform.system() != "Darwin":
return None
try:
account = _get_codex_keychain_account()
result = subprocess.run(
[
"security",
"find-generic-password",
"-s",
CODEX_KEYCHAIN_SERVICE,
"-a",
account,
"-w",
],
capture_output=True,
encoding="utf-8",
timeout=5,
)
if result.returncode != 0:
return None
raw = result.stdout.strip()
if not raw:
return None
return json.loads(raw)
except (subprocess.TimeoutExpired, json.JSONDecodeError, OSError) as exc:
logger.debug("Codex keychain read failed: %s", exc)
return None
def _read_codex_auth_file() -> dict | None:
"""Read Codex auth data from ~/.codex/auth.json (fallback)."""
if not CODEX_AUTH_FILE.exists():
return None
try:
with open(CODEX_AUTH_FILE, encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
return None
def _is_codex_token_expired(auth_data: dict) -> bool:
"""Check whether the Codex token is expired or close to expiry.
The Codex auth.json has no explicit ``expiresAt`` field, so we infer
expiry as ``last_refresh + _CODEX_TOKEN_LIFETIME_SECS``. Falls back
to the file mtime when ``last_refresh`` is absent.
"""
import time
from datetime import datetime
now = time.time()
last_refresh = auth_data.get("last_refresh")
if last_refresh is None:
# Fall back to file modification time
try:
last_refresh = CODEX_AUTH_FILE.stat().st_mtime
except OSError:
# Cannot determine age — assume expired
return True
elif isinstance(last_refresh, str):
# Codex stores last_refresh as an ISO 8601 timestamp string —
# convert to Unix epoch float for arithmetic.
try:
last_refresh = datetime.fromisoformat(last_refresh.replace("Z", "+00:00")).timestamp()
except (ValueError, TypeError):
return True
expires_at = last_refresh + _CODEX_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
def _refresh_codex_token(refresh_token: str) -> dict | None:
"""Refresh the Codex OAuth token using the refresh token.
POSTs to the OpenAI auth endpoint with form-urlencoded data.
Returns:
Dict with new token data on success, None on failure.
"""
import urllib.error
import urllib.parse
import urllib.request
data = urllib.parse.urlencode(
{
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": CODEX_OAUTH_CLIENT_ID,
}
).encode("utf-8")
req = urllib.request.Request(
CODEX_OAUTH_TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
logger.debug("Codex token refresh failed: %s", exc)
return None
def _save_refreshed_codex_credentials(auth_data: dict, token_data: dict) -> None:
"""Write refreshed tokens back to ~/.codex/auth.json only (not Keychain).
The Codex CLI manages its own Keychain entries, so we only update the
file-based credentials.
"""
from datetime import datetime
try:
tokens = auth_data.get("tokens", {})
tokens["access_token"] = token_data["access_token"]
if "refresh_token" in token_data:
tokens["refresh_token"] = token_data["refresh_token"]
if "id_token" in token_data:
tokens["id_token"] = token_data["id_token"]
auth_data["tokens"] = tokens
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
CODEX_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True, mode=0o700)
fd = os.open(CODEX_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(auth_data, f, indent=2)
logger.debug("Codex credentials refreshed successfully")
except (OSError, KeyError) as exc:
logger.debug("Failed to save refreshed Codex credentials: %s", exc)
def get_codex_token() -> str | None:
"""Get the OAuth token from Codex subscription with auto-refresh.
Reads from macOS Keychain first, then falls back to
``~/.codex/auth.json``. If the token is expired or close to
expiry, attempts an automatic refresh.
Returns:
The access token if available, None otherwise.
"""
# Try Keychain first, then file
auth_data = _read_codex_keychain() or _read_codex_auth_file()
if not auth_data:
return None
tokens = auth_data.get("tokens", {})
access_token = tokens.get("access_token")
if not access_token:
return None
# Check if token is still valid
if not _is_codex_token_expired(auth_data):
return access_token
# Token is expired or near expiry — attempt refresh
refresh_token = tokens.get("refresh_token")
if not refresh_token:
logger.warning("Codex token expired and no refresh token available")
return access_token # Return expired token; it may still work briefly
logger.info("Codex token expired or near expiry, refreshing...")
token_data = _refresh_codex_token(refresh_token)
if token_data and "access_token" in token_data:
_save_refreshed_codex_credentials(auth_data, token_data)
return token_data["access_token"]
# Refresh failed — return the existing token and warn
logger.warning("Codex token refresh failed. Run 'codex' to re-authenticate.")
return access_token
def _get_account_id_from_jwt(access_token: str) -> str | None:
"""Extract the ChatGPT account_id from the access token JWT.
The OpenAI access token JWT contains a claim at
``https://api.openai.com/auth`` with a ``chatgpt_account_id`` field.
This is used as a fallback when the auth.json doesn't store the
account_id explicitly.
"""
import base64
try:
parts = access_token.split(".")
if len(parts) != 3:
return None
payload = parts[1]
# Add base64 padding
padding = 4 - len(payload) % 4
if padding != 4:
payload += "=" * padding
decoded = base64.urlsafe_b64decode(payload)
claims = json.loads(decoded)
auth = claims.get("https://api.openai.com/auth")
if isinstance(auth, dict):
account_id = auth.get("chatgpt_account_id")
if isinstance(account_id, str) and account_id:
return account_id
except Exception:
pass
return None
def get_codex_account_id() -> str | None:
"""Extract the account ID from Codex auth data for the ChatGPT-Account-Id header.
Checks the ``tokens.account_id`` field first, then falls back to
decoding the account ID from the access token JWT.
Returns:
The account_id string if available, None otherwise.
"""
auth_data = _read_codex_keychain() or _read_codex_auth_file()
if not auth_data:
return None
tokens = auth_data.get("tokens", {})
account_id = tokens.get("account_id")
if account_id:
return account_id
# Fallback: extract from JWT
access_token = tokens.get("access_token")
if access_token:
return _get_account_id_from_jwt(access_token)
return None
# ---------------------------------------------------------------------------
# Kimi Code subscription token helpers
# ---------------------------------------------------------------------------
def get_kimi_code_token() -> str | None:
"""Get the API key from a Kimi Code CLI installation.
Reads the API key from ``~/.kimi/config.toml``, which is created when
the user runs ``kimi /login`` in the Kimi Code CLI.
Returns:
The API key if available, None otherwise.
"""
import tomllib
config_path = Path.home() / ".kimi" / "config.toml"
if not config_path.exists():
return None
try:
with open(config_path, "rb") as f:
config = tomllib.load(f)
providers = config.get("providers", {})
# kimi-cli stores credentials under providers.kimi-for-coding
for provider_cfg in providers.values():
if isinstance(provider_cfg, dict):
key = provider_cfg.get("api_key")
if key:
return key
except Exception:
pass
return None
# ---------------------------------------------------------------------------
# Antigravity subscription token helpers
# ---------------------------------------------------------------------------
# Antigravity IDE (native macOS/Linux app) stores OAuth tokens in its
# VSCode-style SQLite state database under the key
# "antigravityUnifiedStateSync.oauthToken" as a base64-encoded protobuf blob.
ANTIGRAVITY_IDE_STATE_DB = (
Path.home() / "Library" / "Application Support" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
)
# Linux fallback for the IDE state DB
ANTIGRAVITY_IDE_STATE_DB_LINUX = Path.home() / ".config" / "Antigravity" / "User" / "globalStorage" / "state.vscdb"
# Antigravity credentials stored by native OAuth implementation
from framework.config import HIVE_HOME as _HIVE_HOME
ANTIGRAVITY_AUTH_FILE = _HIVE_HOME / "antigravity-accounts.json"
ANTIGRAVITY_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
_ANTIGRAVITY_TOKEN_LIFETIME_SECS = 3600 # Google access tokens expire in 1 hour
_ANTIGRAVITY_IDE_STATE_DB_KEY = "antigravityUnifiedStateSync.oauthToken"
def _read_antigravity_ide_credentials() -> dict | None:
"""Read credentials from the Antigravity IDE's SQLite state database.
The Antigravity desktop IDE (VSCode-based) stores its OAuth token as a
base64-encoded protobuf blob in a SQLite database. The access token is
a standard Google OAuth ``ya29.*`` bearer token.
Returns:
Dict with ``accessToken`` and optionally ``refreshToken`` keys,
plus ``_source: "ide"`` to skip file-based save on refresh.
Returns None if the database is absent or the key is not found.
"""
import re
import sqlite3
for db_path in (ANTIGRAVITY_IDE_STATE_DB, ANTIGRAVITY_IDE_STATE_DB_LINUX):
if not db_path.exists():
continue
try:
con = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
try:
row = con.execute(
"SELECT value FROM ItemTable WHERE key = ?",
(_ANTIGRAVITY_IDE_STATE_DB_KEY,),
).fetchone()
finally:
con.close()
if not row:
continue
import base64
blob = base64.b64decode(row[0])
# The protobuf blob contains the access token (ya29.*) and
# refresh token (1//*) as length-prefixed UTF-8 strings.
# Decode the inner base64 layer and extract with regex.
inner_b64_candidates = re.findall(rb"[A-Za-z0-9+/=_\-]{40,}", blob)
access_token: str | None = None
refresh_token: str | None = None
for candidate in inner_b64_candidates:
try:
padded = candidate + b"=" * (-len(candidate) % 4)
inner = base64.urlsafe_b64decode(padded)
except Exception:
continue
if not access_token:
m = re.search(rb"ya29\.[A-Za-z0-9_\-\.]+", inner)
if m:
access_token = m.group(0).decode("ascii")
if not refresh_token:
m = re.search(rb"1//[A-Za-z0-9_\-\.]+", inner)
if m:
refresh_token = m.group(0).decode("ascii")
if access_token and refresh_token:
break
if access_token:
return {
"accounts": [
{
"accessToken": access_token,
"refreshToken": refresh_token or "",
}
],
"_source": "ide",
"_db_path": str(db_path),
}
except Exception as exc:
logger.debug("Failed to read Antigravity IDE state DB: %s", exc)
continue
return None
def _read_antigravity_credentials() -> dict | None:
"""Read Antigravity auth data from all supported credential sources.
Checks in order:
1. Antigravity IDE SQLite state database (native macOS/Linux app)
2. Native OAuth credentials file (~/.hive/antigravity-accounts.json)
Returns:
Auth data dict with an ``accounts`` list on success, None otherwise.
"""
# 1. Native Antigravity IDE (primary on macOS)
ide_creds = _read_antigravity_ide_credentials()
if ide_creds:
return ide_creds
# 2. Native OAuth credentials file
if ANTIGRAVITY_AUTH_FILE.exists():
try:
with open(ANTIGRAVITY_AUTH_FILE, encoding="utf-8") as f:
data = json.load(f)
accounts = data.get("accounts", [])
if accounts and isinstance(accounts[0], dict):
return data
except (json.JSONDecodeError, OSError):
pass
return None
def _is_antigravity_token_expired(auth_data: dict) -> bool:
"""Check whether the Antigravity access token is expired or near expiry.
For IDE-sourced credentials: uses the state DB's mtime as last_refresh
since the IDE keeps the DB fresh while it's running.
For JSON-sourced credentials: uses the ``last_refresh`` field or file mtime.
"""
import time
from datetime import datetime
now = time.time()
if auth_data.get("_source") == "ide":
# The IDE refreshes tokens automatically while running.
# Use the DB file's mtime as a proxy for when the token was last updated.
try:
db_path = Path(auth_data.get("_db_path", str(ANTIGRAVITY_IDE_STATE_DB)))
last_refresh: float = db_path.stat().st_mtime
except OSError:
return True
expires_at = last_refresh + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
last_refresh_val: float | str | None = auth_data.get("last_refresh")
if last_refresh_val is None:
try:
last_refresh_val = ANTIGRAVITY_AUTH_FILE.stat().st_mtime
except OSError:
return True
elif isinstance(last_refresh_val, str):
try:
last_refresh_val = datetime.fromisoformat(last_refresh_val.replace("Z", "+00:00")).timestamp()
except (ValueError, TypeError):
return True
expires_at = float(last_refresh_val) + _ANTIGRAVITY_TOKEN_LIFETIME_SECS
return now >= (expires_at - _TOKEN_REFRESH_BUFFER_SECS)
def _refresh_antigravity_token(refresh_token: str) -> dict | None:
"""Refresh the Antigravity access token via Google OAuth.
POSTs form-encoded ``grant_type=refresh_token`` to the Google token
endpoint using Antigravity's public OAuth client ID.
Returns:
Parsed response dict (containing ``access_token``) on success,
None on any error.
"""
import urllib.error
import urllib.parse
import urllib.request
from framework.config import get_antigravity_client_id, get_antigravity_client_secret
client_id = get_antigravity_client_id()
client_secret = get_antigravity_client_secret()
params: dict = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
params["client_secret"] = client_secret
data = urllib.parse.urlencode(params).encode("utf-8")
req = urllib.request.Request(
ANTIGRAVITY_OAUTH_TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp: # noqa: S310
return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, TimeoutError, OSError) as exc:
logger.debug("Antigravity token refresh failed: %s", exc)
return None
def _save_refreshed_antigravity_credentials(auth_data: dict, token_data: dict) -> None:
"""Write refreshed tokens back to the Antigravity JSON credentials file.
Skipped for IDE-sourced credentials (the IDE manages its own DB).
Updates ``accounts[0].accessToken`` (and ``refreshToken`` if present),
then persists ``last_refresh`` as an ISO-8601 UTC string.
"""
from datetime import datetime
# IDE manages its own state — we do not write back to its SQLite DB
if auth_data.get("_source") == "ide":
return
try:
accounts = auth_data.get("accounts", [])
if not accounts:
return
account = accounts[0]
account["accessToken"] = token_data["access_token"]
if "refresh_token" in token_data:
account["refreshToken"] = token_data["refresh_token"]
auth_data["accounts"] = accounts
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
ANTIGRAVITY_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(ANTIGRAVITY_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(auth_data, f, indent=2)
logger.debug("Antigravity credentials refreshed and saved")
except (OSError, KeyError) as exc:
logger.debug("Failed to save refreshed Antigravity credentials: %s", exc)
def get_antigravity_token() -> str | None:
"""Get the OAuth access token from an Antigravity subscription.
Credential sources checked in order:
1. Antigravity IDE SQLite state DB (native app, macOS/Linux)
2. antigravity-auth CLI JSON file
For IDE credentials the token is read directly (the IDE refreshes it
automatically while running). For JSON credentials an automatic OAuth
refresh is attempted when the token is near expiry.
Returns:
The ``ya29.*`` Google OAuth access token, or None if unavailable.
"""
auth_data = _read_antigravity_credentials()
if not auth_data:
return None
accounts = auth_data.get("accounts", [])
if not accounts:
return None
account = accounts[0]
access_token = account.get("accessToken")
if not access_token:
return None
if not _is_antigravity_token_expired(auth_data):
return access_token
# Token is expired or near expiry — attempt a refresh
refresh_token = account.get("refreshToken")
if not refresh_token:
logger.warning(
"Antigravity token expired and no refresh token available. "
"Re-open the Antigravity IDE to refresh, or run 'antigravity-auth accounts add'."
)
return access_token # return stale token; proxy may still accept it briefly
logger.info("Antigravity token expired or near expiry, refreshing...")
token_data = _refresh_antigravity_token(refresh_token)
if token_data and "access_token" in token_data:
_save_refreshed_antigravity_credentials(auth_data, token_data)
return token_data["access_token"]
logger.warning(
"Antigravity token refresh failed. Re-open the Antigravity IDE or run 'antigravity-auth accounts add'."
)
return access_token
@dataclass
class AgentInfo:
"""Information about an exported agent."""
name: str
description: str
goal_name: str
goal_description: str
node_count: int
edge_count: int
nodes: list[dict]
edges: list[dict]
entry_node: str
terminal_nodes: list[str]
success_criteria: list[dict]
constraints: list[dict]
required_tools: list[str]
has_tools_module: bool
@dataclass
class ValidationResult:
"""Result of agent validation."""
valid: bool
errors: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
missing_tools: list[str] = field(default_factory=list)
missing_credentials: list[str] = field(default_factory=list)
def _resolve_template_vars(text: str | None, variables: dict[str, str]) -> str | None:
"""Resolve ``{{variable_name}}`` placeholders in *text*."""
if text is None or not variables:
return text
import re
def _replace(m: re.Match) -> str:
key = m.group(1).strip()
return variables.get(key, m.group(0))
return re.sub(r"\{\{(.+?)\}\}", _replace, text)
def load_agent_config(data: str | dict) -> tuple[GraphSpec, Goal]:
"""Load ``GraphSpec`` and ``Goal`` from a declarative :class:`AgentConfig`.
The declarative format uses a ``name`` key at the top level, unlike the
legacy export format which uses ``graph``/``goal`` keys. The runner
auto-detects the format in :meth:`AgentLoader.load`.
Template variables in ``config.variables`` are resolved in all
``system_prompt`` and ``identity_prompt`` fields via ``{{var_name}}``.
Returns:
Tuple of (GraphSpec, Goal)
"""
from framework.orchestrator.edge import EdgeCondition, EdgeSpec
from framework.orchestrator.goal import Constraint, Goal as GoalModel, SuccessCriterion
from framework.schemas.agent_config import AgentConfig
if isinstance(data, str):
data = json.loads(data)
config = AgentConfig.model_validate(data)
tvars = config.variables
# Build Goal
success_criteria = [
SuccessCriterion(
id=f"sc-{i}",
description=sc,
metric="llm_judge",
target="",
)
for i, sc in enumerate(config.goal.success_criteria)
]
constraints = [
Constraint(
id=f"c-{i}",
description=c,
constraint_type="hard",
category="general",
)
for i, c in enumerate(config.goal.constraints)
]
goal = GoalModel(
id=f"{config.name}-goal",
name=config.name,
description=config.goal.description,
success_criteria=success_criteria,
constraints=constraints,
)
# Build nodes
condition_map = {
"always": EdgeCondition.ALWAYS,
"on_success": EdgeCondition.ON_SUCCESS,
"on_failure": EdgeCondition.ON_FAILURE,
"conditional": EdgeCondition.CONDITIONAL,
"llm_decide": EdgeCondition.LLM_DECIDE,
}
nodes = []
for nc in config.nodes:
# Resolve tool access: node-level config -> agent-level fallback
if nc.tools.policy == "explicit" and nc.tools.allowed:
tools_list = nc.tools.allowed
tool_policy = "explicit"
elif nc.tools.policy == "none":
tools_list = []
tool_policy = "none"
else:
# Inherit agent-level tool config
if config.tools.policy == "explicit" and config.tools.allowed:
tools_list = config.tools.allowed
else:
tools_list = []
tool_policy = config.tools.policy
node_kwargs: dict = {
"id": nc.id,
"name": nc.name or nc.id,
"description": nc.description or "",
"node_type": nc.node_type,
"system_prompt": _resolve_template_vars(nc.system_prompt, tvars),
"tools": tools_list,
"tool_access_policy": tool_policy,
"model": nc.model,
"input_keys": nc.input_keys,
"output_keys": nc.output_keys,
"nullable_output_keys": nc.nullable_output_keys,
"max_iterations": nc.max_iterations,
"success_criteria": nc.success_criteria,
"skip_judge": nc.skip_judge,
}
# Optional fields -- only pass when set (avoids overriding defaults)
if nc.client_facing:
node_kwargs["client_facing"] = nc.client_facing
if nc.max_node_visits != 1:
node_kwargs["max_node_visits"] = nc.max_node_visits
if nc.failure_criteria:
node_kwargs["failure_criteria"] = nc.failure_criteria
if nc.max_retries is not None:
node_kwargs["max_retries"] = nc.max_retries
nodes.append(NodeSpec(**node_kwargs))
# Build edges
edges = []
for i, ec in enumerate(config.edges):
edges.append(
EdgeSpec(
id=f"e-{i}-{ec.from_node}-{ec.to_node}",
source=ec.from_node,
target=ec.to_node,
condition=condition_map.get(ec.condition, EdgeCondition.ON_SUCCESS),
condition_expr=ec.condition_expr,
priority=ec.priority,
input_mapping=ec.input_mapping,
)
)
# Build entry_points dict for GraphSpec
entry_points_dict: dict = {}
if config.entry_points:
for ep in config.entry_points:
entry_points_dict[ep.id] = ep.entry_node or config.entry_node
else:
entry_points_dict = {"default": config.entry_node}
# Build GraphSpec
graph_kwargs: dict = {
"id": f"{config.name}-graph",
"goal_id": goal.id,
"version": config.version,
"entry_node": config.entry_node,
"entry_points": entry_points_dict,
"terminal_nodes": config.terminal_nodes,
"pause_nodes": config.pause_nodes,
"nodes": nodes,
"edges": edges,
"max_tokens": config.max_tokens,
"loop_config": dict(config.loop_config),
"conversation_mode": config.conversation_mode,
"identity_prompt": _resolve_template_vars(config.identity_prompt, tvars) or "",
}
graph = GraphSpec(**graph_kwargs)
return graph, goal
def load_agent_export(data: str | dict) -> tuple[GraphSpec, Goal]:
"""
Load GraphSpec and Goal from export_graph() output.
Args:
data: JSON string or dict from export_graph()
Returns:
Tuple of (GraphSpec, Goal)
"""
if isinstance(data, str):
data = json.loads(data)
# Extract graph and goal
graph_data = data.get("graph", {})
goal_data = data.get("goal", {})
# Build NodeSpec objects
nodes = []
for node_data in graph_data.get("nodes", []):
nodes.append(NodeSpec(**node_data))
# Build EdgeSpec objects
edges = []
for edge_data in graph_data.get("edges", []):
condition_str = edge_data.get("condition", "on_success")
condition_map = {
"always": EdgeCondition.ALWAYS,
"on_success": EdgeCondition.ON_SUCCESS,
"on_failure": EdgeCondition.ON_FAILURE,
"conditional": EdgeCondition.CONDITIONAL,
"llm_decide": EdgeCondition.LLM_DECIDE,
}
edge = EdgeSpec(
id=edge_data["id"],
source=edge_data["source"],
target=edge_data["target"],
condition=condition_map.get(condition_str, EdgeCondition.ON_SUCCESS),
condition_expr=edge_data.get("condition_expr"),
priority=edge_data.get("priority", 0),
input_mapping=edge_data.get("input_mapping", {}),
)
edges.append(edge)
# Build GraphSpec
graph = GraphSpec(
id=graph_data.get("id", "agent-graph"),
goal_id=graph_data.get("goal_id", ""),
version=graph_data.get("version", "1.0.0"),
entry_node=graph_data.get("entry_node", ""),
entry_points=graph_data.get("entry_points", {}), # Support pause/resume architecture
terminal_nodes=graph_data.get("terminal_nodes", []),
pause_nodes=graph_data.get("pause_nodes", []), # Support pause/resume architecture
nodes=nodes,
edges=edges,
max_steps=graph_data.get("max_steps", 100),
max_retries_per_node=graph_data.get("max_retries_per_node", 3),
description=graph_data.get("description", ""),
)
# Build Goal
from framework.orchestrator.goal import Constraint, SuccessCriterion
success_criteria = []
for sc_data in goal_data.get("success_criteria", []):
success_criteria.append(
SuccessCriterion(
id=sc_data["id"],
description=sc_data["description"],
metric=sc_data.get("metric", ""),
target=sc_data.get("target", ""),
weight=sc_data.get("weight", 1.0),
)
)
constraints = []
for c_data in goal_data.get("constraints", []):
constraints.append(
Constraint(
id=c_data["id"],
description=c_data["description"],
constraint_type=c_data.get("constraint_type", "hard"),
category=c_data.get("category", "safety"),
check=c_data.get("check", ""),
)
)
goal = Goal(
id=goal_data.get("id", ""),
name=goal_data.get("name", ""),
description=goal_data.get("description", ""),
success_criteria=success_criteria,
constraints=constraints,
)
return graph, goal
class AgentLoader:
"""
Loads and runs exported agents with minimal boilerplate.
Handles:
- Loading graph and goal from agent.json
- Auto-discovering tools from tools.py
- Setting up Runtime, LLM, and executor
- Executing with dynamic edge traversal
Usage:
# Simple usage
runner = AgentLoader.load("exports/outbound-sales-agent")
result = await runner.run({"lead_id": "123"})
# With context manager
async with AgentLoader.load("exports/outbound-sales-agent") as runner:
result = await runner.run({"lead_id": "123"})
# With custom tools
runner = AgentLoader.load("exports/outbound-sales-agent")
runner.register_tool("my_tool", my_tool_func)
result = await runner.run({"lead_id": "123"})
"""
@staticmethod
def _resolve_default_model() -> str:
"""Resolve the default model from ~/.hive/configuration.json."""
return get_preferred_model()
def __init__(
self,
agent_path: Path,
graph: GraphSpec,
goal: Goal,
mock_mode: bool = False,
storage_path: Path | None = None,
model: str | None = None,
intro_message: str = "",
runtime_config: "AgentRuntimeConfig | None" = None,
interactive: bool = True,
skip_credential_validation: bool = False,
requires_account_selection: bool = False,
configure_for_account: Callable | None = None,
list_accounts: Callable | None = None,
credential_store: Any | None = None,
):
"""
Initialize the runner (use AgentLoader.load() instead).
Args:
agent_path: Path to agent folder
graph: Loaded GraphSpec object
goal: Loaded Goal object
mock_mode: If True, use mock LLM responses
storage_path: Path for runtime storage (defaults to temp)
model: Model to use (reads from agent config or ~/.hive/configuration.json if None)
intro_message: Optional greeting shown to user on TUI load
runtime_config: Optional AgentRuntimeConfig (webhook settings, etc.)
interactive: If True (default), offer interactive credential setup on failure.
Set to False when called from the TUI (which handles setup via its own screen).
skip_credential_validation: If True, skip credential checks at load time.
requires_account_selection: If True, TUI shows account picker before starting.
configure_for_account: Callback(runner, account_dict) to scope tools after selection.
list_accounts: Callback() -> list[dict] to fetch available accounts.
credential_store: Optional shared CredentialStore (avoids creating redundant stores).
"""
self.agent_path = agent_path
self.graph = graph
self.goal = goal
self.mock_mode = mock_mode
self.model = model or self._resolve_default_model()
self.intro_message = intro_message
self.runtime_config = runtime_config
self._interactive = interactive
self.skip_credential_validation = skip_credential_validation
self.requires_account_selection = requires_account_selection
self._configure_for_account = configure_for_account
self._list_accounts = list_accounts
self._credential_store = credential_store
# Set up storage
if storage_path:
self._storage_path = storage_path
self._temp_dir = None
else:
home = Path.home()
default_storage = home / ".hive" / "agents" / agent_path.name
default_storage.mkdir(parents=True, exist_ok=True)
self._storage_path = default_storage
self._temp_dir = None
# Load HIVE_CREDENTIAL_KEY from shell config if not in env.
# Must happen before MCP subprocesses are spawned so they inherit it.
_ensure_credential_key_env()
# Initialize components
self._tool_registry = ToolRegistry()
self._llm: LLMProvider | None = None
self._approval_callback: Callable | None = None
# AgentRuntime — unified execution path for all agents
self._agent_runtime: AgentHost | None = None
# Pre-load validation: structural checks + credentials.
# Fails fast with actionable guidance — no MCP noise on screen.
run_preload_validation(
self.graph,
interactive=self._interactive,
skip_credential_validation=self.skip_credential_validation,
)
# Auto-discover tools from tools.py
tools_path = agent_path / "tools.py"
if tools_path.exists():
self._tool_registry.discover_from_module(tools_path)
# Per-agent env for MCP subprocesses. Stored on the registry so
# parallel workers in the same process don't clobber each other
# via the shared os.environ dict — the registry merges these
# into every MCPServerConfig.env at registration time.
self._tool_registry.set_mcp_extra_env(
{
"HIVE_AGENT_NAME": agent_path.name,
"HIVE_STORAGE_PATH": str(self._storage_path),
}
)
# MCP tools are loaded by McpRegistryStage in the pipeline during AgentHost.start()
@staticmethod
def _import_agent_module(agent_path: Path):
"""Import an agent package from its directory path.
Ensures the agent's parent directory is on sys.path so the package
can be imported normally (supports relative imports within the agent).
Always reloads the package and its submodules so that code changes
made since the last import (or since a previous session load in the
same server process) are picked up.
"""
import importlib
import sys
package_name = agent_path.name
parent_dir = str(agent_path.resolve().parent)
# Always place the correct parent directory first on sys.path.
# Multiple agent dirs can contain packages with the same name
# (e.g. exports/deep_research_agent and examples/deep_research_agent).
# Without this, a previously-added parent dir could shadow the
# agent we actually want to load.
if parent_dir in sys.path:
sys.path.remove(parent_dir)
sys.path.insert(0, parent_dir)
# Evict cached submodules first (e.g. deep_research_agent.nodes,
# deep_research_agent.agent) so the top-level reload picks up
# changes in the entire package — not just __init__.py.
stale = [name for name in sys.modules if name == package_name or name.startswith(f"{package_name}.")]
for name in stale:
del sys.modules[name]
return importlib.import_module(package_name)
@classmethod
def load(
cls,
agent_path: str | Path,
mock_mode: bool = False,
storage_path: Path | None = None,
model: str | None = None,
interactive: bool = True,
skip_credential_validation: bool | None = None,
credential_store: Any | None = None,
) -> "AgentLoader":
"""
Load a colony worker from its config directory.
Finds {worker_name}.json files in the directory and builds a
minimal GraphSpec from the first one found.
Args:
agent_path: Path to colony directory containing worker config JSONs
mock_mode: If True, use mock LLM responses
storage_path: Path for runtime storage
model: LLM model to use
interactive: If True (default), offer interactive credential setup.
skip_credential_validation: If True, skip credential checks.
credential_store: Optional shared CredentialStore.
Returns:
AgentLoader instance ready to run
"""
agent_path = Path(agent_path)
# Find {worker_name}.json worker config files in the colony directory
worker_jsons = sorted(
p
for p in agent_path.iterdir()
if p.is_file()
and p.suffix == ".json"
and p.stem not in ("agent", "flowchart", "triggers", "configuration", "metadata")
)
if not worker_jsons:
raise FileNotFoundError(f"No worker config found in {agent_path}")
from framework.orchestrator.edge import GraphSpec
from framework.orchestrator.goal import Constraint, Goal as GoalModel, SuccessCriterion
from framework.orchestrator.node import NodeSpec
# Load the first worker config
first_worker = json.loads(worker_jsons[0].read_text(encoding="utf-8"))
worker_name = first_worker.get("name", worker_jsons[0].stem)
system_prompt = first_worker.get("system_prompt", "")
tool_names = first_worker.get("tools", [])
goal_data = first_worker.get("goal", {})
loop_config = first_worker.get("loop_config", {})
success_criteria = [
SuccessCriterion(id=f"sc-{i}", description=sc, metric="llm_judge", target="")
for i, sc in enumerate(goal_data.get("success_criteria", []))
]
constraints = [
Constraint(id=f"c-{i}", description=c, constraint_type="hard", category="general")
for i, c in enumerate(goal_data.get("constraints", []))
]
goal = GoalModel(
id=f"{agent_path.name}-goal",
name=goal_data.get("description", worker_name),
description=goal_data.get("description", ""),
success_criteria=success_criteria,
constraints=constraints,
)
node = NodeSpec(
id=worker_name,
name=worker_name.replace("_", " ").title(),
description=first_worker.get("description", ""),
node_type="event_loop",
tools=tool_names,
system_prompt=system_prompt,
)
graph = GraphSpec(
id=f"{agent_path.name}-graph",
goal_id=goal.id,
entry_node=worker_name,
nodes=[node],
edges=[],
max_tokens=loop_config.get("max_tokens", 4096),
loop_config=loop_config,
identity_prompt=first_worker.get("identity_prompt", ""),
conversation_mode="continuous",
)
logger.info(
"Loaded colony worker config from %s (name=%s, tools=%d)",
worker_jsons[0].name,
worker_name,
len(tool_names),
)
if storage_path is None:
storage_path = _HIVE_HOME / "agents" / agent_path.name / worker_name
storage_path.mkdir(parents=True, exist_ok=True)
runner = cls(
agent_path=agent_path,
graph=graph,
goal=goal,
mock_mode=mock_mode,
storage_path=storage_path,
model=model,
interactive=interactive,
skip_credential_validation=skip_credential_validation or False,
credential_store=credential_store,
)
runner._agent_default_skills = None
# Colony workers attached to a SQLite task queue get the
# colony-progress-tracker skill pre-activated so its full
# claim / step / SOP-gate protocol lands in the system prompt
# on turn 0, bypassing the progressive-disclosure catalog
# lookup. Triggered by the presence of ``input_data.db_path``
# in worker.json (written by fork_session_into_colony and
# backfilled by ensure_progress_db for pre-existing colonies).
_preactivate: list[str] = []
_input_data = first_worker.get("input_data") or {}
if isinstance(_input_data, dict) and _input_data.get("db_path"):
_preactivate.append("hive.colony-progress-tracker")
runner._agent_skills = _preactivate or None
return runner
def register_tool(
self,
name: str,
tool_or_func: Tool | Callable,
executor: Callable | None = None,
) -> None:
"""
Register a tool for use by the agent.
Args:
name: Tool name
tool_or_func: Either a Tool object or a callable function
executor: Executor function (required if tool_or_func is a Tool)
"""
if isinstance(tool_or_func, Tool):
if executor is None:
raise ValueError("executor required when registering a Tool object")
self._tool_registry.register(name, tool_or_func, executor)
else:
# It's a function, auto-generate Tool
self._tool_registry.register_function(tool_or_func, name=name)
def register_mcp_server(
self,
name: str,
transport: str,
**config_kwargs,
) -> int:
"""
Register an MCP server and discover its tools.
Args:
name: Server name
transport: "stdio" or "http"
**config_kwargs: Additional configuration (command, args, url, etc.)
Returns:
Number of tools registered from this server
Example:
# Register STDIO MCP server
runner.register_mcp_server(
name="tools",
transport="stdio",
command="python",
args=["-m", "aden_tools.mcp_server", "--stdio"],
cwd="/path/to/tools"
)
# Register HTTP MCP server
runner.register_mcp_server(
name="tools",
transport="http",
url="http://localhost:4001"
)
"""
server_config = {
"name": name,
"transport": transport,
**config_kwargs,
}
return self._tool_registry.register_mcp_server(server_config)
def set_approval_callback(self, callback: Callable) -> None:
"""
Set a callback for human-in-the-loop approval during execution.
Args:
callback: Function to call for approval (receives node info, returns bool)
"""
self._approval_callback = callback
def _setup(self, event_bus=None) -> None:
"""Set up runtime via pipeline stages.
Builds a pipeline with the default stages (LLM, credentials, MCP,
skills) and passes it to AgentHost. The stages initialize during
``AgentHost.start()`` and inject tools/LLM/credentials/skills.
"""
from framework.observability import configure_logging
from framework.pipeline.stages.credential_resolver import CredentialResolverStage
from framework.pipeline.stages.llm_provider import LlmProviderStage
from framework.pipeline.stages.mcp_registry import McpRegistryStage
from framework.pipeline.stages.skill_registry import SkillRegistryStage
from framework.skills.config import SkillsConfig
from framework.skills.discovery import ExtraScope
configure_logging(level="INFO", format="auto")
# Set up session context for tools
agent_id = self.graph.id or "unknown"
self._tool_registry.set_session_context(agent_id=agent_id)
# Read MCP server refs from agent.json
mcp_refs = []
agent_json = self.agent_path / "agent.json"
if agent_json.exists():
try:
import json as _json
data = _json.loads(agent_json.read_text(encoding="utf-8"))
mcp_refs = data.get("mcp_servers", [])
except Exception:
pass
# Build default pipeline stages
# Default infrastructure stages (always present)
pipeline_stages = [
LlmProviderStage(
model=self.model,
mock_mode=self.mock_mode,
llm=self._llm,
),
CredentialResolverStage(
credential_store=self._credential_store,
),
McpRegistryStage(
server_refs=mcp_refs,
agent_path=self.agent_path,
tool_registry=self._tool_registry,
),
SkillRegistryStage(
project_root=self.agent_path,
interactive=self._interactive,
skills_config=SkillsConfig.from_agent_vars(
default_skills=getattr(self, "_agent_default_skills", None),
skills=getattr(self, "_agent_skills", None),
),
# Surface the colony's flat ``skills/`` directory as a
# ``colony_ui`` extra scope so SKILL.md files written there
# by ``create_colony`` (or the HTTP routes) are picked up
# with correct provenance. The legacy nested
# ``<colony>/.hive/skills/`` path is still picked up via
# project-scope auto-discovery (project_root above).
extra_scope_dirs=[
ExtraScope(
directory=self.agent_path / "skills",
label="colony_ui",
priority=3,
)
],
),
]
# Merge user-configured stages from ~/.hive/configuration.json
from framework.pipeline.registry import build_pipeline_from_config
hive_config = get_hive_config()
user_stages_config = hive_config.get("pipeline", {}).get("stages", [])
if user_stages_config:
user_pipeline = build_pipeline_from_config(user_stages_config)
pipeline_stages.extend(user_pipeline.stages)
# Merge agent-level overrides from agent.json pipeline field
if agent_json.exists():
try:
agent_pipeline = (
_json.loads(agent_json.read_text(encoding="utf-8")).get("pipeline", {}).get("stages", [])
)
if agent_pipeline:
agent_stages = build_pipeline_from_config(agent_pipeline)
pipeline_stages.extend(agent_stages.stages)
except Exception:
pass
# Create AgentHost directly (no wrapper)
from framework.host.execution_manager import EntryPointSpec
from framework.orchestrator.checkpoint_config import CheckpointConfig
from framework.tracker.runtime_log_store import RuntimeLogStore
self._agent_runtime = AgentHost(
graph=self.graph,
goal=self.goal,
storage_path=self._storage_path,
runtime_log_store=RuntimeLogStore(
base_path=self._storage_path / "runtime_logs",
),
checkpoint_config=CheckpointConfig(
enabled=True,
checkpoint_on_node_complete=True,
checkpoint_max_age_days=7,
async_checkpoint=True,
),
graph_id=self.graph.id or self.agent_path.name,
event_bus=event_bus,
pipeline_stages=pipeline_stages,
)
self._agent_runtime.register_entry_point(
EntryPointSpec(
id="default",
name="Default",
entry_node=self.graph.entry_node,
trigger_type="manual",
isolation_level="shared",
),
)
self._agent_runtime.intro_message = self.intro_message
def _get_api_key_env_var(self, model: str) -> str | None:
"""Get the environment variable name for the API key based on model name."""
model_lower = model.lower()
# Map model prefixes to API key environment variables
# LiteLLM uses these conventions
if model_lower.startswith("cerebras/"):
return "CEREBRAS_API_KEY"
elif model_lower.startswith("openai/") or model_lower.startswith("gpt-"):
return "OPENAI_API_KEY"
elif model_lower.startswith("anthropic/") or model_lower.startswith("claude"):
return "ANTHROPIC_API_KEY"
elif model_lower.startswith("gemini/") or model_lower.startswith("google/"):
return "GEMINI_API_KEY"
elif model_lower.startswith("mistral/"):
return "MISTRAL_API_KEY"
elif model_lower.startswith("groq/"):
return "GROQ_API_KEY"
elif model_lower.startswith("openrouter/"):
return "OPENROUTER_API_KEY"
elif self._is_local_model(model_lower):
return None # Local models don't need an API key
elif model_lower.startswith("azure/"):
return "AZURE_API_KEY"
elif model_lower.startswith("cohere/"):
return "COHERE_API_KEY"
elif model_lower.startswith("replicate/"):
return "REPLICATE_API_KEY"
elif model_lower.startswith("together/"):
return "TOGETHER_API_KEY"
elif model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
return "MINIMAX_API_KEY"
elif model_lower.startswith("kimi/"):
return "KIMI_API_KEY"
elif model_lower.startswith("hive/"):
return "HIVE_API_KEY"
else:
# Default: assume OpenAI-compatible
return "OPENAI_API_KEY"
def _get_api_key_from_credential_store(self) -> str | None:
"""Get the LLM API key from the encrypted credential store.
Maps model name to credential store ID (e.g. "anthropic/..." -> "anthropic")
and retrieves the key via CredentialStore.get().
"""
if not os.environ.get("HIVE_CREDENTIAL_KEY"):
return None
# Map model prefix to credential store ID
model_lower = self.model.lower()
cred_id = None
if model_lower.startswith("anthropic/") or model_lower.startswith("claude"):
cred_id = "anthropic"
elif model_lower.startswith("openai/") or model_lower.startswith("gpt"):
cred_id = "openai"
elif model_lower.startswith("gemini/") or model_lower.startswith("gemini"):
cred_id = "gemini"
elif model_lower.startswith("minimax/") or model_lower.startswith("minimax-"):
cred_id = "minimax"
elif model_lower.startswith("groq/"):
cred_id = "groq"
elif model_lower.startswith("cerebras/"):
cred_id = "cerebras"
elif model_lower.startswith("openrouter/"):
cred_id = "openrouter"
elif model_lower.startswith("mistral/"):
cred_id = "mistral"
elif model_lower.startswith("together_ai/") or model_lower.startswith("together/"):
cred_id = "together"
elif model_lower.startswith("deepseek/"):
cred_id = "deepseek"
elif model_lower.startswith("kimi/"):
cred_id = "kimi"
elif model_lower.startswith("hive/"):
cred_id = "hive"
if cred_id is None:
return None
try:
store = self._credential_store
if store is None:
from framework.credentials import CredentialStore
store = CredentialStore.with_encrypted_storage()
return store.get(cred_id)
except Exception:
return None
@staticmethod
def _is_local_model(model: str) -> bool:
"""Check if a model is a local model that doesn't require an API key.
Local providers like Ollama run on the user's machine and do not
need any authentication credentials.
"""
LOCAL_PREFIXES = (
"ollama/",
"ollama_chat/",
"vllm/",
"lm_studio/",
"llamacpp/",
)
return model.lower().startswith(LOCAL_PREFIXES)
# ------------------------------------------------------------------
# Execution modes
#
# run() One-shot, blocking execution for worker agents
# (headless CLI via ``hive run``). Validates, runs
# the graph to completion, and returns the result.
#
# start() / trigger() Long-lived runtime for the frontend (queen).
# start() boots the runtime; trigger() sends
# non-blocking execution requests. Used by the
# server session manager and API routes.
# ------------------------------------------------------------------
async def run(
self,
input_data: dict | None = None,
session_state: dict | None = None,
entry_point_id: str | None = None,
) -> ExecutionResult:
"""One-shot execution for worker agents (headless CLI).
Validates credentials, runs the graph to completion, and returns
the result. Used by ``hive run`` and programmatic callers.
For the frontend (queen), use start() + trigger() instead.
Args:
input_data: Input data for the agent (e.g., {"lead_id": "123"})
session_state: Optional session state to resume from
entry_point_id: For multi-entry-point agents, which entry point to trigger
(defaults to first entry point or "default")
Returns:
ExecutionResult with output, path, and metrics
"""
# Validate credentials before execution (fail-fast)
validation = self.validate()
if validation.missing_credentials:
error_lines = ["Cannot run agent: missing required credentials\n"]
for warning in validation.warnings:
if "Missing " in warning:
error_lines.append(f" {warning}")
error_lines.append("\nSet the required environment variables and re-run the agent.")
error_msg = "\n".join(error_lines)
return ExecutionResult(
success=False,
error=error_msg,
)
return await self._run_with_agent_runtime(
input_data=input_data or {},
entry_point_id=entry_point_id,
session_state=session_state,
)
async def _run_with_agent_runtime(
self,
input_data: dict,
entry_point_id: str | None = None,
session_state: dict | None = None,
) -> ExecutionResult:
"""Run using AgentRuntime."""
import sys
if self._agent_runtime is None:
self._setup()
# Start runtime if not running
if not self._agent_runtime.is_running:
await self._agent_runtime.start()
# Set up stdin-based I/O for the queen in headless mode.
# When the queen calls ask_user(), it emits
# CLIENT_INPUT_REQUESTED on the event bus and blocks. We subscribe
# a handler that prints the prompt and reads from stdin, then injects
# the user's response back into the node to unblock it.
has_queen = any(n.is_queen_node() for n in self.graph.nodes)
sub_ids: list[str] = []
if has_queen and sys.stdin.isatty():
from framework.host.event_bus import EventType
runtime = self._agent_runtime
async def _handle_client_output(event):
"""Print agent output to stdout as it streams."""
content = event.data.get("content", "")
if content:
print(content, end="", flush=True)
async def _handle_input_requested(event):
"""Read user input from stdin and inject it into the node."""
import asyncio
node_id = event.node_id
try:
loop = asyncio.get_event_loop()
user_input = await loop.run_in_executor(None, input, "\n>>> ")
except EOFError:
user_input = ""
# Inject into the waiting EventLoopNode via runtime
await runtime.inject_input(node_id, user_input)
sub_ids.append(
runtime.subscribe_to_events(
event_types=[EventType.CLIENT_OUTPUT_DELTA],
handler=_handle_client_output,
)
)
sub_ids.append(
runtime.subscribe_to_events(
event_types=[EventType.CLIENT_INPUT_REQUESTED],
handler=_handle_input_requested,
)
)
# Determine entry point
if entry_point_id is None:
# Use first entry point or "default" if no entry points defined
entry_points = self._agent_runtime.get_entry_points()
if entry_points:
entry_point_id = entry_points[0].id
else:
entry_point_id = "default"
try:
# Trigger and wait for result
result = await self._agent_runtime.trigger_and_wait(
entry_point_id=entry_point_id,
input_data=input_data,
session_state=session_state,
)
# Return result or create error result
if result is not None:
return result
else:
return ExecutionResult(
success=False,
error="Execution timed out or failed to complete",
)
finally:
# Clean up subscriptions
for sub_id in sub_ids:
self._agent_runtime.unsubscribe_from_events(sub_id)
# === Runtime API ===
async def start(self) -> None:
"""Boot the agent runtime for the frontend (queen).
Pair with trigger() to send execution requests. Used by the
server session manager. For headless worker agents, use run()
instead.
"""
if self._agent_runtime is None:
self._setup()
await self._agent_runtime.start()
async def stop(self) -> None:
"""Stop the agent runtime."""
if self._agent_runtime is not None:
await self._agent_runtime.stop()
async def trigger(
self,
entry_point_id: str,
input_data: dict[str, Any],
correlation_id: str | None = None,
) -> str:
"""Send a non-blocking execution request to a running runtime.
Used by the server API routes after start(). For headless
worker agents, use run() instead.
Args:
entry_point_id: Which entry point to trigger
input_data: Input data for the execution
correlation_id: Optional ID to correlate related executions
Returns:
Execution ID for tracking
"""
if self._agent_runtime is None:
self._setup()
if not self._agent_runtime.is_running:
await self._agent_runtime.start()
return await self._agent_runtime.trigger(
entry_point_id=entry_point_id,
input_data=input_data,
correlation_id=correlation_id,
)
def get_entry_points(self) -> list[EntryPointSpec]:
"""
Get all registered entry points.
Returns:
List of EntryPointSpec objects
"""
if self._agent_runtime is None:
self._setup()
return self._agent_runtime.get_entry_points()
@property
def is_running(self) -> bool:
"""Check if the agent runtime is running (for multi-entry-point agents)."""
if self._agent_runtime is None:
return False
return self._agent_runtime.is_running
def info(self) -> AgentInfo:
"""Return agent metadata (nodes, edges, goal, required tools)."""
# Extract required tools from nodes
required_tools = set()
nodes_info = []
for node in self.graph.nodes:
node_info = {
"id": node.id,
"name": node.name,
"description": node.description,
"type": node.node_type,
"input_keys": node.input_keys,
"output_keys": node.output_keys,
}
if node.tools:
required_tools.update(node.tools)
node_info["tools"] = node.tools
nodes_info.append(node_info)
edges_info = [
{
"id": edge.id,
"source": edge.source,
"target": edge.target,
"condition": edge.condition.value,
}
for edge in self.graph.edges
]
return AgentInfo(
name=self.graph.id,
description=self.graph.description,
goal_name=self.goal.name,
goal_description=self.goal.description,
node_count=len(self.graph.nodes),
edge_count=len(self.graph.edges),
nodes=nodes_info,
edges=edges_info,
entry_node=self.graph.entry_node,
terminal_nodes=self.graph.terminal_nodes,
success_criteria=[
{
"id": sc.id,
"description": sc.description,
"metric": sc.metric,
"target": sc.target,
}
for sc in self.goal.success_criteria
],
constraints=[
{"id": c.id, "description": c.description, "type": c.constraint_type} for c in self.goal.constraints
],
required_tools=sorted(required_tools),
has_tools_module=(self.agent_path / "tools.py").exists(),
)
def validate(self) -> ValidationResult:
"""
Check agent is valid and all required tools are registered.
Returns:
ValidationResult with errors, warnings, and missing tools
"""
errors = []
warnings = []
missing_tools = []
# Validate graph structure
graph_result = self.graph.validate()
errors.extend(graph_result["errors"])
warnings.extend(graph_result["warnings"])
# Check goal has success criteria
if not self.goal.success_criteria:
warnings.append("Goal has no success criteria defined")
# Check required tools are registered
info = self.info()
for tool_name in info.required_tools:
if not self._tool_registry.has_tool(tool_name):
missing_tools.append(tool_name)
if missing_tools:
warnings.append(f"Missing tool implementations: {', '.join(missing_tools)}")
# Check credentials for required tools and node types
# Uses CredentialStoreAdapter.default() which includes Aden sync support
missing_credentials = []
try:
from aden_tools.credentials.store_adapter import CredentialStoreAdapter
adapter = CredentialStoreAdapter.default()
# Check tool credentials
for _cred_name, spec in adapter.get_missing_for_tools(list(info.required_tools)):
missing_credentials.append(spec.env_var)
affected_tools = [t for t in info.required_tools if t in spec.tools]
tools_str = ", ".join(affected_tools)
warning_msg = f"Missing {spec.env_var} for {tools_str}"
if spec.help_url:
warning_msg += f"\n Get it at: {spec.help_url}"
warnings.append(warning_msg)
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
node_types = list({node.node_type for node in self.graph.nodes})
for _cred_name, spec in adapter.get_missing_for_node_types(node_types):
missing_credentials.append(spec.env_var)
affected_types = [t for t in node_types if t in spec.node_types]
types_str = ", ".join(affected_types)
warning_msg = f"Missing {spec.env_var} for {types_str} nodes"
if spec.help_url:
warning_msg += f"\n Get it at: {spec.help_url}"
warnings.append(warning_msg)
except ImportError:
# aden_tools not installed - fall back to direct check
has_llm_nodes = any(node.node_type == "event_loop" for node in self.graph.nodes)
if has_llm_nodes:
api_key_env = self._get_api_key_env_var(self.model)
if api_key_env and not os.environ.get(api_key_env):
if api_key_env not in missing_credentials:
missing_credentials.append(api_key_env)
warnings.append(f"Agent has LLM nodes but {api_key_env} not set (model: {self.model})")
return ValidationResult(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
missing_tools=missing_tools,
missing_credentials=missing_credentials,
)
def cleanup(self) -> None:
"""Clean up resources (synchronous)."""
if hasattr(self, "_tool_registry"):
self._tool_registry.cleanup()
if self._temp_dir:
self._temp_dir.cleanup()
self._temp_dir = None
async def cleanup_async(self) -> None:
"""Clean up resources (asynchronous)."""
# Stop agent runtime if running
if self._agent_runtime is not None and self._agent_runtime.is_running:
await self._agent_runtime.stop()
# Run synchronous cleanup
self.cleanup()
async def __aenter__(self) -> "AgentLoader":
"""Context manager entry."""
self._setup()
if self._agent_runtime is not None:
await self._agent_runtime.start()
return self
async def __aexit__(self, *args) -> None:
"""Context manager exit."""
await self.cleanup_async()
def __del__(self) -> None:
"""Destructor - cleanup temp dir."""
self.cleanup()