Merge pull request #44 from TimothyZhang7/feature/persistent-sessions

Feature/persistent sessions
This commit is contained in:
Timothy @aden
2026-01-21 09:10:45 -08:00
committed by GitHub
10 changed files with 644 additions and 112 deletions
+3 -1
View File
@@ -66,4 +66,6 @@ core/data/
tmp/
temp/
exports/*
exports/*
core/.agent-builder-sessions/*
+3 -3
View File
@@ -1,9 +1,9 @@
{
"mcpServers": {
"agent-builder": {
"command": "python",
"args": ["-m", "framework.mcp.agent_builder_server"],
"cwd": "/home/timothy/oss/hive/core"
"command": "bash",
"args": ["-c", "cd core && exec python -m framework.mcp.agent_builder_server"],
"cwd": "core"
}
}
}
+17 -2
View File
@@ -20,6 +20,19 @@ Environment Variables:
"""
import argparse
import os
import sys
# Suppress FastMCP banner in STDIO mode
if "--stdio" in sys.argv:
# Monkey-patch rich Console to redirect to stderr
import rich.console
_original_console_init = rich.console.Console.__init__
def _patched_console_init(self, *args, **kwargs):
kwargs['file'] = sys.stderr # Force all rich output to stderr
_original_console_init(self, *args, **kwargs)
rich.console.Console.__init__ = _patched_console_init
from fastmcp import FastMCP
from starlette.requests import Request
@@ -31,7 +44,9 @@ mcp = FastMCP("aden-tools")
from aden_tools.tools import register_all_tools
tools = register_all_tools(mcp)
print(f"[MCP] Registered {len(tools)} tools: {tools}")
# Only print to stdout in HTTP mode (STDIO mode requires clean stdout for JSON-RPC)
if "--stdio" not in sys.argv:
print(f"[MCP] Registered {len(tools)} tools: {tools}")
@mcp.custom_route("/health", methods=["GET"])
@@ -68,7 +83,7 @@ def main() -> None:
args = parser.parse_args()
if args.stdio:
print("[MCP] Starting with STDIO transport")
# STDIO mode: only JSON-RPC messages go to stdout
mcp.run(transport="stdio")
else:
print(f"[MCP] Starting HTTP server on {args.host}:{args.port}")
+6 -6
View File
@@ -165,12 +165,7 @@ class GraphExecutor:
path.append(current_node_id)
# Check if terminal
if current_node_id in graph.terminal_nodes:
self.logger.info(f"✓ Reached terminal node: {node_spec.name}")
break
# Check if pause (HITL)
# Check if pause (HITL) before execution
if current_node_id in graph.pause_nodes:
self.logger.info(f"⏸ Paused at HITL node: {node_spec.name}")
# Execute this node, then pause
@@ -279,6 +274,11 @@ class GraphExecutor:
session_state=session_state_out,
)
# Check if this is a terminal node - if so, we're done
if node_spec.id in graph.terminal_nodes:
self.logger.info(f"✓ Reached terminal node: {node_spec.name}")
break
# Determine next node
if result.next_node:
# Router explicitly set next node
+85 -16
View File
@@ -431,22 +431,13 @@ class LLMNode(NodeProtocol):
# Write to output keys
output = self._parse_output(response.content, ctx.node_spec)
# For llm_generate nodes, try to parse JSON and extract fields
if ctx.node_spec.node_type == "llm_generate" and len(ctx.node_spec.output_keys) > 1:
# For llm_generate and llm_tool_use nodes, try to parse JSON and extract fields
if ctx.node_spec.node_type in ("llm_generate", "llm_tool_use") and len(ctx.node_spec.output_keys) > 1:
try:
# Try to parse as JSON
import json
import re
# Remove markdown code blocks if present
content = response.content.strip()
if content.startswith("```"):
# Extract JSON from code block
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
if match:
content = match.group(1).strip()
parsed = json.loads(content)
# Try direct JSON parse first
parsed = self._extract_json_with_haiku(response.content, ctx.node_spec.output_keys)
# If parsed successfully, write each field to its corresponding output key
if isinstance(parsed, dict):
@@ -454,8 +445,12 @@ class LLMNode(NodeProtocol):
if key in parsed:
ctx.memory.write(key, parsed[key])
output[key] = parsed[key]
elif key in ctx.input_data:
# Key not in parsed JSON but exists in input - pass through input value
ctx.memory.write(key, ctx.input_data[key])
output[key] = ctx.input_data[key]
else:
# Key not in parsed JSON, write the whole response
# Key not in parsed JSON or input, write the whole response
ctx.memory.write(key, response.content)
output[key] = response.content
else:
@@ -465,8 +460,8 @@ class LLMNode(NodeProtocol):
output[key] = response.content
except (json.JSONDecodeError, Exception) as e:
# JSON parsing failed, fall back to writing entire response
logger.warning(f" ⚠ Failed to parse JSON output, using raw response: {e}")
# JSON extraction failed completely
logger.warning(f" ⚠ Failed to extract JSON output: {e}")
for key in ctx.node_spec.output_keys:
ctx.memory.write(key, response.content)
output[key] = response.content
@@ -503,6 +498,80 @@ class LLMNode(NodeProtocol):
# Default output
return {"result": content}
def _extract_json_with_haiku(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
"""Use Haiku to extract clean JSON from potentially verbose LLM response."""
import json
import re
# Try direct JSON parse first (fast path)
try:
content = raw_response.strip()
# Remove markdown code blocks if present
if content.startswith("```"):
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
if match:
content = match.group(1).strip()
parsed = json.loads(content)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
# JSON parse failed - use Haiku to extract clean JSON
import os
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
# No API key, try one more simple extraction
try:
# Find first { and last }
start = raw_response.find('{')
end = raw_response.rfind('}')
if start != -1 and end != -1:
json_str = raw_response[start:end+1]
return json.loads(json_str)
except:
pass
raise ValueError("Cannot parse JSON and no API key for Haiku cleanup")
# Use Haiku to clean the response
from framework.llm.anthropic import AnthropicProvider
haiku = AnthropicProvider(model="claude-3-5-haiku-20241022")
prompt = f"""Extract the JSON object from this LLM response. Extract ONLY the values that the LLM actually generated.
Expected output keys: {output_keys}
LLM Response:
{raw_response}
IMPORTANT:
- Only extract keys that the LLM explicitly output in its response
- Do NOT include keys that were just mentioned or passed through from input
- If the LLM output multiple pieces of text/JSON, extract the LAST JSON object only
- Output ONLY valid JSON with no extra text, no markdown, no explanations"""
try:
result = haiku.complete(
messages=[{"role": "user", "content": prompt}],
system="You extract clean JSON from messy responses. Output only valid JSON, nothing else.",
)
cleaned = result.content.strip()
# Remove markdown if Haiku added it
if cleaned.startswith("```"):
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, re.DOTALL)
if match:
cleaned = match.group(1).strip()
parsed = json.loads(cleaned)
logger.info(f" ✓ Haiku cleaned JSON output")
return parsed
except Exception as e:
logger.warning(f" ⚠ Haiku JSON extraction failed: {e}")
raise
def _build_messages(self, ctx: NodeContext) -> list[dict]:
"""Build the message list for the LLM."""
# Use Haiku to intelligently format inputs from memory
+290 -3
View File
@@ -31,27 +31,149 @@ from framework.testing.parallel import AgentFactory
mcp = FastMCP("agent-builder")
# Session persistence directory
SESSIONS_DIR = Path(".agent-builder-sessions")
ACTIVE_SESSION_FILE = SESSIONS_DIR / ".active"
# Session storage
class BuildSession:
"""In-memory build session."""
"""Build session with persistence support."""
def __init__(self, name: str):
self.id = f"build_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
def __init__(self, name: str, session_id: str | None = None):
self.id = session_id or f"build_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.name = name
self.goal: Goal | None = None
self.nodes: list[NodeSpec] = []
self.edges: list[EdgeSpec] = []
self.mcp_servers: list[dict] = [] # MCP server configurations
self.created_at = datetime.now().isoformat()
self.last_modified = datetime.now().isoformat()
def to_dict(self) -> dict:
"""Serialize session to dictionary."""
return {
"session_id": self.id,
"name": self.name,
"goal": self.goal.model_dump() if self.goal else None,
"nodes": [n.model_dump() for n in self.nodes],
"edges": [e.model_dump() for e in self.edges],
"mcp_servers": self.mcp_servers,
"created_at": self.created_at,
"last_modified": self.last_modified,
}
@classmethod
def from_dict(cls, data: dict) -> "BuildSession":
"""Deserialize session from dictionary."""
session = cls(name=data["name"], session_id=data["session_id"])
session.created_at = data.get("created_at", session.created_at)
session.last_modified = data.get("last_modified", session.last_modified)
# Restore goal
if data.get("goal"):
goal_data = data["goal"]
session.goal = Goal(
id=goal_data["id"],
name=goal_data["name"],
description=goal_data["description"],
success_criteria=[
SuccessCriterion(**sc) for sc in goal_data.get("success_criteria", [])
],
constraints=[
Constraint(**c) for c in goal_data.get("constraints", [])
],
)
# Restore nodes
session.nodes = [NodeSpec(**n) for n in data.get("nodes", [])]
# Restore edges
edges_data = data.get("edges", [])
for e in edges_data:
# Convert condition string back to enum
condition_str = e.get("condition")
if isinstance(condition_str, str):
condition_map = {
"always": EdgeCondition.ALWAYS,
"on_success": EdgeCondition.ON_SUCCESS,
"on_failure": EdgeCondition.ON_FAILURE,
"conditional": EdgeCondition.CONDITIONAL,
}
e["condition"] = condition_map.get(condition_str, EdgeCondition.ON_SUCCESS)
session.edges.append(EdgeSpec(**e))
# Restore MCP servers
session.mcp_servers = data.get("mcp_servers", [])
return session
# Global session
_session: BuildSession | None = None
def _ensure_sessions_dir():
"""Ensure sessions directory exists."""
SESSIONS_DIR.mkdir(exist_ok=True)
def _save_session(session: BuildSession):
"""Save session to disk."""
_ensure_sessions_dir()
# Update last modified
session.last_modified = datetime.now().isoformat()
# Save session file
session_file = SESSIONS_DIR / f"{session.id}.json"
with open(session_file, "w") as f:
json.dump(session.to_dict(), f, indent=2, default=str)
# Update active session pointer
with open(ACTIVE_SESSION_FILE, "w") as f:
f.write(session.id)
def _load_session(session_id: str) -> BuildSession:
"""Load session from disk."""
session_file = SESSIONS_DIR / f"{session_id}.json"
if not session_file.exists():
raise ValueError(f"Session '{session_id}' not found")
with open(session_file, "r") as f:
data = json.load(f)
return BuildSession.from_dict(data)
def _load_active_session() -> BuildSession | None:
"""Load the active session if one exists."""
if not ACTIVE_SESSION_FILE.exists():
return None
try:
with open(ACTIVE_SESSION_FILE, "r") as f:
session_id = f.read().strip()
if session_id:
return _load_session(session_id)
except Exception:
pass
return None
def get_session() -> BuildSession:
global _session
# Try to load active session if no session in memory
if _session is None:
_session = _load_active_session()
if _session is None:
raise ValueError("No active session. Call create_session first.")
return _session
@@ -64,13 +186,122 @@ def create_session(name: Annotated[str, "Name for the agent being built"]) -> st
"""Create a new agent building session. Call this first before building an agent."""
global _session
_session = BuildSession(name)
_save_session(_session) # Auto-save
return json.dumps({
"session_id": _session.id,
"name": name,
"status": "created",
"persisted": True,
})
@mcp.tool()
def list_sessions() -> str:
"""List all saved agent building sessions."""
_ensure_sessions_dir()
sessions = []
if SESSIONS_DIR.exists():
for session_file in SESSIONS_DIR.glob("*.json"):
try:
with open(session_file, "r") as f:
data = json.load(f)
sessions.append({
"session_id": data["session_id"],
"name": data["name"],
"created_at": data.get("created_at"),
"last_modified": data.get("last_modified"),
"node_count": len(data.get("nodes", [])),
"edge_count": len(data.get("edges", [])),
"has_goal": data.get("goal") is not None,
})
except Exception:
pass # Skip corrupted files
# Check which session is currently active
active_id = None
if ACTIVE_SESSION_FILE.exists():
try:
with open(ACTIVE_SESSION_FILE, "r") as f:
active_id = f.read().strip()
except Exception:
pass
return json.dumps({
"sessions": sorted(sessions, key=lambda s: s["last_modified"], reverse=True),
"total": len(sessions),
"active_session_id": active_id,
}, indent=2)
@mcp.tool()
def load_session_by_id(session_id: Annotated[str, "ID of the session to load"]) -> str:
"""Load a previously saved agent building session by its ID."""
global _session
try:
_session = _load_session(session_id)
# Update active session pointer
with open(ACTIVE_SESSION_FILE, "w") as f:
f.write(session_id)
return json.dumps({
"success": True,
"session_id": _session.id,
"name": _session.name,
"node_count": len(_session.nodes),
"edge_count": len(_session.edges),
"has_goal": _session.goal is not None,
"created_at": _session.created_at,
"last_modified": _session.last_modified,
"message": f"Session '{_session.name}' loaded successfully"
})
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
})
@mcp.tool()
def delete_session(session_id: Annotated[str, "ID of the session to delete"]) -> str:
"""Delete a saved agent building session."""
global _session
session_file = SESSIONS_DIR / f"{session_id}.json"
if not session_file.exists():
return json.dumps({
"success": False,
"error": f"Session '{session_id}' not found"
})
try:
# Remove session file
session_file.unlink()
# Clear active session if it was the deleted one
if _session and _session.id == session_id:
_session = None
if ACTIVE_SESSION_FILE.exists():
with open(ACTIVE_SESSION_FILE, "r") as f:
active_id = f.read().strip()
if active_id == session_id:
ACTIVE_SESSION_FILE.unlink()
return json.dumps({
"success": True,
"deleted_session_id": session_id,
"message": f"Session '{session_id}' deleted successfully"
})
except Exception as e:
return json.dumps({
"success": False,
"error": str(e)
})
@mcp.tool()
def set_goal(
goal_id: Annotated[str, "Unique identifier for the goal"],
@@ -132,6 +363,8 @@ def set_goal(
if not constraint_list:
warnings.append("Consider adding constraints")
_save_session(session) # Auto-save
return json.dumps({
"valid": len(errors) == 0,
"errors": errors,
@@ -215,6 +448,8 @@ def add_node(
if node_type in ("llm_generate", "llm_tool_use") and not system_prompt:
warnings.append(f"LLM node '{node_id}' should have a system_prompt")
_save_session(session) # Auto-save
return json.dumps({
"valid": len(errors) == 0,
"errors": errors,
@@ -291,6 +526,8 @@ def add_edge(
if edge_condition == EdgeCondition.CONDITIONAL and not condition_expr:
errors.append(f"Conditional edge '{edge_id}' needs condition_expr")
_save_session(session) # Auto-save
return json.dumps({
"valid": len(errors) == 0,
"errors": errors,
@@ -374,6 +611,8 @@ def update_node(
if node.node_type in ("llm_generate", "llm_tool_use") and not node.system_prompt:
warnings.append(f"LLM node '{node_id}' should have a system_prompt")
_save_session(session) # Auto-save
return json.dumps({
"valid": len(errors) == 0,
"errors": errors,
@@ -431,6 +670,8 @@ def delete_node(
if not (e.source == node_id or e.target == node_id)
]
_save_session(session) # Auto-save
return json.dumps({
"valid": True,
"deleted_node": removed_node.model_dump(),
@@ -461,6 +702,8 @@ def delete_edge(
# Remove the edge
removed_edge = session.edges.pop(edge_idx)
_save_session(session) # Auto-save
return json.dumps({
"valid": True,
"deleted_edge": removed_edge.model_dump(),
@@ -893,6 +1136,46 @@ def export_graph() -> str:
entry_node = validation["entry_node"]
terminal_nodes = validation["terminal_nodes"]
# Extract pause/resume configuration from validation
pause_nodes = validation.get("pause_nodes", [])
resume_entry_points = validation.get("resume_entry_points", [])
# Build entry_points dict for pause/resume architecture
entry_points = {}
if entry_node:
entry_points["start"] = entry_node
# Add resume entry points with {pause_node}_resume naming convention
if pause_nodes and resume_entry_points:
# Strategy 1: Try to match by checking which resume node uses the pause node's outputs
pause_to_resume = {}
for pause_node_id in pause_nodes:
pause_node = next((n for n in session.nodes if n.id == pause_node_id), None)
if not pause_node:
continue
# Find resume nodes that read the outputs of this pause node
for resume_node_id in resume_entry_points:
resume_node = next((n for n in session.nodes if n.id == resume_node_id), None)
if not resume_node:
continue
# Check if resume node reads pause node's outputs
shared_keys = set(pause_node.output_keys) & set(resume_node.input_keys)
if shared_keys:
pause_to_resume[pause_node_id] = resume_node_id
break
# Strategy 2: Fallback - pair sequentially if no match found
unmatched_pause = [p for p in pause_nodes if p not in pause_to_resume]
unmatched_resume = [r for r in resume_entry_points if r not in pause_to_resume.values()]
for pause_id, resume_id in zip(unmatched_pause, unmatched_resume):
pause_to_resume[pause_id] = resume_id
# Build entry_points dict
for pause_id, resume_id in pause_to_resume.items():
entry_points[f"{pause_id}_resume"] = resume_id
# Build edges list
edges_list = [
{
@@ -937,6 +1220,8 @@ def export_graph() -> str:
"goal_id": session.goal.id,
"version": "1.0.0",
"entry_node": entry_node,
"entry_points": entry_points,
"pause_nodes": pause_nodes,
"terminal_nodes": terminal_nodes,
"nodes": [node.model_dump() for node in session.nodes],
"edges": edges_list,
@@ -1171,6 +1456,7 @@ def add_mcp_server(
# Add to session
session.mcp_servers.append(server_config)
_save_session(session) # Auto-save
return json.dumps({
"success": True,
@@ -1290,6 +1576,7 @@ def remove_mcp_server(
for i, server in enumerate(session.mcp_servers):
if server["name"] == name:
session.mcp_servers.pop(i)
_save_session(session) # Auto-save
return json.dumps({
"success": True,
"removed": name,
+89 -31
View File
@@ -6,8 +6,6 @@ import json
import sys
from pathlib import Path
from framework.graph import ExecutionStatus
def register_commands(subparsers: argparse._SubParsersAction) -> None:
"""Register runner commands with the main CLI."""
@@ -48,6 +46,11 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
action="store_true",
help="Only output the final result JSON",
)
run_parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Show detailed execution logs (steps, LLM calls, etc.)",
)
run_parser.set_defaults(func=cmd_run)
# info command
@@ -166,8 +169,17 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
def cmd_run(args: argparse.Namespace) -> int:
"""Run an exported agent."""
import logging
from framework.runner import AgentRunner
# Set logging level (quiet by default for cleaner output)
if args.quiet:
logging.basicConfig(level=logging.ERROR, format='%(message)s')
elif getattr(args, 'verbose', False):
logging.basicConfig(level=logging.INFO, format='%(message)s')
else:
logging.basicConfig(level=logging.WARNING, format='%(message)s')
# Load input context
context = {}
if args.input:
@@ -195,6 +207,12 @@ def cmd_run(args: argparse.Namespace) -> int:
print(f"Error: {e}", file=sys.stderr)
return 1
# Auto-inject user_id if the agent expects it but it's not provided
entry_input_keys = runner.graph.nodes[0].input_keys if runner.graph.nodes else []
if "user_id" in entry_input_keys and context.get("user_id") is None:
import os
context["user_id"] = os.environ.get("USER", "default_user")
if not args.quiet:
info = runner.info()
print(f"Agent: {info.name}")
@@ -212,12 +230,14 @@ def cmd_run(args: argparse.Namespace) -> int:
# Format output
output = {
"status": result.status.value if hasattr(result.status, "value") else str(result.status),
"completed_steps": result.completed_steps,
"results": result.results,
"success": result.success,
"steps_executed": result.steps_executed,
"output": result.output,
}
if result.feedback:
output["feedback"] = result.feedback
if result.error:
output["error"] = result.error
if result.paused_at:
output["paused_at"] = result.paused_at
# Output results
if args.output:
@@ -231,27 +251,51 @@ def cmd_run(args: argparse.Namespace) -> int:
else:
print()
print("=" * 60)
status_str = result.status.value if hasattr(result.status, "value") else str(result.status)
status_str = "SUCCESS" if result.success else "FAILED"
print(f"Status: {status_str}")
print(f"Completed steps: {len(result.completed_steps)}")
print(f"Steps executed: {result.steps_executed}")
print(f"Path: {''.join(result.path)}")
print("=" * 60)
if result.status == ExecutionStatus.COMPLETED:
if result.success:
print("\n--- Results ---")
for key, value in result.results.items():
if isinstance(value, (dict, list)):
print(f"\n{key}:")
value_str = json.dumps(value, indent=2, default=str)
if len(value_str) > 500:
value_str = value_str[:500] + "..."
print(value_str)
else:
print(f"{key}: {str(value)[:200]}")
elif result.feedback:
print(f"\nFeedback: {result.feedback}")
# Show only meaningful output keys (skip internal/intermediate values)
meaningful_keys = ["final_response", "response", "result", "answer", "output"]
# Try to find the most relevant output
shown = False
for key in meaningful_keys:
if key in result.output:
value = result.output[key]
if isinstance(value, str) and len(value) > 10:
print(value)
shown = True
break
elif isinstance(value, (dict, list)):
print(json.dumps(value, indent=2, default=str))
shown = True
break
# If no meaningful key found, show all non-internal keys
if not shown:
for key, value in result.output.items():
if not key.startswith("_") and key not in ["user_id", "request", "memory_loaded", "user_profile", "recent_context"]:
if isinstance(value, (dict, list)):
print(f"\n{key}:")
value_str = json.dumps(value, indent=2, default=str)
if len(value_str) > 300:
value_str = value_str[:300] + "..."
print(value_str)
else:
val_str = str(value)
if len(val_str) > 200:
val_str = val_str[:200] + "..."
print(f"{key}: {val_str}")
elif result.error:
print(f"\nError: {result.error}")
runner.cleanup()
return 0 if result.status == ExecutionStatus.COMPLETED else 1
return 0 if result.success else 1
def cmd_info(args: argparse.Namespace) -> int:
@@ -760,6 +804,11 @@ def cmd_shell(args: argparse.Namespace) -> int:
# STARTING FRESH: Merge new input with accumulated session memory
run_context = {**session_memory, **context}
# Auto-inject user_id if missing (for personal assistant agents)
if "user_id" in entry_input_keys and run_context.get("user_id") is None:
import os
run_context["user_id"] = os.environ.get("USER", "default_user")
# Add conversation history to context if agent expects it
if conversation_history:
run_context["_conversation_history"] = conversation_history.copy()
@@ -778,16 +827,25 @@ def cmd_shell(args: argparse.Namespace) -> int:
print(f"Steps executed: {result.steps_executed}")
print(f"Path: {''.join(result.path)}")
# Show clean output - prioritize meaningful keys
if result.output:
print("\nOutput:")
for key, value in result.output.items():
if isinstance(value, (dict, list)):
value_str = json.dumps(value, indent=2, default=str)
if len(value_str) > 300:
value_str = value_str[:300] + "..."
print(f" {key}: {value_str}")
else:
print(f" {key}: {str(value)[:200]}")
meaningful_keys = ["final_response", "response", "result", "answer", "output"]
shown = False
for key in meaningful_keys:
if key in result.output:
value = result.output[key]
if isinstance(value, str) and len(value) > 10:
print(f"\n{value}\n")
shown = True
break
if not shown:
print("\nOutput:")
for key, value in result.output.items():
if not key.startswith("_"):
val_str = str(value)[:200]
print(f" {key}: {val_str}")
if result.error:
print(f"\nError: {result.error}")
+118 -45
View File
@@ -65,10 +65,15 @@ class MCPClient:
self._session = None
self._read_stream = None
self._write_stream = None
self._stdio_context = None # Context manager for stdio_client
self._http_client: httpx.Client | None = None
self._tools: dict[str, MCPTool] = {}
self._connected = False
# Background event loop for persistent STDIO connection
self._loop = None
self._loop_thread = None
def _run_async(self, coro):
"""
Run an async coroutine, handling both sync and async contexts.
@@ -79,6 +84,13 @@ class MCPClient:
Returns:
Result of the coroutine
"""
# If we have a persistent loop (for STDIO), use it
if self._loop is not None:
import concurrent.futures
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
# Otherwise, use the standard approach
try:
# Try to get the current event loop
asyncio.get_running_loop()
@@ -129,12 +141,12 @@ class MCPClient:
self._connected = True
def _connect_stdio(self) -> None:
"""Connect to MCP server via STDIO transport using MCP SDK."""
"""Connect to MCP server via STDIO transport using MCP SDK with persistent connection."""
if not self.config.command:
raise ValueError("command is required for STDIO transport")
try:
# Import MCP SDK
import threading
from mcp import StdioServerParameters
# Create server parameters
@@ -145,10 +157,62 @@ class MCPClient:
cwd=self.config.cwd,
)
# Store for later use in async context
# Store for later use
self._server_params = server_params
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO")
# Start background event loop for persistent connection
loop_started = threading.Event()
connection_ready = threading.Event()
connection_error = []
def run_event_loop():
"""Run event loop in background thread."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
loop_started.set()
# Initialize persistent connection
async def init_connection():
try:
from mcp import ClientSession
from mcp.client.stdio import stdio_client
# Create persistent stdio client context
self._stdio_context = stdio_client(server_params)
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
# Create persistent session
self._session = ClientSession(self._read_stream, self._write_stream)
await self._session.__aenter__()
# Initialize session
await self._session.initialize()
connection_ready.set()
except Exception as e:
connection_error.append(e)
connection_ready.set()
# Schedule connection initialization
self._loop.create_task(init_connection())
# Run loop forever
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
self._loop_thread.start()
# Wait for loop to start
loop_started.wait(timeout=5)
if not loop_started.is_set():
raise RuntimeError("Event loop failed to start")
# Wait for connection to be ready
connection_ready.wait(timeout=10)
if connection_error:
raise connection_error[0]
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO (persistent)")
except Exception as e:
raise RuntimeError(f"Failed to connect to MCP server: {e}")
@@ -196,28 +260,23 @@ class MCPClient:
raise
async def _list_tools_stdio_async(self) -> list[dict]:
"""List tools via STDIO protocol using MCP SDK."""
from mcp import ClientSession
from mcp.client.stdio import stdio_client
"""List tools via STDIO protocol using persistent session."""
if not self._session:
raise RuntimeError("STDIO session not initialized")
async with stdio_client(self._server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# List tools using persistent session
response = await self._session.list_tools()
# List tools
response = await session.list_tools()
# Convert tools to dict format
tools_list = []
for tool in response.tools:
tools_list.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
})
# Convert tools to dict format
tools_list = []
for tool in response.tools:
tools_list.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
})
return tools_list
return tools_list
def _list_tools_http(self) -> list[dict]:
"""List tools via HTTP protocol."""
@@ -280,31 +339,26 @@ class MCPClient:
return self._call_tool_http(tool_name, arguments)
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via STDIO protocol using MCP SDK."""
from mcp import ClientSession
from mcp.client.stdio import stdio_client
"""Call tool via STDIO protocol using persistent session."""
if not self._session:
raise RuntimeError("STDIO session not initialized")
async with stdio_client(self._server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# Call tool using persistent session
result = await self._session.call_tool(tool_name, arguments=arguments)
# Call tool
result = await session.call_tool(tool_name, arguments=arguments)
# Extract content
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, 'text'):
return content_item.text
elif hasattr(content_item, 'data'):
return content_item.data
return result.content
# Extract content
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, 'text'):
return content_item.text
elif hasattr(content_item, 'data'):
return content_item.data
return result.content
return None
return None
def _call_tool_http(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via HTTP protocol."""
@@ -336,6 +390,25 @@ class MCPClient:
def disconnect(self) -> None:
"""Disconnect from the MCP server."""
# Clean up persistent STDIO connection
if self._loop is not None:
# Stop event loop - this will cause context managers to clean up naturally
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
# Wait for thread to finish
if self._loop_thread and self._loop_thread.is_alive():
self._loop_thread.join(timeout=2)
# Clear references
self._session = None
self._stdio_context = None
self._read_stream = None
self._write_stream = None
self._loop = None
self._loop_thread = None
# Clean up HTTP client
if self._http_client:
self._http_client.close()
self._http_client = None
+18 -2
View File
@@ -195,8 +195,12 @@ class AgentRunner:
self._storage_path = storage_path
self._temp_dir = None
else:
self._temp_dir = tempfile.TemporaryDirectory()
self._storage_path = Path(self._temp_dir.name) / "runtime"
# Use persistent storage in ~/.hive by default
home = Path.home()
default_storage = home / ".hive" / "storage" / agent_path.name
default_storage.mkdir(parents=True, exist_ok=True)
self._storage_path = default_storage
self._temp_dir = None
# Initialize components
self._tool_registry = ToolRegistry()
@@ -366,6 +370,18 @@ class AgentRunner:
# Create runtime
self._runtime = Runtime(storage_path=self._storage_path)
# Set up session context for tools (workspace_id, agent_id, session_id)
workspace_id = "default" # Could be derived from storage path
agent_id = self.graph.id or "unknown"
# Use "current" as a stable session_id for persistent memory
session_id = "current"
self._tool_registry.set_session_context(
workspace_id=workspace_id,
agent_id=agent_id,
session_id=session_id,
)
# Create LLM provider (if not mock mode and API key available)
if not self.mock_mode and os.environ.get("ANTHROPIC_API_KEY"):
from framework.llm.anthropic import AnthropicProvider
+15 -3
View File
@@ -35,6 +35,7 @@ class ToolRegistry:
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
def register(
self,
@@ -227,6 +228,15 @@ class ToolRegistry:
"""Check if a tool is registered."""
return name in self._tools
def set_session_context(self, **context) -> None:
"""
Set session context to auto-inject into tool calls.
Args:
**context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id)
"""
self._session_context.update(context)
def register_mcp_server(
self,
server_config: dict[str, Any],
@@ -279,10 +289,12 @@ class ToolRegistry:
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
# Create executor that calls the MCP server
def make_mcp_executor(client_ref: MCPClient, tool_name: str):
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
def executor(inputs: dict) -> Any:
try:
result = client_ref.call_tool(tool_name, inputs)
# Inject session context for tools that need it
merged_inputs = {**registry_ref._session_context, **inputs}
result = client_ref.call_tool(tool_name, merged_inputs)
# MCP tools return content array, extract the result
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "text" in result[0]:
@@ -298,7 +310,7 @@ class ToolRegistry:
self.register(
mcp_tool.name,
tool,
make_mcp_executor(client, mcp_tool.name),
make_mcp_executor(client, mcp_tool.name, self),
)
count += 1