feat(cli): basic session support, mcp integration issues

This commit is contained in:
Timothy
2026-01-21 07:45:41 -08:00
parent ebafd90b9f
commit 26d0ab4419
9 changed files with 354 additions and 109 deletions
+3 -1
View File
@@ -63,4 +63,6 @@ __pycache__/
tmp/
temp/
exports/*
exports/*
core/.agent-builder-sessions/*
+3 -3
View File
@@ -1,9 +1,9 @@
{
"mcpServers": {
"agent-builder": {
"command": "python",
"args": ["-m", "framework.mcp.agent_builder_server"],
"cwd": "/home/timothy/oss/hive/core"
"command": "bash",
"args": ["-c", "cd core && exec python -m framework.mcp.agent_builder_server"],
"cwd": "core"
}
}
}
+17 -2
View File
@@ -20,6 +20,19 @@ Environment Variables:
"""
import argparse
import os
import sys
# Suppress FastMCP banner in STDIO mode
if "--stdio" in sys.argv:
# Monkey-patch rich Console to redirect to stderr
import rich.console
_original_console_init = rich.console.Console.__init__
def _patched_console_init(self, *args, **kwargs):
kwargs['file'] = sys.stderr # Force all rich output to stderr
_original_console_init(self, *args, **kwargs)
rich.console.Console.__init__ = _patched_console_init
from fastmcp import FastMCP
from starlette.requests import Request
@@ -31,7 +44,9 @@ mcp = FastMCP("aden-tools")
from aden_tools.tools import register_all_tools
tools = register_all_tools(mcp)
print(f"[MCP] Registered {len(tools)} tools: {tools}")
# Only print to stdout in HTTP mode (STDIO mode requires clean stdout for JSON-RPC)
if "--stdio" not in sys.argv:
print(f"[MCP] Registered {len(tools)} tools: {tools}")
@mcp.custom_route("/health", methods=["GET"])
@@ -68,7 +83,7 @@ def main() -> None:
args = parser.parse_args()
if args.stdio:
print("[MCP] Starting with STDIO transport")
# STDIO mode: only JSON-RPC messages go to stdout
mcp.run(transport="stdio")
else:
print(f"[MCP] Starting HTTP server on {args.host}:{args.port}")
+6 -6
View File
@@ -165,12 +165,7 @@ class GraphExecutor:
path.append(current_node_id)
# Check if terminal
if current_node_id in graph.terminal_nodes:
self.logger.info(f"✓ Reached terminal node: {node_spec.name}")
break
# Check if pause (HITL)
# Check if pause (HITL) before execution
if current_node_id in graph.pause_nodes:
self.logger.info(f"⏸ Paused at HITL node: {node_spec.name}")
# Execute this node, then pause
@@ -279,6 +274,11 @@ class GraphExecutor:
session_state=session_state_out,
)
# Check if this is a terminal node - if so, we're done
if node_spec.id in graph.terminal_nodes:
self.logger.info(f"✓ Reached terminal node: {node_spec.name}")
break
# Determine next node
if result.next_node:
# Router explicitly set next node
+85 -16
View File
@@ -431,22 +431,13 @@ class LLMNode(NodeProtocol):
# Write to output keys
output = self._parse_output(response.content, ctx.node_spec)
# For llm_generate nodes, try to parse JSON and extract fields
if ctx.node_spec.node_type == "llm_generate" and len(ctx.node_spec.output_keys) > 1:
# For llm_generate and llm_tool_use nodes, try to parse JSON and extract fields
if ctx.node_spec.node_type in ("llm_generate", "llm_tool_use") and len(ctx.node_spec.output_keys) > 1:
try:
# Try to parse as JSON
import json
import re
# Remove markdown code blocks if present
content = response.content.strip()
if content.startswith("```"):
# Extract JSON from code block
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
if match:
content = match.group(1).strip()
parsed = json.loads(content)
# Try direct JSON parse first
parsed = self._extract_json_with_haiku(response.content, ctx.node_spec.output_keys)
# If parsed successfully, write each field to its corresponding output key
if isinstance(parsed, dict):
@@ -454,8 +445,12 @@ class LLMNode(NodeProtocol):
if key in parsed:
ctx.memory.write(key, parsed[key])
output[key] = parsed[key]
elif key in ctx.input_data:
# Key not in parsed JSON but exists in input - pass through input value
ctx.memory.write(key, ctx.input_data[key])
output[key] = ctx.input_data[key]
else:
# Key not in parsed JSON, write the whole response
# Key not in parsed JSON or input, write the whole response
ctx.memory.write(key, response.content)
output[key] = response.content
else:
@@ -465,8 +460,8 @@ class LLMNode(NodeProtocol):
output[key] = response.content
except (json.JSONDecodeError, Exception) as e:
# JSON parsing failed, fall back to writing entire response
logger.warning(f" ⚠ Failed to parse JSON output, using raw response: {e}")
# JSON extraction failed completely
logger.warning(f" ⚠ Failed to extract JSON output: {e}")
for key in ctx.node_spec.output_keys:
ctx.memory.write(key, response.content)
output[key] = response.content
@@ -503,6 +498,80 @@ class LLMNode(NodeProtocol):
# Default output
return {"result": content}
def _extract_json_with_haiku(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
"""Use Haiku to extract clean JSON from potentially verbose LLM response."""
import json
import re
# Try direct JSON parse first (fast path)
try:
content = raw_response.strip()
# Remove markdown code blocks if present
if content.startswith("```"):
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
if match:
content = match.group(1).strip()
parsed = json.loads(content)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
# JSON parse failed - use Haiku to extract clean JSON
import os
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
# No API key, try one more simple extraction
try:
# Find first { and last }
start = raw_response.find('{')
end = raw_response.rfind('}')
if start != -1 and end != -1:
json_str = raw_response[start:end+1]
return json.loads(json_str)
except:
pass
raise ValueError("Cannot parse JSON and no API key for Haiku cleanup")
# Use Haiku to clean the response
from framework.llm.anthropic import AnthropicProvider
haiku = AnthropicProvider(model="claude-3-5-haiku-20241022")
prompt = f"""Extract the JSON object from this LLM response. Extract ONLY the values that the LLM actually generated.
Expected output keys: {output_keys}
LLM Response:
{raw_response}
IMPORTANT:
- Only extract keys that the LLM explicitly output in its response
- Do NOT include keys that were just mentioned or passed through from input
- If the LLM output multiple pieces of text/JSON, extract the LAST JSON object only
- Output ONLY valid JSON with no extra text, no markdown, no explanations"""
try:
result = haiku.complete(
messages=[{"role": "user", "content": prompt}],
system="You extract clean JSON from messy responses. Output only valid JSON, nothing else.",
)
cleaned = result.content.strip()
# Remove markdown if Haiku added it
if cleaned.startswith("```"):
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, re.DOTALL)
if match:
cleaned = match.group(1).strip()
parsed = json.loads(cleaned)
logger.info(f" ✓ Haiku cleaned JSON output")
return parsed
except Exception as e:
logger.warning(f" ⚠ Haiku JSON extraction failed: {e}")
raise
def _build_messages(self, ctx: NodeContext) -> list[dict]:
"""Build the message list for the LLM."""
# Use Haiku to intelligently format inputs from memory
+89 -31
View File
@@ -6,8 +6,6 @@ import json
import sys
from pathlib import Path
from framework.graph import ExecutionStatus
def register_commands(subparsers: argparse._SubParsersAction) -> None:
"""Register runner commands with the main CLI."""
@@ -48,6 +46,11 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
action="store_true",
help="Only output the final result JSON",
)
run_parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Show detailed execution logs (steps, LLM calls, etc.)",
)
run_parser.set_defaults(func=cmd_run)
# info command
@@ -166,8 +169,17 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
def cmd_run(args: argparse.Namespace) -> int:
"""Run an exported agent."""
import logging
from framework.runner import AgentRunner
# Set logging level (quiet by default for cleaner output)
if args.quiet:
logging.basicConfig(level=logging.ERROR, format='%(message)s')
elif getattr(args, 'verbose', False):
logging.basicConfig(level=logging.INFO, format='%(message)s')
else:
logging.basicConfig(level=logging.WARNING, format='%(message)s')
# Load input context
context = {}
if args.input:
@@ -195,6 +207,12 @@ def cmd_run(args: argparse.Namespace) -> int:
print(f"Error: {e}", file=sys.stderr)
return 1
# Auto-inject user_id if the agent expects it but it's not provided
entry_input_keys = runner.graph.nodes[0].input_keys if runner.graph.nodes else []
if "user_id" in entry_input_keys and context.get("user_id") is None:
import os
context["user_id"] = os.environ.get("USER", "default_user")
if not args.quiet:
info = runner.info()
print(f"Agent: {info.name}")
@@ -212,12 +230,14 @@ def cmd_run(args: argparse.Namespace) -> int:
# Format output
output = {
"status": result.status.value if hasattr(result.status, "value") else str(result.status),
"completed_steps": result.completed_steps,
"results": result.results,
"success": result.success,
"steps_executed": result.steps_executed,
"output": result.output,
}
if result.feedback:
output["feedback"] = result.feedback
if result.error:
output["error"] = result.error
if result.paused_at:
output["paused_at"] = result.paused_at
# Output results
if args.output:
@@ -231,27 +251,51 @@ def cmd_run(args: argparse.Namespace) -> int:
else:
print()
print("=" * 60)
status_str = result.status.value if hasattr(result.status, "value") else str(result.status)
status_str = "SUCCESS" if result.success else "FAILED"
print(f"Status: {status_str}")
print(f"Completed steps: {len(result.completed_steps)}")
print(f"Steps executed: {result.steps_executed}")
print(f"Path: {''.join(result.path)}")
print("=" * 60)
if result.status == ExecutionStatus.COMPLETED:
if result.success:
print("\n--- Results ---")
for key, value in result.results.items():
if isinstance(value, (dict, list)):
print(f"\n{key}:")
value_str = json.dumps(value, indent=2, default=str)
if len(value_str) > 500:
value_str = value_str[:500] + "..."
print(value_str)
else:
print(f"{key}: {str(value)[:200]}")
elif result.feedback:
print(f"\nFeedback: {result.feedback}")
# Show only meaningful output keys (skip internal/intermediate values)
meaningful_keys = ["final_response", "response", "result", "answer", "output"]
# Try to find the most relevant output
shown = False
for key in meaningful_keys:
if key in result.output:
value = result.output[key]
if isinstance(value, str) and len(value) > 10:
print(value)
shown = True
break
elif isinstance(value, (dict, list)):
print(json.dumps(value, indent=2, default=str))
shown = True
break
# If no meaningful key found, show all non-internal keys
if not shown:
for key, value in result.output.items():
if not key.startswith("_") and key not in ["user_id", "request", "memory_loaded", "user_profile", "recent_context"]:
if isinstance(value, (dict, list)):
print(f"\n{key}:")
value_str = json.dumps(value, indent=2, default=str)
if len(value_str) > 300:
value_str = value_str[:300] + "..."
print(value_str)
else:
val_str = str(value)
if len(val_str) > 200:
val_str = val_str[:200] + "..."
print(f"{key}: {val_str}")
elif result.error:
print(f"\nError: {result.error}")
runner.cleanup()
return 0 if result.status == ExecutionStatus.COMPLETED else 1
return 0 if result.success else 1
def cmd_info(args: argparse.Namespace) -> int:
@@ -760,6 +804,11 @@ def cmd_shell(args: argparse.Namespace) -> int:
# STARTING FRESH: Merge new input with accumulated session memory
run_context = {**session_memory, **context}
# Auto-inject user_id if missing (for personal assistant agents)
if "user_id" in entry_input_keys and run_context.get("user_id") is None:
import os
run_context["user_id"] = os.environ.get("USER", "default_user")
# Add conversation history to context if agent expects it
if conversation_history:
run_context["_conversation_history"] = conversation_history.copy()
@@ -778,16 +827,25 @@ def cmd_shell(args: argparse.Namespace) -> int:
print(f"Steps executed: {result.steps_executed}")
print(f"Path: {''.join(result.path)}")
# Show clean output - prioritize meaningful keys
if result.output:
print("\nOutput:")
for key, value in result.output.items():
if isinstance(value, (dict, list)):
value_str = json.dumps(value, indent=2, default=str)
if len(value_str) > 300:
value_str = value_str[:300] + "..."
print(f" {key}: {value_str}")
else:
print(f" {key}: {str(value)[:200]}")
meaningful_keys = ["final_response", "response", "result", "answer", "output"]
shown = False
for key in meaningful_keys:
if key in result.output:
value = result.output[key]
if isinstance(value, str) and len(value) > 10:
print(f"\n{value}\n")
shown = True
break
if not shown:
print("\nOutput:")
for key, value in result.output.items():
if not key.startswith("_"):
val_str = str(value)[:200]
print(f" {key}: {val_str}")
if result.error:
print(f"\nError: {result.error}")
+118 -45
View File
@@ -65,10 +65,15 @@ class MCPClient:
self._session = None
self._read_stream = None
self._write_stream = None
self._stdio_context = None # Context manager for stdio_client
self._http_client: httpx.Client | None = None
self._tools: dict[str, MCPTool] = {}
self._connected = False
# Background event loop for persistent STDIO connection
self._loop = None
self._loop_thread = None
def _run_async(self, coro):
"""
Run an async coroutine, handling both sync and async contexts.
@@ -79,6 +84,13 @@ class MCPClient:
Returns:
Result of the coroutine
"""
# If we have a persistent loop (for STDIO), use it
if self._loop is not None:
import concurrent.futures
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
# Otherwise, use the standard approach
try:
# Try to get the current event loop
asyncio.get_running_loop()
@@ -129,12 +141,12 @@ class MCPClient:
self._connected = True
def _connect_stdio(self) -> None:
"""Connect to MCP server via STDIO transport using MCP SDK."""
"""Connect to MCP server via STDIO transport using MCP SDK with persistent connection."""
if not self.config.command:
raise ValueError("command is required for STDIO transport")
try:
# Import MCP SDK
import threading
from mcp import StdioServerParameters
# Create server parameters
@@ -145,10 +157,62 @@ class MCPClient:
cwd=self.config.cwd,
)
# Store for later use in async context
# Store for later use
self._server_params = server_params
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO")
# Start background event loop for persistent connection
loop_started = threading.Event()
connection_ready = threading.Event()
connection_error = []
def run_event_loop():
"""Run event loop in background thread."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
loop_started.set()
# Initialize persistent connection
async def init_connection():
try:
from mcp import ClientSession
from mcp.client.stdio import stdio_client
# Create persistent stdio client context
self._stdio_context = stdio_client(server_params)
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
# Create persistent session
self._session = ClientSession(self._read_stream, self._write_stream)
await self._session.__aenter__()
# Initialize session
await self._session.initialize()
connection_ready.set()
except Exception as e:
connection_error.append(e)
connection_ready.set()
# Schedule connection initialization
self._loop.create_task(init_connection())
# Run loop forever
self._loop.run_forever()
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
self._loop_thread.start()
# Wait for loop to start
loop_started.wait(timeout=5)
if not loop_started.is_set():
raise RuntimeError("Event loop failed to start")
# Wait for connection to be ready
connection_ready.wait(timeout=10)
if connection_error:
raise connection_error[0]
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO (persistent)")
except Exception as e:
raise RuntimeError(f"Failed to connect to MCP server: {e}")
@@ -196,28 +260,23 @@ class MCPClient:
raise
async def _list_tools_stdio_async(self) -> list[dict]:
"""List tools via STDIO protocol using MCP SDK."""
from mcp import ClientSession
from mcp.client.stdio import stdio_client
"""List tools via STDIO protocol using persistent session."""
if not self._session:
raise RuntimeError("STDIO session not initialized")
async with stdio_client(self._server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# List tools using persistent session
response = await self._session.list_tools()
# List tools
response = await session.list_tools()
# Convert tools to dict format
tools_list = []
for tool in response.tools:
tools_list.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
})
# Convert tools to dict format
tools_list = []
for tool in response.tools:
tools_list.append({
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
})
return tools_list
return tools_list
def _list_tools_http(self) -> list[dict]:
"""List tools via HTTP protocol."""
@@ -280,31 +339,26 @@ class MCPClient:
return self._call_tool_http(tool_name, arguments)
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via STDIO protocol using MCP SDK."""
from mcp import ClientSession
from mcp.client.stdio import stdio_client
"""Call tool via STDIO protocol using persistent session."""
if not self._session:
raise RuntimeError("STDIO session not initialized")
async with stdio_client(self._server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# Call tool using persistent session
result = await self._session.call_tool(tool_name, arguments=arguments)
# Call tool
result = await session.call_tool(tool_name, arguments=arguments)
# Extract content
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, 'text'):
return content_item.text
elif hasattr(content_item, 'data'):
return content_item.data
return result.content
# Extract content
if result.content:
# MCP returns content as a list of content items
if len(result.content) > 0:
content_item = result.content[0]
# Check if it's a text content item
if hasattr(content_item, 'text'):
return content_item.text
elif hasattr(content_item, 'data'):
return content_item.data
return result.content
return None
return None
def _call_tool_http(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call tool via HTTP protocol."""
@@ -336,6 +390,25 @@ class MCPClient:
def disconnect(self) -> None:
"""Disconnect from the MCP server."""
# Clean up persistent STDIO connection
if self._loop is not None:
# Stop event loop - this will cause context managers to clean up naturally
if self._loop and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
# Wait for thread to finish
if self._loop_thread and self._loop_thread.is_alive():
self._loop_thread.join(timeout=2)
# Clear references
self._session = None
self._stdio_context = None
self._read_stream = None
self._write_stream = None
self._loop = None
self._loop_thread = None
# Clean up HTTP client
if self._http_client:
self._http_client.close()
self._http_client = None
+18 -2
View File
@@ -195,8 +195,12 @@ class AgentRunner:
self._storage_path = storage_path
self._temp_dir = None
else:
self._temp_dir = tempfile.TemporaryDirectory()
self._storage_path = Path(self._temp_dir.name) / "runtime"
# Use persistent storage in ~/.hive by default
home = Path.home()
default_storage = home / ".hive" / "storage" / agent_path.name
default_storage.mkdir(parents=True, exist_ok=True)
self._storage_path = default_storage
self._temp_dir = None
# Initialize components
self._tool_registry = ToolRegistry()
@@ -366,6 +370,18 @@ class AgentRunner:
# Create runtime
self._runtime = Runtime(storage_path=self._storage_path)
# Set up session context for tools (workspace_id, agent_id, session_id)
workspace_id = "default" # Could be derived from storage path
agent_id = self.graph.id or "unknown"
# Use "current" as a stable session_id for persistent memory
session_id = "current"
self._tool_registry.set_session_context(
workspace_id=workspace_id,
agent_id=agent_id,
session_id=session_id,
)
# Create LLM provider (if not mock mode and API key available)
if not self.mock_mode and os.environ.get("ANTHROPIC_API_KEY"):
from framework.llm.anthropic import AnthropicProvider
+15 -3
View File
@@ -35,6 +35,7 @@ class ToolRegistry:
def __init__(self):
self._tools: dict[str, RegisteredTool] = {}
self._mcp_clients: list[Any] = [] # List of MCPClient instances
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
def register(
self,
@@ -227,6 +228,15 @@ class ToolRegistry:
"""Check if a tool is registered."""
return name in self._tools
def set_session_context(self, **context) -> None:
"""
Set session context to auto-inject into tool calls.
Args:
**context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id)
"""
self._session_context.update(context)
def register_mcp_server(
self,
server_config: dict[str, Any],
@@ -279,10 +289,12 @@ class ToolRegistry:
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
# Create executor that calls the MCP server
def make_mcp_executor(client_ref: MCPClient, tool_name: str):
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
def executor(inputs: dict) -> Any:
try:
result = client_ref.call_tool(tool_name, inputs)
# Inject session context for tools that need it
merged_inputs = {**registry_ref._session_context, **inputs}
result = client_ref.call_tool(tool_name, merged_inputs)
# MCP tools return content array, extract the result
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "text" in result[0]:
@@ -298,7 +310,7 @@ class ToolRegistry:
self.register(
mcp_tool.name,
tool,
make_mcp_executor(client, mcp_tool.name),
make_mcp_executor(client, mcp_tool.name, self),
)
count += 1