feat: credential v2 with provider loading and test agent
This commit is contained in:
@@ -5,13 +5,44 @@ Interactive agent that lists connected accounts, lets the user pick one,
|
||||
loads the provider's tools, and runs a chat session to test the credential.
|
||||
"""
|
||||
|
||||
from .agent import CredentialTesterAgent, edges, goal, nodes
|
||||
from .agent import (
|
||||
CredentialTesterAgent,
|
||||
configure_for_account,
|
||||
conversation_mode,
|
||||
edges,
|
||||
entry_node,
|
||||
entry_points,
|
||||
goal,
|
||||
identity_prompt,
|
||||
list_connected_accounts,
|
||||
loop_config,
|
||||
nodes,
|
||||
pause_nodes,
|
||||
requires_account_selection,
|
||||
skip_credential_validation,
|
||||
skip_guardian,
|
||||
terminal_nodes,
|
||||
)
|
||||
from .config import default_config
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
"CredentialTesterAgent",
|
||||
"goal",
|
||||
"nodes",
|
||||
"configure_for_account",
|
||||
"conversation_mode",
|
||||
"default_config",
|
||||
"edges",
|
||||
"entry_node",
|
||||
"entry_points",
|
||||
"goal",
|
||||
"identity_prompt",
|
||||
"list_connected_accounts",
|
||||
"loop_config",
|
||||
"nodes",
|
||||
"pause_nodes",
|
||||
"requires_account_selection",
|
||||
"skip_credential_validation",
|
||||
"skip_guardian",
|
||||
"terminal_nodes",
|
||||
]
|
||||
|
||||
@@ -4,9 +4,10 @@ A framework agent that lets the user pick a connected account and test it
|
||||
by making real API calls via the provider's tools.
|
||||
|
||||
When loaded via AgentRunner.load() (TUI picker, ``hive run``), the module-level
|
||||
``nodes`` / ``edges`` variables provide a static graph with a single
|
||||
client-facing node. The system prompt lists connected accounts so the LLM
|
||||
can guide the user through account selection and testing.
|
||||
``nodes`` / ``edges`` variables provide a static graph. The TUI detects
|
||||
``requires_account_selection`` and shows an account picker *before* starting
|
||||
the agent. ``configure_for_account()`` then scopes the node's tools to the
|
||||
selected provider.
|
||||
|
||||
When used directly (``CredentialTesterAgent``), the graph is built dynamically
|
||||
after the user picks an account programmatically.
|
||||
@@ -15,6 +16,7 @@ after the user picks an account programmatically.
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from framework.graph import Goal, NodeSpec, SuccessCriterion
|
||||
from framework.graph.checkpoint_config import CheckpointConfig
|
||||
@@ -28,6 +30,9 @@ from framework.runtime.execution_stream import EntryPointSpec
|
||||
from .config import default_config
|
||||
from .nodes import build_tester_node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Goal
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -48,80 +53,23 @@ goal = Goal(
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level graph variables (read by AgentRunner.load)
|
||||
# ---------------------------------------------------------------------------
|
||||
# All tools are provided by the hive-tools MCP server (mcp_servers.json).
|
||||
# The system prompt lists connected accounts so the LLM can help the user
|
||||
# pick one and test it interactively.
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="tester",
|
||||
name="Credential Tester",
|
||||
description=(
|
||||
"Interactive credential testing — lets the user pick an account "
|
||||
"and verify it via API calls."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=[],
|
||||
output_keys=[],
|
||||
tools=[], # All tools come from MCP; no explicit filter
|
||||
system_prompt="""\
|
||||
You are a credential tester. Your job is to help the user verify that their \
|
||||
connected accounts can make real API calls.
|
||||
|
||||
# Startup
|
||||
|
||||
1. Call ``get_account_info`` to list the user's connected accounts.
|
||||
2. Present the list and ask the user which account to test.
|
||||
3. Once they pick one, suggest a simple read-only API call to verify \
|
||||
the credential works (e.g. list messages, list channels, list contacts).
|
||||
4. Execute the call when the user agrees.
|
||||
5. Report the result: success (with sample data) or failure (with error).
|
||||
6. Let the user request additional API calls to further test the credential.
|
||||
|
||||
# Rules
|
||||
|
||||
- Start with read-only operations (list, get) before write operations.
|
||||
- Always confirm with the user before performing write operations.
|
||||
- If a call fails, report the exact error — this helps diagnose credential issues.
|
||||
- Be concise. No emojis.
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
edges = []
|
||||
|
||||
entry_node = "tester"
|
||||
entry_points = {"start": "tester"}
|
||||
pause_nodes = []
|
||||
terminal_nodes = [] # Forever-alive: loops until user exits
|
||||
|
||||
conversation_mode = "continuous"
|
||||
identity_prompt = (
|
||||
"You are a credential tester that verifies connected accounts can make real API calls."
|
||||
)
|
||||
loop_config = {
|
||||
"max_iterations": 50,
|
||||
"max_tool_calls_per_turn": 10,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_tools_for_provider(provider_name: str) -> list[str]:
|
||||
"""Collect tool names for an Aden provider from CREDENTIAL_SPECS."""
|
||||
"""Collect tool names for a specific Aden credential by credential_id.
|
||||
|
||||
Matches on ``credential_id`` (e.g. "google" → Gmail tools only),
|
||||
NOT ``aden_provider_name`` which can be shared across products
|
||||
(e.g. both google and google_docs have aden_provider_name="google").
|
||||
"""
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
tools: list[str] = []
|
||||
for spec in CREDENTIAL_SPECS.values():
|
||||
if spec.aden_provider_name == provider_name:
|
||||
if spec.credential_id == provider_name:
|
||||
tools.extend(spec.tools)
|
||||
return sorted(set(tools))
|
||||
|
||||
@@ -158,6 +106,135 @@ def list_connected_accounts() -> list[dict]:
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level hooks (read by AgentRunner.load / TUI)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
skip_credential_validation = True
|
||||
"""Don't validate credentials at load time — we don't know which provider yet."""
|
||||
|
||||
skip_guardian = True
|
||||
"""Don't attach the Hive Coder guardian — this is a standalone utility agent."""
|
||||
|
||||
requires_account_selection = True
|
||||
"""Signal TUI to show account picker before starting the agent."""
|
||||
|
||||
|
||||
def configure_for_account(runner: AgentRunner, account: dict) -> None:
|
||||
"""Scope the tester node's tools to the selected provider.
|
||||
|
||||
Called by the TUI after the user picks an account from the picker.
|
||||
After scoping, re-enables credential validation so the selected
|
||||
provider's credentials are checked before the agent starts.
|
||||
"""
|
||||
provider = account["provider"]
|
||||
tools = get_tools_for_provider(provider)
|
||||
tools.append("get_account_info")
|
||||
|
||||
alias = account.get("alias", "unknown")
|
||||
email = account.get("identity", {}).get("email", "")
|
||||
detail = f" (email: {email})" if email else ""
|
||||
|
||||
for node in runner.graph.nodes:
|
||||
if node.id == "tester":
|
||||
node.tools = sorted(set(tools))
|
||||
# Update system prompt to be provider-specific
|
||||
node.system_prompt = f"""\
|
||||
You are a credential tester for the account: {provider}/{alias}{detail}
|
||||
|
||||
# Instructions
|
||||
|
||||
1. Suggest a simple read-only API call to verify the credential works \
|
||||
(e.g. list messages, list channels, list contacts).
|
||||
2. Execute the call when the user agrees.
|
||||
3. Report the result: success (with sample data) or failure (with error).
|
||||
4. Let the user request additional API calls to further test the credential.
|
||||
|
||||
# Account routing
|
||||
|
||||
IMPORTANT: Always pass `account="{alias}"` when calling any tool. \
|
||||
This routes the API call to the correct credential. Never use the email \
|
||||
or any other identifier — always use the alias exactly as shown.
|
||||
|
||||
# Rules
|
||||
|
||||
- Start with read-only operations (list, get) before write operations.
|
||||
- Always confirm with the user before performing write operations.
|
||||
- If a call fails, report the exact error — this helps diagnose credential issues.
|
||||
- Be concise. No emojis.
|
||||
"""
|
||||
break
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level graph variables (read by AgentRunner.load)
|
||||
# ---------------------------------------------------------------------------
|
||||
# The static node starts with minimal tools. configure_for_account() scopes
|
||||
# it to the selected provider's tools before execution.
|
||||
|
||||
nodes = [
|
||||
NodeSpec(
|
||||
id="tester",
|
||||
name="Credential Tester",
|
||||
description=(
|
||||
"Interactive credential testing — lets the user pick an account "
|
||||
"and verify it via API calls."
|
||||
),
|
||||
node_type="event_loop",
|
||||
client_facing=True,
|
||||
max_node_visits=0,
|
||||
input_keys=[],
|
||||
output_keys=[],
|
||||
tools=["get_account_info"],
|
||||
system_prompt="""\
|
||||
You are a credential tester. Your job is to help the user verify that their \
|
||||
connected accounts can make real API calls.
|
||||
|
||||
# Startup
|
||||
|
||||
1. Call ``get_account_info`` to list the user's connected accounts.
|
||||
2. Present the list and ask the user which account to test.
|
||||
3. Once they pick one, note the account's **alias** (e.g. "Timothy", "work-slack").
|
||||
4. Suggest a simple read-only API call to verify the credential works \
|
||||
(e.g. list messages, list channels, list contacts).
|
||||
5. Execute the call when the user agrees.
|
||||
6. Report the result: success (with sample data) or failure (with error).
|
||||
7. Let the user request additional API calls to further test the credential.
|
||||
|
||||
# Account routing
|
||||
|
||||
IMPORTANT: Always pass the account's **alias** as the ``account`` parameter \
|
||||
when calling any tool. The alias is the routing key — never use the email or \
|
||||
any other identifier. For example, if the alias is "Timothy", call \
|
||||
``gmail_list_messages(account="Timothy", ...)``.
|
||||
|
||||
# Rules
|
||||
|
||||
- Start with read-only operations (list, get) before write operations.
|
||||
- Always confirm with the user before performing write operations.
|
||||
- If a call fails, report the exact error — this helps diagnose credential issues.
|
||||
- Be concise. No emojis.
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
||||
edges = []
|
||||
|
||||
entry_node = "tester"
|
||||
entry_points = {"start": "tester"}
|
||||
pause_nodes = []
|
||||
terminal_nodes = [] # Forever-alive: loops until user exits
|
||||
|
||||
conversation_mode = "continuous"
|
||||
identity_prompt = (
|
||||
"You are a credential tester that verifies connected accounts can make real API calls."
|
||||
)
|
||||
loop_config = {
|
||||
"max_iterations": 50,
|
||||
"max_tool_calls_per_turn": 10,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Programmatic agent class (used by __main__.py CLI)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -257,10 +334,12 @@ class CredentialTesterAgent:
|
||||
if mcp_config_path.exists():
|
||||
self._tool_registry.load_mcp_config(mcp_config_path)
|
||||
|
||||
extra_kwargs = getattr(self.config, "extra_kwargs", {}) or {}
|
||||
llm = LiteLLMProvider(
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
api_base=self.config.api_base,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
|
||||
@@ -1,32 +1,19 @@
|
||||
"""Runtime configuration for Credential Tester agent."""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _load_preferred_model() -> str:
|
||||
"""Load preferred model from ~/.hive/configuration.json."""
|
||||
config_path = Path.home() / ".hive" / "configuration.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
llm = config.get("llm", {})
|
||||
if llm.get("provider") and llm.get("model"):
|
||||
return f"{llm['provider']}/{llm['model']}"
|
||||
except Exception:
|
||||
pass
|
||||
return "anthropic/claude-sonnet-4-20250514"
|
||||
from framework.config import RuntimeConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
model: str = field(default_factory=_load_preferred_model)
|
||||
temperature: float = 0.3
|
||||
max_tokens: int = 16000
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
class AgentMetadata:
|
||||
name: str = "Credential Tester"
|
||||
version: str = "1.0.0"
|
||||
description: str = (
|
||||
"Test connected accounts by making real API calls. "
|
||||
"Pick an account, verify credentials work, and explore available tools."
|
||||
)
|
||||
|
||||
|
||||
default_config = RuntimeConfig()
|
||||
metadata = AgentMetadata()
|
||||
default_config = RuntimeConfig(temperature=0.3)
|
||||
|
||||
@@ -39,6 +39,12 @@ You are a credential tester for the account: {provider}/{alias}{detail}
|
||||
Your job is to help the user verify that this credential works by making \
|
||||
real API calls using the available tools.
|
||||
|
||||
# Account routing
|
||||
|
||||
IMPORTANT: Always pass `account="{alias}"` when calling any tool. \
|
||||
This routes the API call to the correct credential. Never use the email \
|
||||
or any other identifier — always use the alias exactly as shown.
|
||||
|
||||
# Instructions
|
||||
|
||||
1. Start by greeting the user and confirming which account you're testing.
|
||||
|
||||
@@ -298,6 +298,27 @@ class EventLoopNode(NodeProtocol):
|
||||
if ctx.accounts_prompt:
|
||||
system_prompt = f"{system_prompt}\n\n{ctx.accounts_prompt}"
|
||||
|
||||
# Inject agent working memory (adapt.md).
|
||||
# If it doesn't exist yet, seed it with available context.
|
||||
if self._config.spillover_dir:
|
||||
_adapt_path = Path(self._config.spillover_dir) / "adapt.md"
|
||||
if not _adapt_path.exists() and ctx.accounts_prompt:
|
||||
_adapt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_adapt_path.write_text(
|
||||
f"## Identity\n{ctx.accounts_prompt}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
if _adapt_path.exists():
|
||||
_adapt_text = _adapt_path.read_text(encoding="utf-8").strip()
|
||||
if _adapt_text:
|
||||
system_prompt = (
|
||||
f"{system_prompt}\n\n"
|
||||
f"--- Your Memory ---\n{_adapt_text}\n--- End Memory ---\n\n"
|
||||
'Maintain your memory by calling save_data("adapt.md", ...) '
|
||||
'or edit_data("adapt.md", ...) as you work. '
|
||||
"Record identity, session history, decisions, and working notes."
|
||||
)
|
||||
|
||||
conversation = NodeConversation(
|
||||
system_prompt=system_prompt,
|
||||
max_history_tokens=self._config.max_history_tokens,
|
||||
@@ -1298,7 +1319,7 @@ class EventLoopNode(NodeProtocol):
|
||||
node_id=node_id,
|
||||
reason=tc.tool_input.get("reason", ""),
|
||||
context=tc.tool_input.get("context", ""),
|
||||
execution_id=ctx.execution_id,
|
||||
execution_id=stream_id,
|
||||
)
|
||||
# Block like ask_user — the TUI loads the coder,
|
||||
# and /back injects a message to unblock us.
|
||||
@@ -1672,6 +1693,12 @@ class EventLoopNode(NodeProtocol):
|
||||
),
|
||||
)
|
||||
|
||||
# Client-facing nodes with no output keys are meant for
|
||||
# continuous interaction — they should not auto-accept.
|
||||
# Only exit via shutdown, max_iterations, or max_node_visits.
|
||||
if not output_keys and ctx.node_spec.client_facing:
|
||||
return JudgeVerdict(action="RETRY", feedback="")
|
||||
|
||||
# Level 2: conversation-aware quality check (if success_criteria set)
|
||||
if ctx.node_spec.success_criteria and ctx.llm:
|
||||
from framework.graph.conversation_judge import evaluate_phase_completion
|
||||
|
||||
@@ -1174,6 +1174,23 @@ class GraphExecutor:
|
||||
# Build Layer 2 (narrative) from current state
|
||||
narrative = build_narrative(memory, path, graph)
|
||||
|
||||
# Read agent working memory (adapt.md) once for both
|
||||
# system prompt and transition marker.
|
||||
_adapt_text: str | None = None
|
||||
if self._storage_path:
|
||||
_adapt_path = self._storage_path / "data" / "adapt.md"
|
||||
if _adapt_path.exists():
|
||||
_raw = _adapt_path.read_text(encoding="utf-8").strip()
|
||||
_adapt_text = _raw or None
|
||||
|
||||
# Merge adapt.md into narrative for system prompt
|
||||
if _adapt_text:
|
||||
narrative = (
|
||||
f"{narrative}\n\n--- Agent Memory ---\n{_adapt_text}"
|
||||
if narrative
|
||||
else _adapt_text
|
||||
)
|
||||
|
||||
# Compose new system prompt (Layer 1 + 2 + 3 + accounts)
|
||||
new_system = compose_system_prompt(
|
||||
identity_prompt=getattr(graph, "identity_prompt", None),
|
||||
@@ -1203,6 +1220,7 @@ class GraphExecutor:
|
||||
memory=memory,
|
||||
cumulative_tool_names=sorted(cumulative_tool_names),
|
||||
data_dir=data_dir,
|
||||
adapt_content=_adapt_text,
|
||||
)
|
||||
await continuous_conversation.add_user_message(
|
||||
marker,
|
||||
|
||||
@@ -151,6 +151,7 @@ def build_transition_marker(
|
||||
memory: SharedMemory,
|
||||
cumulative_tool_names: list[str],
|
||||
data_dir: Path | str | None = None,
|
||||
adapt_content: str | None = None,
|
||||
) -> str:
|
||||
"""Build a 'State of the World' transition marker.
|
||||
|
||||
@@ -164,6 +165,7 @@ def build_transition_marker(
|
||||
memory: Current shared memory state.
|
||||
cumulative_tool_names: All tools available (cumulative set).
|
||||
data_dir: Path to spillover data directory.
|
||||
adapt_content: Agent working memory (adapt.md) content.
|
||||
|
||||
Returns:
|
||||
Transition marker message text.
|
||||
@@ -205,6 +207,10 @@ def build_transition_marker(
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Agent working memory
|
||||
if adapt_content:
|
||||
sections.append(f"\n--- Agent Memory ---\n{adapt_content}")
|
||||
|
||||
# Available tools
|
||||
if cumulative_tool_names:
|
||||
sections.append("\nAvailable tools: " + ", ".join(sorted(cumulative_tool_names)))
|
||||
|
||||
@@ -351,6 +351,11 @@ class AgentRunner:
|
||||
intro_message: str = "",
|
||||
runtime_config: "AgentRuntimeConfig | None" = None,
|
||||
interactive: bool = True,
|
||||
skip_credential_validation: bool = False,
|
||||
skip_guardian: bool = False,
|
||||
requires_account_selection: bool = False,
|
||||
configure_for_account: Callable | None = None,
|
||||
list_accounts: Callable | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the runner (use AgentRunner.load() instead).
|
||||
@@ -366,6 +371,11 @@ class AgentRunner:
|
||||
runtime_config: Optional AgentRuntimeConfig (webhook settings, etc.)
|
||||
interactive: If True (default), offer interactive credential setup on failure.
|
||||
Set to False when called from the TUI (which handles setup via its own screen).
|
||||
skip_credential_validation: If True, skip credential checks at load time.
|
||||
skip_guardian: If True, don't attach the Hive Coder guardian.
|
||||
requires_account_selection: If True, TUI shows account picker before starting.
|
||||
configure_for_account: Callback(runner, account_dict) to scope tools after selection.
|
||||
list_accounts: Callback() -> list[dict] to fetch available accounts.
|
||||
"""
|
||||
self.agent_path = agent_path
|
||||
self.graph = graph
|
||||
@@ -375,6 +385,11 @@ class AgentRunner:
|
||||
self.intro_message = intro_message
|
||||
self.runtime_config = runtime_config
|
||||
self._interactive = interactive
|
||||
self.skip_credential_validation = skip_credential_validation
|
||||
self.skip_guardian = skip_guardian
|
||||
self.requires_account_selection = requires_account_selection
|
||||
self._configure_for_account = configure_for_account
|
||||
self._list_accounts = list_accounts
|
||||
|
||||
# Set up storage
|
||||
if storage_path:
|
||||
@@ -425,6 +440,9 @@ class AgentRunner:
|
||||
When ``interactive`` is False (e.g. TUI callers), the CredentialError
|
||||
propagates immediately so the caller can handle it with its own UI.
|
||||
"""
|
||||
if self.skip_credential_validation:
|
||||
return
|
||||
|
||||
if not self._interactive:
|
||||
# Let the CredentialError propagate — caller handles UI.
|
||||
validate_agent_credentials(self.graph.nodes)
|
||||
@@ -586,6 +604,13 @@ class AgentRunner:
|
||||
# Read runtime config (webhook settings, etc.) if defined
|
||||
agent_runtime_config = getattr(agent_module, "runtime_config", None)
|
||||
|
||||
# Read pre-run hooks (e.g., credential_tester needs account selection)
|
||||
skip_cred = getattr(agent_module, "skip_credential_validation", False)
|
||||
no_guardian = getattr(agent_module, "skip_guardian", False)
|
||||
needs_acct = getattr(agent_module, "requires_account_selection", False)
|
||||
configure_fn = getattr(agent_module, "configure_for_account", None)
|
||||
list_accts_fn = getattr(agent_module, "list_connected_accounts", None)
|
||||
|
||||
return cls(
|
||||
agent_path=agent_path,
|
||||
graph=graph,
|
||||
@@ -596,6 +621,11 @@ class AgentRunner:
|
||||
intro_message=intro_message,
|
||||
runtime_config=agent_runtime_config,
|
||||
interactive=interactive,
|
||||
skip_credential_validation=skip_cred,
|
||||
skip_guardian=no_guardian,
|
||||
requires_account_selection=needs_acct,
|
||||
configure_for_account=configure_fn,
|
||||
list_accounts=list_accts_fn,
|
||||
)
|
||||
|
||||
# Fallback: load from agent.json (legacy JSON-based agents)
|
||||
|
||||
@@ -51,6 +51,7 @@ class ToolRegistry:
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -480,6 +481,56 @@ class ToolRegistry:
|
||||
|
||||
return tool
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider-based tool filtering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_provider_index(self) -> None:
|
||||
"""Build provider -> tool-name mapping from CREDENTIAL_SPECS.
|
||||
|
||||
Populates ``_provider_index`` so :meth:`get_by_provider` works.
|
||||
Safe to call even if ``aden_tools`` is not installed (silently no-ops).
|
||||
"""
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
except ImportError:
|
||||
logger.debug("aden_tools not available, skipping provider index")
|
||||
return
|
||||
|
||||
self._provider_index.clear()
|
||||
for spec in CREDENTIAL_SPECS.values():
|
||||
provider = spec.aden_provider_name
|
||||
if provider:
|
||||
if provider not in self._provider_index:
|
||||
self._provider_index[provider] = set()
|
||||
self._provider_index[provider].update(spec.tools)
|
||||
|
||||
def get_by_provider(self, provider: str) -> dict[str, Tool]:
|
||||
"""Return registered tools that belong to *provider*.
|
||||
|
||||
Lazily builds the provider index on first call.
|
||||
"""
|
||||
if not self._provider_index:
|
||||
self.build_provider_index()
|
||||
tool_names = self._provider_index.get(provider, set())
|
||||
return {name: rt.tool for name, rt in self._tools.items() if name in tool_names}
|
||||
|
||||
def get_tool_names_by_provider(self, provider: str) -> list[str]:
|
||||
"""Return sorted registered tool names for *provider*."""
|
||||
if not self._provider_index:
|
||||
self.build_provider_index()
|
||||
tool_names = self._provider_index.get(provider, set())
|
||||
return sorted(name for name in self._tools if name in tool_names)
|
||||
|
||||
def get_all_provider_tool_names(self) -> list[str]:
|
||||
"""Return sorted names of all registered tools that belong to any provider."""
|
||||
if not self._provider_index:
|
||||
self.build_provider_index()
|
||||
all_names: set[str] = set()
|
||||
for names in self._provider_index.values():
|
||||
all_names.update(names)
|
||||
return sorted(name for name in self._tools if name in all_names)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all MCP client connections."""
|
||||
for client in self._mcp_clients:
|
||||
|
||||
+83
-14
@@ -366,19 +366,6 @@ class AdenTUI(App):
|
||||
interactive=False,
|
||||
)
|
||||
runner = await loop.run_in_executor(None, load_fn)
|
||||
if runner._agent_runtime is None:
|
||||
await loop.run_in_executor(None, runner._setup)
|
||||
|
||||
if not self._no_guardian and runner._agent_runtime:
|
||||
from framework.agents.hive_coder.guardian import attach_guardian
|
||||
|
||||
attach_guardian(runner._agent_runtime, runner._tool_registry)
|
||||
|
||||
if runner._agent_runtime and not runner._agent_runtime.is_running:
|
||||
await runner._agent_runtime.start()
|
||||
|
||||
self._runner = runner
|
||||
self.runtime = runner._agent_runtime
|
||||
except CredentialError as e:
|
||||
self.status_bar.set_graph_id("")
|
||||
self._show_credential_setup(
|
||||
@@ -391,7 +378,47 @@ class AdenTUI(App):
|
||||
self.notify(f"Failed to load agent: {e}", severity="error", timeout=10)
|
||||
return
|
||||
|
||||
# 4. Mount new widgets and subscribe to events
|
||||
# 4. Pre-run account selection (if agent requires it)
|
||||
if runner.requires_account_selection and runner._configure_for_account:
|
||||
try:
|
||||
if runner._list_accounts:
|
||||
accounts = await loop.run_in_executor(None, runner._list_accounts)
|
||||
else:
|
||||
accounts = []
|
||||
except Exception as e:
|
||||
self.notify(f"Failed to list accounts: {e}", severity="error", timeout=10)
|
||||
accounts = []
|
||||
if accounts:
|
||||
self._show_account_selection(runner, accounts)
|
||||
return # Continuation via callback
|
||||
|
||||
# 5. Complete the load
|
||||
await self._finish_agent_load(runner)
|
||||
|
||||
async def _finish_agent_load(self, runner) -> None:
|
||||
"""Complete agent setup, guardian attach, and widget mount."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
if runner._agent_runtime is None:
|
||||
await loop.run_in_executor(None, runner._setup)
|
||||
|
||||
if not self._no_guardian and not runner.skip_guardian and runner._agent_runtime:
|
||||
from framework.agents.hive_coder.guardian import attach_guardian
|
||||
|
||||
attach_guardian(runner._agent_runtime, runner._tool_registry)
|
||||
|
||||
if runner._agent_runtime and not runner._agent_runtime.is_running:
|
||||
await runner._agent_runtime.start()
|
||||
|
||||
self._runner = runner
|
||||
self.runtime = runner._agent_runtime
|
||||
except Exception as e:
|
||||
self.status_bar.set_graph_id("")
|
||||
self.notify(f"Failed to load agent: {e}", severity="error", timeout=10)
|
||||
return
|
||||
|
||||
self._mount_agent_widgets()
|
||||
await self._init_runtime_connection()
|
||||
|
||||
@@ -399,8 +426,50 @@ class AdenTUI(App):
|
||||
self._resume_session = None
|
||||
self._resume_checkpoint = None
|
||||
|
||||
agent_name = runner.agent_path.name
|
||||
self.notify(f"Agent loaded: {agent_name}", severity="information", timeout=3)
|
||||
|
||||
def _show_account_selection(self, runner, accounts: list[dict]) -> None:
|
||||
"""Show the account selection screen and continue loading on selection."""
|
||||
from framework.tui.screens.account_selection import AccountSelectionScreen
|
||||
|
||||
def _on_selection(selected: dict | None) -> None:
|
||||
if selected is None:
|
||||
self.status_bar.set_graph_id("")
|
||||
self.notify(
|
||||
"Account selection cancelled. Agent not loaded.",
|
||||
severity="warning",
|
||||
timeout=5,
|
||||
)
|
||||
return
|
||||
|
||||
# Scope tools to the selected provider
|
||||
if runner._configure_for_account:
|
||||
runner._configure_for_account(runner, selected)
|
||||
|
||||
# Validate credentials for the now-scoped provider
|
||||
from framework.credentials.models import CredentialError as CredError
|
||||
from framework.credentials.validation import validate_agent_credentials
|
||||
|
||||
try:
|
||||
validate_agent_credentials(runner.graph.nodes)
|
||||
except CredError as e:
|
||||
self._show_credential_setup(
|
||||
str(runner.agent_path),
|
||||
credential_error=e,
|
||||
)
|
||||
return
|
||||
|
||||
# Continue with the rest of agent loading
|
||||
self._do_finish_agent_load(runner)
|
||||
|
||||
self.push_screen(AccountSelectionScreen(accounts), callback=_on_selection)
|
||||
|
||||
@work(exclusive=True)
|
||||
async def _do_finish_agent_load(self, runner) -> None:
|
||||
"""Worker wrapper for _finish_agent_load (used by account selection callback)."""
|
||||
await self._finish_agent_load(runner)
|
||||
|
||||
def _show_credential_setup(
|
||||
self,
|
||||
agent_path: str,
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Account selection ModalScreen for picking a connected account before agent start."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.text import Text
|
||||
from textual.app import ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Label, OptionList
|
||||
from textual.widgets._option_list import Option
|
||||
|
||||
|
||||
class AccountSelectionScreen(ModalScreen[dict | None]):
|
||||
"""Modal screen showing connected accounts for pre-run selection.
|
||||
|
||||
Returns the selected account dict, or None if dismissed.
|
||||
"""
|
||||
|
||||
SCOPED_CSS = False
|
||||
|
||||
BINDINGS = [
|
||||
Binding("escape", "dismiss_picker", "Cancel"),
|
||||
]
|
||||
|
||||
DEFAULT_CSS = """
|
||||
AccountSelectionScreen {
|
||||
align: center middle;
|
||||
}
|
||||
#acct-container {
|
||||
width: 70%;
|
||||
max-width: 80;
|
||||
height: 60%;
|
||||
background: $surface;
|
||||
border: heavy $primary;
|
||||
padding: 1 2;
|
||||
}
|
||||
#acct-title {
|
||||
text-align: center;
|
||||
text-style: bold;
|
||||
width: 100%;
|
||||
color: $text;
|
||||
}
|
||||
#acct-subtitle {
|
||||
text-align: center;
|
||||
width: 100%;
|
||||
margin-bottom: 1;
|
||||
}
|
||||
#acct-footer {
|
||||
text-align: center;
|
||||
width: 100%;
|
||||
margin-top: 1;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, accounts: list[dict]) -> None:
|
||||
super().__init__()
|
||||
self._accounts = accounts
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
n = len(self._accounts)
|
||||
with Vertical(id="acct-container"):
|
||||
yield Label("Select Account to Test", id="acct-title")
|
||||
yield Label(
|
||||
f"[dim]{n} connected account{'s' if n != 1 else ''}[/dim]",
|
||||
id="acct-subtitle",
|
||||
)
|
||||
option_list = OptionList(id="acct-list")
|
||||
for i, acct in enumerate(self._accounts):
|
||||
provider = acct.get("provider", "unknown")
|
||||
alias = acct.get("alias", "unknown")
|
||||
email = acct.get("identity", {}).get("email", "")
|
||||
label = Text()
|
||||
label.append(f"{provider}/", style="bold")
|
||||
label.append(alias, style="bold cyan")
|
||||
if email:
|
||||
label.append(f" ({email})", style="dim")
|
||||
option_list.add_option(Option(label, id=f"acct-{i}"))
|
||||
yield option_list
|
||||
yield Label(
|
||||
"[dim]Enter[/dim] Select [dim]Esc[/dim] Cancel",
|
||||
id="acct-footer",
|
||||
)
|
||||
|
||||
def on_mount(self) -> None:
|
||||
ol = self.query_one("#acct-list", OptionList)
|
||||
ol.styles.height = "1fr"
|
||||
|
||||
def on_option_list_option_selected(self, event: OptionList.OptionSelected) -> None:
|
||||
idx = event.option_index
|
||||
if 0 <= idx < len(self._accounts):
|
||||
self.dismiss(self._accounts[idx])
|
||||
|
||||
def action_dismiss_picker(self) -> None:
|
||||
self.dismiss(None)
|
||||
@@ -115,8 +115,8 @@ def register_tools(
|
||||
query: Gmail search query (default: "is:unread").
|
||||
max_results: Maximum messages to return (1-500, default 100).
|
||||
page_token: Token for fetching the next page of results.
|
||||
account: Account to use when multiple accounts are connected
|
||||
(e.g. "alice@gmail.com"). Leave empty for default.
|
||||
account: Account alias to target a specific account
|
||||
(e.g. "Timothy"). Leave empty for default.
|
||||
|
||||
Returns:
|
||||
Dict with "messages" list (each has "id" and "threadId"),
|
||||
|
||||
Reference in New Issue
Block a user