feat(cli): basic session support, mcp integration issues
This commit is contained in:
+3
-1
@@ -63,4 +63,6 @@ __pycache__/
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
exports/*
|
||||
exports/*
|
||||
|
||||
core/.agent-builder-sessions/*
|
||||
@@ -1,9 +1,9 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"agent-builder": {
|
||||
"command": "python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "/home/timothy/oss/hive/core"
|
||||
"command": "bash",
|
||||
"args": ["-c", "cd core && exec python -m framework.mcp.agent_builder_server"],
|
||||
"cwd": "core"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,19 @@ Environment Variables:
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Suppress FastMCP banner in STDIO mode
|
||||
if "--stdio" in sys.argv:
|
||||
# Monkey-patch rich Console to redirect to stderr
|
||||
import rich.console
|
||||
_original_console_init = rich.console.Console.__init__
|
||||
|
||||
def _patched_console_init(self, *args, **kwargs):
|
||||
kwargs['file'] = sys.stderr # Force all rich output to stderr
|
||||
_original_console_init(self, *args, **kwargs)
|
||||
|
||||
rich.console.Console.__init__ = _patched_console_init
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from starlette.requests import Request
|
||||
@@ -31,7 +44,9 @@ mcp = FastMCP("aden-tools")
|
||||
from aden_tools.tools import register_all_tools
|
||||
|
||||
tools = register_all_tools(mcp)
|
||||
print(f"[MCP] Registered {len(tools)} tools: {tools}")
|
||||
# Only print to stdout in HTTP mode (STDIO mode requires clean stdout for JSON-RPC)
|
||||
if "--stdio" not in sys.argv:
|
||||
print(f"[MCP] Registered {len(tools)} tools: {tools}")
|
||||
|
||||
|
||||
@mcp.custom_route("/health", methods=["GET"])
|
||||
@@ -68,7 +83,7 @@ def main() -> None:
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.stdio:
|
||||
print("[MCP] Starting with STDIO transport")
|
||||
# STDIO mode: only JSON-RPC messages go to stdout
|
||||
mcp.run(transport="stdio")
|
||||
else:
|
||||
print(f"[MCP] Starting HTTP server on {args.host}:{args.port}")
|
||||
|
||||
@@ -165,12 +165,7 @@ class GraphExecutor:
|
||||
|
||||
path.append(current_node_id)
|
||||
|
||||
# Check if terminal
|
||||
if current_node_id in graph.terminal_nodes:
|
||||
self.logger.info(f"✓ Reached terminal node: {node_spec.name}")
|
||||
break
|
||||
|
||||
# Check if pause (HITL)
|
||||
# Check if pause (HITL) before execution
|
||||
if current_node_id in graph.pause_nodes:
|
||||
self.logger.info(f"⏸ Paused at HITL node: {node_spec.name}")
|
||||
# Execute this node, then pause
|
||||
@@ -279,6 +274,11 @@ class GraphExecutor:
|
||||
session_state=session_state_out,
|
||||
)
|
||||
|
||||
# Check if this is a terminal node - if so, we're done
|
||||
if node_spec.id in graph.terminal_nodes:
|
||||
self.logger.info(f"✓ Reached terminal node: {node_spec.name}")
|
||||
break
|
||||
|
||||
# Determine next node
|
||||
if result.next_node:
|
||||
# Router explicitly set next node
|
||||
|
||||
@@ -431,22 +431,13 @@ class LLMNode(NodeProtocol):
|
||||
# Write to output keys
|
||||
output = self._parse_output(response.content, ctx.node_spec)
|
||||
|
||||
# For llm_generate nodes, try to parse JSON and extract fields
|
||||
if ctx.node_spec.node_type == "llm_generate" and len(ctx.node_spec.output_keys) > 1:
|
||||
# For llm_generate and llm_tool_use nodes, try to parse JSON and extract fields
|
||||
if ctx.node_spec.node_type in ("llm_generate", "llm_tool_use") and len(ctx.node_spec.output_keys) > 1:
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
import json
|
||||
import re
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
content = response.content.strip()
|
||||
if content.startswith("```"):
|
||||
# Extract JSON from code block
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
|
||||
parsed = json.loads(content)
|
||||
# Try direct JSON parse first
|
||||
parsed = self._extract_json_with_haiku(response.content, ctx.node_spec.output_keys)
|
||||
|
||||
# If parsed successfully, write each field to its corresponding output key
|
||||
if isinstance(parsed, dict):
|
||||
@@ -454,8 +445,12 @@ class LLMNode(NodeProtocol):
|
||||
if key in parsed:
|
||||
ctx.memory.write(key, parsed[key])
|
||||
output[key] = parsed[key]
|
||||
elif key in ctx.input_data:
|
||||
# Key not in parsed JSON but exists in input - pass through input value
|
||||
ctx.memory.write(key, ctx.input_data[key])
|
||||
output[key] = ctx.input_data[key]
|
||||
else:
|
||||
# Key not in parsed JSON, write the whole response
|
||||
# Key not in parsed JSON or input, write the whole response
|
||||
ctx.memory.write(key, response.content)
|
||||
output[key] = response.content
|
||||
else:
|
||||
@@ -465,8 +460,8 @@ class LLMNode(NodeProtocol):
|
||||
output[key] = response.content
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
# JSON parsing failed, fall back to writing entire response
|
||||
logger.warning(f" ⚠ Failed to parse JSON output, using raw response: {e}")
|
||||
# JSON extraction failed completely
|
||||
logger.warning(f" ⚠ Failed to extract JSON output: {e}")
|
||||
for key in ctx.node_spec.output_keys:
|
||||
ctx.memory.write(key, response.content)
|
||||
output[key] = response.content
|
||||
@@ -503,6 +498,80 @@ class LLMNode(NodeProtocol):
|
||||
# Default output
|
||||
return {"result": content}
|
||||
|
||||
def _extract_json_with_haiku(self, raw_response: str, output_keys: list[str]) -> dict[str, Any]:
|
||||
"""Use Haiku to extract clean JSON from potentially verbose LLM response."""
|
||||
import json
|
||||
import re
|
||||
|
||||
# Try direct JSON parse first (fast path)
|
||||
try:
|
||||
content = raw_response.strip()
|
||||
# Remove markdown code blocks if present
|
||||
if content.startswith("```"):
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', content, re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# JSON parse failed - use Haiku to extract clean JSON
|
||||
import os
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
# No API key, try one more simple extraction
|
||||
try:
|
||||
# Find first { and last }
|
||||
start = raw_response.find('{')
|
||||
end = raw_response.rfind('}')
|
||||
if start != -1 and end != -1:
|
||||
json_str = raw_response[start:end+1]
|
||||
return json.loads(json_str)
|
||||
except:
|
||||
pass
|
||||
raise ValueError("Cannot parse JSON and no API key for Haiku cleanup")
|
||||
|
||||
# Use Haiku to clean the response
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
haiku = AnthropicProvider(model="claude-3-5-haiku-20241022")
|
||||
|
||||
prompt = f"""Extract the JSON object from this LLM response. Extract ONLY the values that the LLM actually generated.
|
||||
|
||||
Expected output keys: {output_keys}
|
||||
|
||||
LLM Response:
|
||||
{raw_response}
|
||||
|
||||
IMPORTANT:
|
||||
- Only extract keys that the LLM explicitly output in its response
|
||||
- Do NOT include keys that were just mentioned or passed through from input
|
||||
- If the LLM output multiple pieces of text/JSON, extract the LAST JSON object only
|
||||
- Output ONLY valid JSON with no extra text, no markdown, no explanations"""
|
||||
|
||||
try:
|
||||
result = haiku.complete(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
system="You extract clean JSON from messy responses. Output only valid JSON, nothing else.",
|
||||
)
|
||||
|
||||
cleaned = result.content.strip()
|
||||
# Remove markdown if Haiku added it
|
||||
if cleaned.startswith("```"):
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, re.DOTALL)
|
||||
if match:
|
||||
cleaned = match.group(1).strip()
|
||||
|
||||
parsed = json.loads(cleaned)
|
||||
logger.info(f" ✓ Haiku cleaned JSON output")
|
||||
return parsed
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f" ⚠ Haiku JSON extraction failed: {e}")
|
||||
raise
|
||||
|
||||
def _build_messages(self, ctx: NodeContext) -> list[dict]:
|
||||
"""Build the message list for the LLM."""
|
||||
# Use Haiku to intelligently format inputs from memory
|
||||
|
||||
@@ -6,8 +6,6 @@ import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph import ExecutionStatus
|
||||
|
||||
|
||||
def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Register runner commands with the main CLI."""
|
||||
@@ -48,6 +46,11 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
action="store_true",
|
||||
help="Only output the final result JSON",
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Show detailed execution logs (steps, LLM calls, etc.)",
|
||||
)
|
||||
run_parser.set_defaults(func=cmd_run)
|
||||
|
||||
# info command
|
||||
@@ -166,8 +169,17 @@ def register_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
|
||||
def cmd_run(args: argparse.Namespace) -> int:
|
||||
"""Run an exported agent."""
|
||||
import logging
|
||||
from framework.runner import AgentRunner
|
||||
|
||||
# Set logging level (quiet by default for cleaner output)
|
||||
if args.quiet:
|
||||
logging.basicConfig(level=logging.ERROR, format='%(message)s')
|
||||
elif getattr(args, 'verbose', False):
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
else:
|
||||
logging.basicConfig(level=logging.WARNING, format='%(message)s')
|
||||
|
||||
# Load input context
|
||||
context = {}
|
||||
if args.input:
|
||||
@@ -195,6 +207,12 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Auto-inject user_id if the agent expects it but it's not provided
|
||||
entry_input_keys = runner.graph.nodes[0].input_keys if runner.graph.nodes else []
|
||||
if "user_id" in entry_input_keys and context.get("user_id") is None:
|
||||
import os
|
||||
context["user_id"] = os.environ.get("USER", "default_user")
|
||||
|
||||
if not args.quiet:
|
||||
info = runner.info()
|
||||
print(f"Agent: {info.name}")
|
||||
@@ -212,12 +230,14 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
|
||||
# Format output
|
||||
output = {
|
||||
"status": result.status.value if hasattr(result.status, "value") else str(result.status),
|
||||
"completed_steps": result.completed_steps,
|
||||
"results": result.results,
|
||||
"success": result.success,
|
||||
"steps_executed": result.steps_executed,
|
||||
"output": result.output,
|
||||
}
|
||||
if result.feedback:
|
||||
output["feedback"] = result.feedback
|
||||
if result.error:
|
||||
output["error"] = result.error
|
||||
if result.paused_at:
|
||||
output["paused_at"] = result.paused_at
|
||||
|
||||
# Output results
|
||||
if args.output:
|
||||
@@ -231,27 +251,51 @@ def cmd_run(args: argparse.Namespace) -> int:
|
||||
else:
|
||||
print()
|
||||
print("=" * 60)
|
||||
status_str = result.status.value if hasattr(result.status, "value") else str(result.status)
|
||||
status_str = "SUCCESS" if result.success else "FAILED"
|
||||
print(f"Status: {status_str}")
|
||||
print(f"Completed steps: {len(result.completed_steps)}")
|
||||
print(f"Steps executed: {result.steps_executed}")
|
||||
print(f"Path: {' → '.join(result.path)}")
|
||||
print("=" * 60)
|
||||
|
||||
if result.status == ExecutionStatus.COMPLETED:
|
||||
if result.success:
|
||||
print("\n--- Results ---")
|
||||
for key, value in result.results.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
print(f"\n{key}:")
|
||||
value_str = json.dumps(value, indent=2, default=str)
|
||||
if len(value_str) > 500:
|
||||
value_str = value_str[:500] + "..."
|
||||
print(value_str)
|
||||
else:
|
||||
print(f"{key}: {str(value)[:200]}")
|
||||
elif result.feedback:
|
||||
print(f"\nFeedback: {result.feedback}")
|
||||
# Show only meaningful output keys (skip internal/intermediate values)
|
||||
meaningful_keys = ["final_response", "response", "result", "answer", "output"]
|
||||
|
||||
# Try to find the most relevant output
|
||||
shown = False
|
||||
for key in meaningful_keys:
|
||||
if key in result.output:
|
||||
value = result.output[key]
|
||||
if isinstance(value, str) and len(value) > 10:
|
||||
print(value)
|
||||
shown = True
|
||||
break
|
||||
elif isinstance(value, (dict, list)):
|
||||
print(json.dumps(value, indent=2, default=str))
|
||||
shown = True
|
||||
break
|
||||
|
||||
# If no meaningful key found, show all non-internal keys
|
||||
if not shown:
|
||||
for key, value in result.output.items():
|
||||
if not key.startswith("_") and key not in ["user_id", "request", "memory_loaded", "user_profile", "recent_context"]:
|
||||
if isinstance(value, (dict, list)):
|
||||
print(f"\n{key}:")
|
||||
value_str = json.dumps(value, indent=2, default=str)
|
||||
if len(value_str) > 300:
|
||||
value_str = value_str[:300] + "..."
|
||||
print(value_str)
|
||||
else:
|
||||
val_str = str(value)
|
||||
if len(val_str) > 200:
|
||||
val_str = val_str[:200] + "..."
|
||||
print(f"{key}: {val_str}")
|
||||
elif result.error:
|
||||
print(f"\nError: {result.error}")
|
||||
|
||||
runner.cleanup()
|
||||
return 0 if result.status == ExecutionStatus.COMPLETED else 1
|
||||
return 0 if result.success else 1
|
||||
|
||||
|
||||
def cmd_info(args: argparse.Namespace) -> int:
|
||||
@@ -760,6 +804,11 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
# STARTING FRESH: Merge new input with accumulated session memory
|
||||
run_context = {**session_memory, **context}
|
||||
|
||||
# Auto-inject user_id if missing (for personal assistant agents)
|
||||
if "user_id" in entry_input_keys and run_context.get("user_id") is None:
|
||||
import os
|
||||
run_context["user_id"] = os.environ.get("USER", "default_user")
|
||||
|
||||
# Add conversation history to context if agent expects it
|
||||
if conversation_history:
|
||||
run_context["_conversation_history"] = conversation_history.copy()
|
||||
@@ -778,16 +827,25 @@ def cmd_shell(args: argparse.Namespace) -> int:
|
||||
print(f"Steps executed: {result.steps_executed}")
|
||||
print(f"Path: {' → '.join(result.path)}")
|
||||
|
||||
# Show clean output - prioritize meaningful keys
|
||||
if result.output:
|
||||
print("\nOutput:")
|
||||
for key, value in result.output.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
value_str = json.dumps(value, indent=2, default=str)
|
||||
if len(value_str) > 300:
|
||||
value_str = value_str[:300] + "..."
|
||||
print(f" {key}: {value_str}")
|
||||
else:
|
||||
print(f" {key}: {str(value)[:200]}")
|
||||
meaningful_keys = ["final_response", "response", "result", "answer", "output"]
|
||||
shown = False
|
||||
|
||||
for key in meaningful_keys:
|
||||
if key in result.output:
|
||||
value = result.output[key]
|
||||
if isinstance(value, str) and len(value) > 10:
|
||||
print(f"\n{value}\n")
|
||||
shown = True
|
||||
break
|
||||
|
||||
if not shown:
|
||||
print("\nOutput:")
|
||||
for key, value in result.output.items():
|
||||
if not key.startswith("_"):
|
||||
val_str = str(value)[:200]
|
||||
print(f" {key}: {val_str}")
|
||||
|
||||
if result.error:
|
||||
print(f"\nError: {result.error}")
|
||||
|
||||
@@ -65,10 +65,15 @@ class MCPClient:
|
||||
self._session = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._stdio_context = None # Context manager for stdio_client
|
||||
self._http_client: httpx.Client | None = None
|
||||
self._tools: dict[str, MCPTool] = {}
|
||||
self._connected = False
|
||||
|
||||
# Background event loop for persistent STDIO connection
|
||||
self._loop = None
|
||||
self._loop_thread = None
|
||||
|
||||
def _run_async(self, coro):
|
||||
"""
|
||||
Run an async coroutine, handling both sync and async contexts.
|
||||
@@ -79,6 +84,13 @@ class MCPClient:
|
||||
Returns:
|
||||
Result of the coroutine
|
||||
"""
|
||||
# If we have a persistent loop (for STDIO), use it
|
||||
if self._loop is not None:
|
||||
import concurrent.futures
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result()
|
||||
|
||||
# Otherwise, use the standard approach
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
asyncio.get_running_loop()
|
||||
@@ -129,12 +141,12 @@ class MCPClient:
|
||||
self._connected = True
|
||||
|
||||
def _connect_stdio(self) -> None:
|
||||
"""Connect to MCP server via STDIO transport using MCP SDK."""
|
||||
"""Connect to MCP server via STDIO transport using MCP SDK with persistent connection."""
|
||||
if not self.config.command:
|
||||
raise ValueError("command is required for STDIO transport")
|
||||
|
||||
try:
|
||||
# Import MCP SDK
|
||||
import threading
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
# Create server parameters
|
||||
@@ -145,10 +157,62 @@ class MCPClient:
|
||||
cwd=self.config.cwd,
|
||||
)
|
||||
|
||||
# Store for later use in async context
|
||||
# Store for later use
|
||||
self._server_params = server_params
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO")
|
||||
# Start background event loop for persistent connection
|
||||
loop_started = threading.Event()
|
||||
connection_ready = threading.Event()
|
||||
connection_error = []
|
||||
|
||||
def run_event_loop():
|
||||
"""Run event loop in background thread."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
loop_started.set()
|
||||
|
||||
# Initialize persistent connection
|
||||
async def init_connection():
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
# Create persistent stdio client context
|
||||
self._stdio_context = stdio_client(server_params)
|
||||
self._read_stream, self._write_stream = await self._stdio_context.__aenter__()
|
||||
|
||||
# Create persistent session
|
||||
self._session = ClientSession(self._read_stream, self._write_stream)
|
||||
await self._session.__aenter__()
|
||||
|
||||
# Initialize session
|
||||
await self._session.initialize()
|
||||
|
||||
connection_ready.set()
|
||||
except Exception as e:
|
||||
connection_error.append(e)
|
||||
connection_ready.set()
|
||||
|
||||
# Schedule connection initialization
|
||||
self._loop.create_task(init_connection())
|
||||
|
||||
# Run loop forever
|
||||
self._loop.run_forever()
|
||||
|
||||
self._loop_thread = threading.Thread(target=run_event_loop, daemon=True)
|
||||
self._loop_thread.start()
|
||||
|
||||
# Wait for loop to start
|
||||
loop_started.wait(timeout=5)
|
||||
if not loop_started.is_set():
|
||||
raise RuntimeError("Event loop failed to start")
|
||||
|
||||
# Wait for connection to be ready
|
||||
connection_ready.wait(timeout=10)
|
||||
if connection_error:
|
||||
raise connection_error[0]
|
||||
|
||||
logger.info(f"Connected to MCP server '{self.config.name}' via STDIO (persistent)")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
@@ -196,28 +260,23 @@ class MCPClient:
|
||||
raise
|
||||
|
||||
async def _list_tools_stdio_async(self) -> list[dict]:
|
||||
"""List tools via STDIO protocol using MCP SDK."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
"""List tools via STDIO protocol using persistent session."""
|
||||
if not self._session:
|
||||
raise RuntimeError("STDIO session not initialized")
|
||||
|
||||
async with stdio_client(self._server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the session
|
||||
await session.initialize()
|
||||
# List tools using persistent session
|
||||
response = await self._session.list_tools()
|
||||
|
||||
# List tools
|
||||
response = await session.list_tools()
|
||||
# Convert tools to dict format
|
||||
tools_list = []
|
||||
for tool in response.tools:
|
||||
tools_list.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema,
|
||||
})
|
||||
|
||||
# Convert tools to dict format
|
||||
tools_list = []
|
||||
for tool in response.tools:
|
||||
tools_list.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": tool.inputSchema,
|
||||
})
|
||||
|
||||
return tools_list
|
||||
return tools_list
|
||||
|
||||
def _list_tools_http(self) -> list[dict]:
|
||||
"""List tools via HTTP protocol."""
|
||||
@@ -280,31 +339,26 @@ class MCPClient:
|
||||
return self._call_tool_http(tool_name, arguments)
|
||||
|
||||
async def _call_tool_stdio_async(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call tool via STDIO protocol using MCP SDK."""
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
"""Call tool via STDIO protocol using persistent session."""
|
||||
if not self._session:
|
||||
raise RuntimeError("STDIO session not initialized")
|
||||
|
||||
async with stdio_client(self._server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the session
|
||||
await session.initialize()
|
||||
# Call tool using persistent session
|
||||
result = await self._session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
# Call tool
|
||||
result = await session.call_tool(tool_name, arguments=arguments)
|
||||
# Extract content
|
||||
if result.content:
|
||||
# MCP returns content as a list of content items
|
||||
if len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
# Check if it's a text content item
|
||||
if hasattr(content_item, 'text'):
|
||||
return content_item.text
|
||||
elif hasattr(content_item, 'data'):
|
||||
return content_item.data
|
||||
return result.content
|
||||
|
||||
# Extract content
|
||||
if result.content:
|
||||
# MCP returns content as a list of content items
|
||||
if len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
# Check if it's a text content item
|
||||
if hasattr(content_item, 'text'):
|
||||
return content_item.text
|
||||
elif hasattr(content_item, 'data'):
|
||||
return content_item.data
|
||||
return result.content
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
def _call_tool_http(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call tool via HTTP protocol."""
|
||||
@@ -336,6 +390,25 @@ class MCPClient:
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the MCP server."""
|
||||
# Clean up persistent STDIO connection
|
||||
if self._loop is not None:
|
||||
# Stop event loop - this will cause context managers to clean up naturally
|
||||
if self._loop and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
|
||||
# Wait for thread to finish
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
self._loop_thread.join(timeout=2)
|
||||
|
||||
# Clear references
|
||||
self._session = None
|
||||
self._stdio_context = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
self._loop = None
|
||||
self._loop_thread = None
|
||||
|
||||
# Clean up HTTP client
|
||||
if self._http_client:
|
||||
self._http_client.close()
|
||||
self._http_client = None
|
||||
|
||||
@@ -195,8 +195,12 @@ class AgentRunner:
|
||||
self._storage_path = storage_path
|
||||
self._temp_dir = None
|
||||
else:
|
||||
self._temp_dir = tempfile.TemporaryDirectory()
|
||||
self._storage_path = Path(self._temp_dir.name) / "runtime"
|
||||
# Use persistent storage in ~/.hive by default
|
||||
home = Path.home()
|
||||
default_storage = home / ".hive" / "storage" / agent_path.name
|
||||
default_storage.mkdir(parents=True, exist_ok=True)
|
||||
self._storage_path = default_storage
|
||||
self._temp_dir = None
|
||||
|
||||
# Initialize components
|
||||
self._tool_registry = ToolRegistry()
|
||||
@@ -366,6 +370,18 @@ class AgentRunner:
|
||||
# Create runtime
|
||||
self._runtime = Runtime(storage_path=self._storage_path)
|
||||
|
||||
# Set up session context for tools (workspace_id, agent_id, session_id)
|
||||
workspace_id = "default" # Could be derived from storage path
|
||||
agent_id = self.graph.id or "unknown"
|
||||
# Use "current" as a stable session_id for persistent memory
|
||||
session_id = "current"
|
||||
|
||||
self._tool_registry.set_session_context(
|
||||
workspace_id=workspace_id,
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Create LLM provider (if not mock mode and API key available)
|
||||
if not self.mock_mode and os.environ.get("ANTHROPIC_API_KEY"):
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
|
||||
@@ -35,6 +35,7 @@ class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -227,6 +228,15 @@ class ToolRegistry:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
def set_session_context(self, **context) -> None:
|
||||
"""
|
||||
Set session context to auto-inject into tool calls.
|
||||
|
||||
Args:
|
||||
**context: Key-value pairs to inject (e.g., workspace_id, agent_id, session_id)
|
||||
"""
|
||||
self._session_context.update(context)
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
@@ -279,10 +289,12 @@ class ToolRegistry:
|
||||
tool = self._convert_mcp_tool_to_framework_tool(mcp_tool)
|
||||
|
||||
# Create executor that calls the MCP server
|
||||
def make_mcp_executor(client_ref: MCPClient, tool_name: str):
|
||||
def make_mcp_executor(client_ref: MCPClient, tool_name: str, registry_ref):
|
||||
def executor(inputs: dict) -> Any:
|
||||
try:
|
||||
result = client_ref.call_tool(tool_name, inputs)
|
||||
# Inject session context for tools that need it
|
||||
merged_inputs = {**registry_ref._session_context, **inputs}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
if isinstance(result[0], dict) and "text" in result[0]:
|
||||
@@ -298,7 +310,7 @@ class ToolRegistry:
|
||||
self.register(
|
||||
mcp_tool.name,
|
||||
tool,
|
||||
make_mcp_executor(client, mcp_tool.name),
|
||||
make_mcp_executor(client, mcp_tool.name, self),
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user