Merge branch 'main' into feat/open-hive-colony

This commit is contained in:
Bryan
2026-04-08 11:50:07 -07:00
13 changed files with 1916 additions and 76 deletions
+1
View File
@@ -79,3 +79,4 @@ core/tests/*dumps/*
screenshots/*
.gemini/*
.coverage
+27 -23
View File
@@ -342,34 +342,38 @@ def _dump_failed_request(
attempt: int,
) -> str:
"""Dump failed request to a file for debugging. Returns the file path."""
FAILED_REQUESTS_DIR.mkdir(parents=True, exist_ok=True)
try:
FAILED_REQUESTS_DIR.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{error_type}_{model.replace('/', '_')}_{timestamp}.json"
filepath = FAILED_REQUESTS_DIR / filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{error_type}_{model.replace('/', '_')}_{timestamp}.json"
filepath = FAILED_REQUESTS_DIR / filename
# Build dump data
messages = kwargs.get("messages", [])
dump_data = {
"timestamp": datetime.now().isoformat(),
"model": model,
"error_type": error_type,
"attempt": attempt,
"estimated_tokens": _estimate_tokens(model, messages),
"num_messages": len(messages),
"messages": messages,
"tools": kwargs.get("tools"),
"max_tokens": kwargs.get("max_tokens"),
"temperature": kwargs.get("temperature"),
}
# Build dump data
messages = kwargs.get("messages", [])
dump_data = {
"timestamp": datetime.now().isoformat(),
"model": model,
"error_type": error_type,
"attempt": attempt,
"estimated_tokens": _estimate_tokens(model, messages),
"num_messages": len(messages),
"messages": messages,
"tools": kwargs.get("tools"),
"max_tokens": kwargs.get("max_tokens"),
"temperature": kwargs.get("temperature"),
}
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dump_data, f, indent=2, default=str)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(dump_data, f, indent=2, default=str)
# Prune old dumps to prevent unbounded disk growth
_prune_failed_request_dumps()
# Prune old dumps to prevent unbounded disk growth
_prune_failed_request_dumps()
return str(filepath)
return str(filepath)
except OSError as e:
logger.warning(f"Failed to dump request debug log to {FAILED_REQUESTS_DIR}: {e}")
return "log_write_failed"
def _compute_retry_delay(
+927
View File
@@ -0,0 +1,927 @@
"""Tests for EventBus pub/sub event system.
Validates subscription management, event publishing, filtering,
concurrency handling, history operations, and convenience publishers.
"""
from __future__ import annotations
import asyncio
from datetime import datetime
import pytest
from framework.runtime.event_bus import (
AgentEvent,
EventBus,
EventType,
)
# ---------------------------------------------------------------------------
# AgentEvent dataclass tests
# ---------------------------------------------------------------------------
class TestAgentEvent:
"""Tests for AgentEvent dataclass."""
def test_minimal_construction(self):
"""Event can be created with just type and stream_id."""
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test_stream")
assert event.type == EventType.EXECUTION_STARTED
assert event.stream_id == "test_stream"
assert event.node_id is None
assert event.execution_id is None
assert event.data == {}
assert event.correlation_id is None
def test_full_construction(self):
"""Event stores all provided fields."""
event = AgentEvent(
type=EventType.TOOL_CALL_COMPLETED,
stream_id="stream_1",
node_id="node_1",
execution_id="exec_123",
data={"result": "success"},
correlation_id="corr_456",
)
assert event.type == EventType.TOOL_CALL_COMPLETED
assert event.stream_id == "stream_1"
assert event.node_id == "node_1"
assert event.execution_id == "exec_123"
assert event.data == {"result": "success"}
assert event.correlation_id == "corr_456"
def test_timestamp_auto_generated(self):
"""Timestamp is auto-generated if not provided."""
before = datetime.now()
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test")
after = datetime.now()
assert before <= event.timestamp <= after
def test_to_dict_serialization(self):
"""Event can be serialized to dictionary."""
event = AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="stream_1",
node_id="node_1",
execution_id="exec_1",
data={"output": "result"},
correlation_id="corr_1",
graph_id="graph_1",
)
d = event.to_dict()
assert d["type"] == "execution_completed"
assert d["stream_id"] == "stream_1"
def test_to_dict_includes_run_id(self):
"""run_id is included in to_dict() when set."""
event = AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="s1",
run_id="run-abc",
)
d = event.to_dict()
assert d["run_id"] == "run-abc"
def test_to_dict_omits_run_id_when_none(self):
"""run_id is omitted from to_dict() when None."""
event = AgentEvent(
type=EventType.EXECUTION_STARTED,
stream_id="s1",
)
d = event.to_dict()
assert "run_id" not in d
# ---------------------------------------------------------------------------
# Subscription management tests
# ---------------------------------------------------------------------------
class TestSubscriptionManagement:
"""Tests for subscribe/unsubscribe operations."""
def test_subscribe_returns_id(self):
"""subscribe() returns a subscription ID."""
bus = EventBus()
async def handler(event: AgentEvent) -> None:
pass
sub_id = bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
assert sub_id.startswith("sub_")
def test_subscribe_increments_id(self):
"""Each subscription gets a unique incremented ID."""
bus = EventBus()
async def handler(event: AgentEvent) -> None:
pass
id1 = bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
id2 = bus.subscribe(event_types=[EventType.EXECUTION_COMPLETED], handler=handler)
id3 = bus.subscribe(event_types=[EventType.EXECUTION_FAILED], handler=handler)
assert id1 == "sub_1"
assert id2 == "sub_2"
assert id3 == "sub_3"
def test_unsubscribe_removes_subscription(self):
"""unsubscribe() removes the subscription."""
bus = EventBus()
async def handler(event: AgentEvent) -> None:
pass
sub_id = bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
assert sub_id in bus._subscriptions
result = bus.unsubscribe(sub_id)
assert result is True
assert sub_id not in bus._subscriptions
def test_unsubscribe_nonexistent_returns_false(self):
"""unsubscribe() returns False for non-existent subscription."""
bus = EventBus()
result = bus.unsubscribe("sub_nonexistent")
assert result is False
def test_multiple_subscriptions_same_event_type(self):
"""Multiple handlers can subscribe to the same event type."""
bus = EventBus()
received = []
async def handler1(event: AgentEvent) -> None:
received.append("handler1")
async def handler2(event: AgentEvent) -> None:
received.append("handler2")
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler1)
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler2)
assert len(bus._subscriptions) == 2
# ---------------------------------------------------------------------------
# Event publishing tests
# ---------------------------------------------------------------------------
class TestEventPublishing:
"""Tests for publish() and event delivery."""
@pytest.mark.asyncio
async def test_publish_delivers_to_subscriber(self):
"""Published events are delivered to matching subscribers."""
bus = EventBus()
received_events = []
async def handler(event: AgentEvent) -> None:
received_events.append(event)
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test")
await bus.publish(event)
assert len(received_events) == 1
assert received_events[0] == event
@pytest.mark.asyncio
async def test_publish_to_multiple_subscribers(self):
"""Event is delivered to all matching subscribers."""
bus = EventBus()
received = []
async def handler1(event: AgentEvent) -> None:
received.append("h1")
async def handler2(event: AgentEvent) -> None:
received.append("h2")
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler1)
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler2)
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test"))
assert "h1" in received
assert "h2" in received
@pytest.mark.asyncio
async def test_publish_non_matching_type_not_delivered(self):
"""Events with non-matching types are not delivered."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="test"))
assert len(received) == 0
@pytest.mark.asyncio
async def test_publish_adds_to_history(self):
"""Published events are added to history."""
bus = EventBus()
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test")
await bus.publish(event)
history = bus.get_history()
assert len(history) == 1
assert history[0] == event
# ---------------------------------------------------------------------------
# Filter tests
# ---------------------------------------------------------------------------
class TestEventFiltering:
"""Tests for subscription filters."""
@pytest.mark.asyncio
async def test_filter_by_stream(self):
"""filter_stream only receives events from that stream."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event.stream_id)
bus.subscribe(
event_types=[EventType.EXECUTION_STARTED],
handler=handler,
filter_stream="stream_a",
)
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_a"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_b"))
assert received == ["stream_a"]
@pytest.mark.asyncio
async def test_filter_by_node(self):
"""filter_node only receives events from that node."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event.node_id)
bus.subscribe(
event_types=[EventType.NODE_LOOP_STARTED],
handler=handler,
filter_node="node_x",
)
await bus.publish(
AgentEvent(type=EventType.NODE_LOOP_STARTED, stream_id="s", node_id="node_x")
)
await bus.publish(
AgentEvent(type=EventType.NODE_LOOP_STARTED, stream_id="s", node_id="node_y")
)
assert received == ["node_x"]
@pytest.mark.asyncio
async def test_filter_by_execution(self):
"""filter_execution only receives events from that execution."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event.execution_id)
bus.subscribe(
event_types=[EventType.EXECUTION_COMPLETED],
handler=handler,
filter_execution="exec_1",
)
await bus.publish(
AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s", execution_id="exec_1")
)
await bus.publish(
AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s", execution_id="exec_2")
)
assert received == ["exec_1"]
@pytest.mark.asyncio
async def test_combined_filters(self):
"""Multiple filters are AND-ed together."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(True)
bus.subscribe(
event_types=[EventType.TOOL_CALL_COMPLETED],
handler=handler,
filter_stream="stream_1",
filter_node="node_1",
)
# Matches both filters
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_COMPLETED,
stream_id="stream_1",
node_id="node_1",
)
)
# Matches stream but not node
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_COMPLETED,
stream_id="stream_1",
node_id="node_2",
)
)
# Matches node but not stream
await bus.publish(
AgentEvent(
type=EventType.TOOL_CALL_COMPLETED,
stream_id="stream_2",
node_id="node_1",
)
)
assert len(received) == 1
@pytest.mark.asyncio
async def test_filter_by_graph(self):
"""filter_graph only receives events from that graph."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event.graph_id)
bus.subscribe(
event_types=[EventType.EXECUTION_STARTED],
handler=handler,
filter_graph="graph_a",
)
await bus.publish(
AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s", graph_id="graph_a")
)
await bus.publish(
AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s", graph_id="graph_b")
)
assert received == ["graph_a"]
# ---------------------------------------------------------------------------
# Concurrency tests
# ---------------------------------------------------------------------------
class TestConcurrency:
"""Tests for concurrent handler execution."""
@pytest.mark.asyncio
async def test_handler_error_doesnt_crash_others(self):
"""One handler's error doesn't prevent other handlers from running."""
bus = EventBus()
results = []
async def failing_handler(event: AgentEvent) -> None:
raise ValueError("Handler error!")
async def working_handler(event: AgentEvent) -> None:
results.append("success")
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=failing_handler)
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=working_handler)
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test"))
assert "success" in results
@pytest.mark.asyncio
async def test_max_concurrent_handlers_respected(self):
"""Semaphore limits concurrent handler executions."""
bus = EventBus(max_concurrent_handlers=2)
concurrent_count = 0
max_concurrent = 0
async def slow_handler(event: AgentEvent) -> None:
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
await asyncio.sleep(0.1)
concurrent_count -= 1
# Subscribe 5 handlers
for _ in range(5):
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=slow_handler)
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test"))
# Max concurrent should be limited to 2
assert max_concurrent <= 2
# ---------------------------------------------------------------------------
# History and query tests
# ---------------------------------------------------------------------------
class TestHistoryAndQueries:
"""Tests for get_history() and get_stats()."""
@pytest.mark.asyncio
async def test_history_returns_events_most_recent_first(self):
"""get_history() returns events in reverse chronological order."""
bus = EventBus()
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s1"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s2"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_FAILED, stream_id="s3"))
history = bus.get_history()
assert history[0].stream_id == "s3" # Most recent
assert history[1].stream_id == "s2"
assert history[2].stream_id == "s1" # Oldest
@pytest.mark.asyncio
async def test_history_filter_by_event_type(self):
"""get_history() can filter by event type."""
bus = EventBus()
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
history = bus.get_history(event_type=EventType.EXECUTION_STARTED)
assert len(history) == 2
assert all(e.type == EventType.EXECUTION_STARTED for e in history)
@pytest.mark.asyncio
async def test_history_filter_by_stream_id(self):
"""get_history() can filter by stream_id."""
bus = EventBus()
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_a"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_b"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_a"))
history = bus.get_history(stream_id="stream_a")
assert len(history) == 2
assert all(e.stream_id == "stream_a" for e in history)
@pytest.mark.asyncio
async def test_history_limit(self):
"""get_history() respects limit parameter."""
bus = EventBus()
for i in range(10):
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id=f"s{i}"))
history = bus.get_history(limit=3)
assert len(history) == 3
@pytest.mark.asyncio
async def test_max_history_enforced(self):
"""EventBus enforces max_history limit."""
bus = EventBus(max_history=5)
for i in range(10):
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id=f"s{i}"))
assert len(bus._event_history) == 5
# Should have the 5 most recent
assert bus._event_history[-1].stream_id == "s9"
assert bus._event_history[0].stream_id == "s5"
@pytest.mark.asyncio
async def test_get_stats(self):
"""get_stats() returns accurate statistics."""
bus = EventBus()
async def handler(event: AgentEvent) -> None:
pass
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
bus.subscribe(event_types=[EventType.EXECUTION_COMPLETED], handler=handler)
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s"))
stats = bus.get_stats()
assert stats["total_events"] == 3
assert stats["subscriptions"] == 2
assert stats["events_by_type"]["execution_started"] == 2
assert stats["events_by_type"]["execution_completed"] == 1
# ---------------------------------------------------------------------------
# Wait operations tests
# ---------------------------------------------------------------------------
class TestWaitOperations:
"""Tests for wait_for() async waiting."""
@pytest.mark.asyncio
async def test_wait_for_receives_event(self):
"""wait_for() returns when matching event is published."""
bus = EventBus()
async def publish_later():
await asyncio.sleep(0.05)
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="test",
execution_id="exec_1",
)
)
asyncio.create_task(publish_later())
event = await bus.wait_for(
event_type=EventType.EXECUTION_COMPLETED,
timeout=1.0,
)
assert event is not None
assert event.type == EventType.EXECUTION_COMPLETED
@pytest.mark.asyncio
async def test_wait_for_timeout_returns_none(self):
"""wait_for() returns None on timeout."""
bus = EventBus()
event = await bus.wait_for(
event_type=EventType.EXECUTION_COMPLETED,
timeout=0.05,
)
assert event is None
@pytest.mark.asyncio
async def test_wait_for_with_filters(self):
"""wait_for() respects filters."""
bus = EventBus()
async def publish_events():
await asyncio.sleep(0.02)
# This one shouldn't match
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="wrong_stream",
)
)
await asyncio.sleep(0.02)
# This one should match
await bus.publish(
AgentEvent(
type=EventType.EXECUTION_COMPLETED,
stream_id="correct_stream",
)
)
asyncio.create_task(publish_events())
event = await bus.wait_for(
event_type=EventType.EXECUTION_COMPLETED,
stream_id="correct_stream",
timeout=1.0,
)
assert event is not None
assert event.stream_id == "correct_stream"
@pytest.mark.asyncio
async def test_wait_for_cleans_up_subscription(self):
"""wait_for() removes its subscription after completion."""
bus = EventBus()
initial_count = len(bus._subscriptions)
async def publish_later():
await asyncio.sleep(0.02)
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s"))
asyncio.create_task(publish_later())
await bus.wait_for(event_type=EventType.EXECUTION_COMPLETED, timeout=1.0)
assert len(bus._subscriptions) == initial_count
# ---------------------------------------------------------------------------
# Convenience publisher tests
# ---------------------------------------------------------------------------
class TestConveniencePublishers:
"""Tests for emit_* convenience methods."""
@pytest.mark.asyncio
async def test_emit_execution_started(self):
"""emit_execution_started publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
await bus.emit_execution_started(
stream_id="test_stream",
execution_id="exec_1",
input_data={"key": "value"},
correlation_id="corr_1",
)
assert len(received) == 1
assert received[0].type == EventType.EXECUTION_STARTED
assert received[0].stream_id == "test_stream"
assert received[0].execution_id == "exec_1"
assert received[0].data == {"input": {"key": "value"}}
assert received[0].correlation_id == "corr_1"
@pytest.mark.asyncio
async def test_emit_execution_completed(self):
"""emit_execution_completed publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.EXECUTION_COMPLETED], handler=handler)
await bus.emit_execution_completed(
stream_id="s",
execution_id="e",
output={"result": "success"},
)
assert len(received) == 1
assert received[0].type == EventType.EXECUTION_COMPLETED
assert received[0].data == {"output": {"result": "success"}}
@pytest.mark.asyncio
async def test_emit_execution_failed(self):
"""emit_execution_failed publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.EXECUTION_FAILED], handler=handler)
await bus.emit_execution_failed(
stream_id="s",
execution_id="e",
error="Something went wrong",
)
assert len(received) == 1
assert received[0].type == EventType.EXECUTION_FAILED
assert received[0].data == {"error": "Something went wrong"}
@pytest.mark.asyncio
async def test_emit_tool_call_started(self):
"""emit_tool_call_started publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
await bus.emit_tool_call_started(
stream_id="s",
node_id="n",
tool_use_id="tool_1",
tool_name="web_search",
tool_input={"query": "test"},
)
assert len(received) == 1
assert received[0].data["tool_name"] == "web_search"
assert received[0].data["tool_input"] == {"query": "test"}
@pytest.mark.asyncio
async def test_emit_tool_call_completed(self):
"""emit_tool_call_completed publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
await bus.emit_tool_call_completed(
stream_id="s",
node_id="n",
tool_use_id="tool_1",
tool_name="web_search",
result="search results",
is_error=False,
)
assert len(received) == 1
assert received[0].data["result"] == "search results"
assert received[0].data["is_error"] is False
@pytest.mark.asyncio
async def test_emit_webhook_received(self):
"""emit_webhook_received publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.WEBHOOK_RECEIVED], handler=handler)
await bus.emit_webhook_received(
source_id="webhook_1",
path="/api/webhook",
method="POST",
headers={"Content-Type": "application/json"},
payload={"data": "test"},
query_params={"token": "abc"},
)
assert len(received) == 1
assert received[0].data["path"] == "/api/webhook"
assert received[0].data["method"] == "POST"
assert received[0].data["payload"] == {"data": "test"}
@pytest.mark.asyncio
async def test_emit_tool_doom_loop(self):
"""emit_tool_doom_loop publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.NODE_TOOL_DOOM_LOOP], handler=handler)
await bus.emit_tool_doom_loop(
stream_id="test_stream",
node_id="node_1",
description="Tool called same endpoint 5 times",
execution_id="exec_1",
)
assert len(received) == 1
assert received[0].type == EventType.NODE_TOOL_DOOM_LOOP
assert received[0].stream_id == "test_stream"
assert received[0].node_id == "node_1"
assert received[0].data["description"] == "Tool called same endpoint 5 times"
@pytest.mark.asyncio
async def test_emit_escalation_requested(self):
"""emit_escalation_requested publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.ESCALATION_REQUESTED], handler=handler)
await bus.emit_escalation_requested(
stream_id="test_stream",
node_id="node_1",
reason="Need human intervention",
context="Complex decision required",
execution_id="exec_1",
)
assert len(received) == 1
assert received[0].type == EventType.ESCALATION_REQUESTED
assert received[0].stream_id == "test_stream"
assert received[0].node_id == "node_1"
assert received[0].data["reason"] == "Need human intervention"
assert received[0].data["context"] == "Complex decision required"
@pytest.mark.asyncio
async def test_emit_llm_turn_complete(self):
"""emit_llm_turn_complete publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.LLM_TURN_COMPLETE], handler=handler)
await bus.emit_llm_turn_complete(
stream_id="test_stream",
node_id="node_1",
stop_reason="end_turn",
model="claude-sonnet-4-20250514",
input_tokens=100,
output_tokens=50,
execution_id="exec_1",
iteration=3,
)
assert len(received) == 1
assert received[0].type == EventType.LLM_TURN_COMPLETE
assert received[0].data["stop_reason"] == "end_turn"
assert received[0].data["model"] == "claude-sonnet-4-20250514"
assert received[0].data["input_tokens"] == 100
assert received[0].data["output_tokens"] == 50
assert received[0].data["iteration"] == 3
@pytest.mark.asyncio
async def test_emit_node_action_plan(self):
"""emit_node_action_plan publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.NODE_ACTION_PLAN], handler=handler)
await bus.emit_node_action_plan(
stream_id="test_stream",
node_id="node_1",
plan="1. Search for data\n2. Analyze results\n3. Generate report",
execution_id="exec_1",
)
assert len(received) == 1
assert received[0].type == EventType.NODE_ACTION_PLAN
assert (
received[0].data["plan"] == "1. Search for data\n2. Analyze results\n3. Generate report"
)
@pytest.mark.asyncio
async def test_emit_subagent_report(self):
"""emit_subagent_report publishes correct event."""
bus = EventBus()
received = []
async def handler(event: AgentEvent) -> None:
received.append(event)
bus.subscribe(event_types=[EventType.SUBAGENT_REPORT], handler=handler)
await bus.emit_subagent_report(
stream_id="test_stream",
node_id="queen",
subagent_id="worker-1",
message="Task 50% complete",
data={"progress": 0.5},
)
assert len(received) == 1
assert received[0].type == EventType.SUBAGENT_REPORT
assert received[0].data["subagent_id"] == "worker-1"
assert received[0].data["message"] == "Task 50% complete"
assert received[0].data["data"]["progress"] == 0.5
# ---------------------------------------------------------------------------
# EventType enum tests
# ---------------------------------------------------------------------------
class TestEventType:
"""Tests for EventType enum."""
def test_all_event_types_are_strings(self):
"""All EventType values are strings."""
for event_type in EventType:
assert isinstance(event_type.value, str)
def test_event_types_are_unique(self):
"""All EventType values are unique."""
values = [e.value for e in EventType]
assert len(values) == len(set(values))
def test_key_event_types_exist(self):
"""Key event types are defined."""
assert EventType.EXECUTION_STARTED
assert EventType.EXECUTION_COMPLETED
assert EventType.EXECUTION_FAILED
assert EventType.EXECUTION_PAUSED
assert EventType.EXECUTION_RESUMED
assert EventType.TOOL_CALL_STARTED
assert EventType.TOOL_CALL_COMPLETED
assert EventType.WEBHOOK_RECEIVED
assert EventType.NODE_TOOL_DOOM_LOOP
assert EventType.ESCALATION_REQUESTED
assert EventType.LLM_TURN_COMPLETE
assert EventType.NODE_ACTION_PLAN
assert EventType.WORKER_GRAPH_LOADED
assert EventType.CREDENTIALS_REQUIRED
assert EventType.EXECUTION_RESURRECTED
assert EventType.DRAFT_GRAPH_UPDATED
assert EventType.FLOWCHART_MAP_UPDATED
assert EventType.QUEEN_PHASE_CHANGED
assert EventType.QUEEN_PERSONA_SELECTED
assert EventType.SUBAGENT_REPORT
assert EventType.TRIGGER_AVAILABLE
assert EventType.TRIGGER_FIRED
+59 -50
View File
@@ -18,6 +18,24 @@ from framework.agents.queen.recall_selector import (
from framework.graph.prompting import build_system_prompt_for_node_context
from framework.tools.queen_lifecycle_tools import QueenPhaseState
def _make_litellm_response(tool_calls: list[dict] | None = None, content: str = ""):
"""Build a mock that mirrors litellm ModelResponse structure."""
if tool_calls:
tc_objects = []
for tc in tool_calls:
fn = SimpleNamespace(
name=tc["name"],
arguments=json.dumps(tc.get("input", {})),
)
tc_objects.append(SimpleNamespace(id=tc["id"], function=fn))
message = SimpleNamespace(tool_calls=tc_objects)
else:
message = SimpleNamespace(tool_calls=None)
raw = SimpleNamespace(choices=[SimpleNamespace(message=message)])
return MagicMock(content=content, raw_response=raw)
# ---------------------------------------------------------------------------
# parse_frontmatter
# ---------------------------------------------------------------------------
@@ -244,28 +262,25 @@ async def test_short_reflection(tmp_path: Path):
llm = AsyncMock()
llm.acomplete.side_effect = [
# Turn 1: LLM writes a global memory file
MagicMock(
content="",
raw_response={
"tool_calls": [
{
"id": "tc_1",
"name": "write_memory_file",
"input": {
"filename": "user-likes-tests.md",
"content": (
"---\nname: user-likes-tests\n"
"type: preference\n"
"description: User values thorough testing\n"
"---\nObserved emphasis on test coverage."
),
},
}
]
},
_make_litellm_response(
tool_calls=[
{
"id": "tc_1",
"name": "write_memory_file",
"input": {
"filename": "user-likes-tests.md",
"content": (
"---\nname: user-likes-tests\n"
"type: preference\n"
"description: User values thorough testing\n"
"---\nObserved emphasis on test coverage."
),
},
}
]
),
# Turn 2: done
MagicMock(content="Done reflecting.", raw_response={}),
_make_litellm_response(content="Done reflecting."),
]
session_dir = tmp_path / "session"
@@ -312,39 +327,33 @@ async def test_long_reflection(tmp_path: Path):
llm = AsyncMock()
llm.acomplete.side_effect = [
MagicMock(
content="",
raw_response={
"tool_calls": [
{"id": "tc_1", "name": "list_memory_files", "input": {}},
]
},
_make_litellm_response(
tool_calls=[
{"id": "tc_1", "name": "list_memory_files", "input": {}},
]
),
MagicMock(
content="",
raw_response={
"tool_calls": [
{
"id": "tc_2",
"name": "write_memory_file",
"input": {
"filename": "dup-a.md",
"content": (
"---\nname: dup-a\ntype: profile\n"
"description: profile A (merged)\n"
"---\nProfile A details. Also same profile A."
),
},
_make_litellm_response(
tool_calls=[
{
"id": "tc_2",
"name": "write_memory_file",
"input": {
"filename": "dup-a.md",
"content": (
"---\nname: dup-a\ntype: profile\n"
"description: profile A (merged)\n"
"---\nProfile A details. Also same profile A."
),
},
{
"id": "tc_3",
"name": "delete_memory_file",
"input": {"filename": "dup-b.md"},
},
]
},
},
{
"id": "tc_3",
"name": "delete_memory_file",
"input": {"filename": "dup-b.md"},
},
]
),
MagicMock(content="Housekeeping complete.", raw_response={}),
_make_litellm_response(content="Housekeeping complete."),
]
await run_long_reflection(llm, memory_dir=mem_dir)
+2 -2
View File
@@ -137,7 +137,7 @@ export MOCK_MODE=1
# Fernet encryption key for credential store at ~/.hive/credentials
export HIVE_CREDENTIAL_KEY="your-fernet-key"
# Custom agent storage path (default: /tmp)
# Custom agent storage path (default: ~/.hive/agents/{agent_name}/)
export AGENT_STORAGE_PATH="/custom/storage"
```
@@ -152,7 +152,7 @@ CONFIG = {
"max_tokens": 8192, # default: DEFAULT_MAX_TOKENS from framework.graph
"temperature": 0.7,
"tools": ["web_search", "pdf_read"], # MCP tools to enable
"storage_path": "/tmp/my_agent", # Runtime data location
"storage_path": "~/.hive/agents/my_agent/", # Runtime data location (default)
}
```
+1 -1
View File
@@ -446,7 +446,7 @@ Quickstart also supports selecting OpenRouter and Hive LLM interactively. See [c
# Fernet encryption key for credential store at ~/.hive/credentials
export HIVE_CREDENTIAL_KEY="your-fernet-key"
# Agent storage location (default: /tmp)
# Agent storage location (default: ~/.hive/agents/{agent_name}/)
export AGENT_STORAGE_PATH="/custom/storage"
```
@@ -139,6 +139,7 @@ from .trello import TRELLO_CREDENTIALS
from .twilio import TWILIO_CREDENTIALS
from .twitter import TWITTER_CREDENTIALS
from .vercel import VERCEL_CREDENTIALS
from .wandb import WANDB_CREDENTIALS
from .youtube import YOUTUBE_CREDENTIALS
from .zendesk import ZENDESK_CREDENTIALS
from .zoho_crm import ZOHO_CRM_CREDENTIALS
@@ -219,6 +220,7 @@ CREDENTIAL_SPECS = {
**TWITTER_CREDENTIALS,
**VERCEL_CREDENTIALS,
**YOUTUBE_CREDENTIALS,
**WANDB_CREDENTIALS,
**ZENDESK_CREDENTIALS,
**ZOHO_CRM_CREDENTIALS,
**ZOOM_CREDENTIALS,
@@ -313,6 +315,7 @@ __all__ = [
"TWILIO_CREDENTIALS",
"TWITTER_CREDENTIALS",
"VERCEL_CREDENTIALS",
"WANDB_CREDENTIALS",
"YOUTUBE_CREDENTIALS",
"ZENDESK_CREDENTIALS",
"ZOHO_CRM_CREDENTIALS",
+38
View File
@@ -0,0 +1,38 @@
"""
Weights & Biases integration credentials.
Contains credentials for the W&B GraphQL API.
Requires WANDB_API_KEY only no host configuration needed.
"""
from __future__ import annotations
from .base import CredentialSpec
WANDB_CREDENTIALS = {
"wandb_api_key": CredentialSpec(
env_var="WANDB_API_KEY",
tools=[
"wandb_list_projects",
"wandb_list_runs",
"wandb_get_run",
"wandb_get_run_metrics",
"wandb_list_artifacts",
"wandb_get_summary",
],
required=True,
startup_required=False,
help_url="https://wandb.ai/authorize",
description="Weights & Biases API Key",
direct_api_key_supported=True,
api_key_instructions="""To set up W&B API access:
1. Create a W&B account at https://wandb.ai
2. Go to https://wandb.ai/authorize
3. Copy your API key
4. Set environment variable:
export WANDB_API_KEY=your-api-key""",
health_check_endpoint="",
credential_id="wandb_api_key",
credential_key="api_key",
),
}
+2
View File
@@ -133,6 +133,7 @@ from .twilio_tool import register_tools as register_twilio
from .twitter_tool import register_tools as register_twitter
from .vercel_tool import register_tools as register_vercel
from .vision_tool import register_tools as register_vision
from .wandb_tool import register_tools as register_wandb
try:
from .web_scrape_tool import register_tools as register_web_scrape
@@ -306,6 +307,7 @@ def _register_unverified(
register_zendesk(mcp, credentials=credentials)
register_zoho_crm(mcp, credentials=credentials)
register_zoom(mcp, credentials=credentials)
register_wandb(mcp, credentials=credentials)
register_freshdesk(mcp, credentials=credentials)
@@ -0,0 +1,94 @@
# Weights & Biases Tool
Query ML experiment runs, metrics, and artifacts from Weights & Biases using the W&B GraphQL API.
## Tools
| Tool | Description |
|------|-------------|
| `wandb_list_projects` | List all projects for a W&B entity (user or organization) |
| `wandb_list_runs` | List runs in a project with optional filters |
| `wandb_get_run` | Get full details of a specific run (config, state, summary) |
| `wandb_get_run_metrics` | Get sampled metric history for a run |
| `wandb_list_artifacts` | List output artifacts logged by a run |
| `wandb_get_summary` | Get final summary metrics for a run |
## Setup
Requires a W&B account and API key.
1. Create a W&B account at [wandb.ai](https://wandb.ai)
2. Get your API key at [wandb.ai/authorize](https://wandb.ai/authorize)
3. Set the environment variable:
```bash
export WANDB_API_KEY=your-api-key
```
Or configure via the Aden credential store as `wandb_api_key`.
## Usage Examples
### List projects for an entity
```python
wandb_list_projects(entity="my-team")
```
### List recent runs in a project
```python
wandb_list_runs(entity="my-team", project="my-project", per_page=10)
```
### Filter runs by state
```python
wandb_list_runs(
entity="my-team",
project="my-project",
filters='{"state": "finished"}',
)
```
### Get details of a specific run
```python
wandb_get_run(entity="my-team", project="my-project", run_id="abc123")
```
### Get training metrics for a run
```python
wandb_get_run_metrics(
entity="my-team",
project="my-project",
run_id="abc123",
metric_keys="loss,accuracy",
)
```
### Get final summary metrics
```python
wandb_get_summary(entity="my-team", project="my-project", run_id="abc123")
```
### List artifacts produced by a run
```python
wandb_list_artifacts(entity="my-team", project="my-project", run_id="abc123")
```
## Error Handling
All tools return error dicts on failure:
```python
{"error": "Weights & Biases credentials not configured", "help": "Set WANDB_API_KEY..."}
{"error": "Invalid Weights & Biases API key"}
{"error": "Weights & Biases resource not found"}
{"error": "Request timed out"}
{"error": "filters must be a valid JSON string"}
{"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
```
@@ -0,0 +1,5 @@
"""Weights & Biases experiment tracking tool for Aden Tools."""
from .wandb_tool import register_tools
__all__ = ["register_tools"]
@@ -0,0 +1,440 @@
"""
Weights & Biases ML experiment tracking tool.
Uses the W&B GraphQL API via httpx no SDK dependency.
Authentication: Bearer token (WANDB_API_KEY)
GraphQL endpoint: https://api.wandb.ai/graphql
API Reference: https://github.com/wandb/wandb/blob/main/wandb/proto/wandb_internal.proto
"""
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any
import httpx
from fastmcp import FastMCP
if TYPE_CHECKING:
from aden_tools.credentials import CredentialStoreAdapter
GRAPHQL_URL = "https://api.wandb.ai/graphql"
def _get_creds(
credentials: CredentialStoreAdapter | None,
) -> tuple[str] | dict[str, Any]:
"""Return (api_key,) or an error dict."""
if credentials is not None:
api_key = credentials.get("wandb_api_key")
else:
api_key = os.getenv("WANDB_API_KEY")
if not api_key:
return {
"error": "Weights & Biases credentials not configured",
"help": (
"Set WANDB_API_KEY environment variable or configure via credential store. "
"Get your API key at https://wandb.ai/authorize"
),
}
return (api_key,)
def _graphql(
api_key: str,
query: str,
variables: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Execute a GraphQL query and return the parsed response."""
try:
resp = httpx.post(
GRAPHQL_URL,
headers={"Authorization": f"Bearer {api_key}"},
json={"query": query, "variables": variables or {}},
timeout=30.0,
)
except httpx.TimeoutException:
return {"error": "Request timed out"}
except httpx.RequestError as e:
return {"error": f"Network error: {e}"}
if resp.status_code == 401:
return {"error": "Invalid Weights & Biases API key"}
if resp.status_code == 403:
return {"error": "Insufficient permissions for this Weights & Biases resource"}
if resp.status_code >= 400:
try:
detail = resp.json().get("errors", [{}])[0].get("message", resp.text)
except Exception:
detail = resp.text
return {"error": f"Weights & Biases API error (HTTP {resp.status_code}): {detail}"}
payload = resp.json()
if "errors" in payload:
msg = payload["errors"][0].get("message", str(payload["errors"]))
return {"error": f"Weights & Biases GraphQL error: {msg}"}
return payload.get("data", {})
def register_tools(
mcp: FastMCP,
credentials: CredentialStoreAdapter | None = None,
) -> None:
"""Register Weights & Biases experiment tracking tools with the MCP server."""
@mcp.tool()
def wandb_list_projects(entity: str) -> dict:
"""
List all projects for a Weights & Biases entity (user or organization).
Args:
entity: The W&B entity name (username or organization).
Returns:
Dict containing the list of projects for the entity.
"""
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query ListProjects($entity: String!) {
projects(entityName: $entity) {
edges {
node {
name
description
createdAt
}
}
}
}
"""
data = _graphql(api_key, query, {"entity": entity})
if "error" in data:
return data
edges = data.get("projects", {}).get("edges", [])
return {
"entity": entity,
"projects": [
{
"name": e["node"]["name"],
"description": e["node"].get("description", ""),
"created_at": e["node"].get("createdAt", ""),
}
for e in edges
],
}
@mcp.tool()
def wandb_list_runs(
entity: str,
project: str,
filters: str = "",
per_page: int = 50,
) -> dict:
"""
List runs in a Weights & Biases project.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
filters: Optional JSON filter string to narrow results.
per_page: Number of runs to return (default 50).
Returns:
Dict containing the list of runs in the project.
"""
parsed_filters = None
if filters:
try:
parsed_filters = json.loads(filters)
except json.JSONDecodeError:
return {"error": "filters must be a valid JSON string"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query ListRuns($project: String!, $entity: String!, $perPage: Int!, $filters: JSONString) {
project(name: $project, entityName: $entity) {
runs(first: $perPage, filters: $filters) {
edges {
node {
name
id
state
createdAt
config
summaryMetrics
}
}
}
}
}
"""
variables: dict[str, Any] = {"project": project, "entity": entity, "perPage": per_page}
if parsed_filters is not None:
variables["filters"] = parsed_filters
data = _graphql(api_key, query, variables)
if "error" in data:
return data
edges = data.get("project", {}).get("runs", {}).get("edges", [])
runs = []
for e in edges:
node = e["node"]
try:
config = json.loads(node.get("config") or "{}")
except (json.JSONDecodeError, TypeError):
config = {}
runs.append(
{
"id": node.get("name"),
"display_name": node.get("id"),
"state": node.get("state"),
"created_at": node.get("createdAt"),
"config": config,
}
)
return {"entity": entity, "project": project, "runs": runs}
@mcp.tool()
def wandb_get_run(entity: str, project: str, run_id: str) -> dict:
"""
Get details of a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
Returns:
Dict containing full run details including config and metadata.
"""
if not run_id:
return {"error": "run_id is required"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query GetRun($project: String!, $entity: String!, $run: String!) {
project(name: $project, entityName: $entity) {
run(name: $run) {
name
id
state
createdAt
config
summaryMetrics
tags
notes
}
}
}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
try:
config = json.loads(node.get("config") or "{}")
except (json.JSONDecodeError, TypeError):
config = {}
try:
summary = json.loads(node.get("summaryMetrics") or "{}")
except (json.JSONDecodeError, TypeError):
summary = {}
return {
"id": node.get("name"),
"display_name": node.get("id"),
"state": node.get("state"),
"created_at": node.get("createdAt"),
"config": config,
"summary": summary,
"tags": node.get("tags") or [],
"notes": node.get("notes") or "",
}
@mcp.tool()
def wandb_get_run_metrics(
entity: str,
project: str,
run_id: str,
metric_keys: str = "",
) -> dict:
"""
Get sampled metrics history for a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
metric_keys: Comma-separated metric keys to sample (e.g. "loss,accuracy").
At least one key is required.
Returns:
Dict containing sampled metric history per key.
"""
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
if not run_id:
return {"error": "run_id is required"}
if not metric_keys:
return {"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
keys = [k.strip() for k in metric_keys.split(",") if k.strip()]
if not keys:
return {"error": "metric_keys must include at least one non-empty key"}
specs = json.dumps([{"key": k} for k in keys])
query = f"""
query GetRunMetrics($project: String!, $entity: String!, $run: String!) {{
project(name: $project, entityName: $entity) {{
run(name: $run) {{
sampledHistory(specs: {specs})
}}
}}
}}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
return {
"run_id": run_id,
"metric_keys": keys,
"history": node.get("sampledHistory", []),
}
@mcp.tool()
def wandb_list_artifacts(entity: str, project: str, run_id: str) -> dict:
"""
List artifacts logged by a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
Returns:
Dict containing the list of output artifacts for the run.
"""
if not run_id:
return {"error": "run_id is required"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query ListArtifacts($project: String!, $entity: String!, $run: String!) {
project(name: $project, entityName: $entity) {
run(name: $run) {
outputArtifacts {
edges {
node {
name
type
description
createdAt
}
}
}
}
}
}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
edges = node.get("outputArtifacts", {}).get("edges", [])
return {
"run_id": run_id,
"artifacts": [
{
"name": e["node"]["name"],
"type": e["node"]["type"],
"description": e["node"].get("description", ""),
"created_at": e["node"].get("createdAt", ""),
}
for e in edges
],
}
@mcp.tool()
def wandb_get_summary(entity: str, project: str, run_id: str) -> dict:
"""
Get summary metrics for a specific Weights & Biases run.
Args:
entity: The W&B entity name (username or organization).
project: The project name.
run_id: The run ID.
Returns:
Dict containing the run's final summary metrics.
"""
if not run_id:
return {"error": "run_id is required"}
creds = _get_creds(credentials)
if isinstance(creds, dict):
return creds
(api_key,) = creds
query = """
query GetSummary($project: String!, $entity: String!, $run: String!) {
project(name: $project, entityName: $entity) {
run(name: $run) {
summaryMetrics
}
}
}
"""
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
if "error" in data:
return data
node = data.get("project", {}).get("run")
if not node:
return {"error": "Weights & Biases resource not found"}
try:
summary = json.loads(node.get("summaryMetrics") or "{}")
except (json.JSONDecodeError, TypeError):
summary = {}
# Filter out internal W&B keys
summary = {k: v for k, v in summary.items() if not k.startswith("_")}
return {"run_id": run_id, "summary": summary}
+317
View File
@@ -0,0 +1,317 @@
"""Tests for wandb_tool - Weights & Biases integration (GraphQL/httpx)."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
import httpx
import pytest
from fastmcp import FastMCP
from aden_tools.tools.wandb_tool.wandb_tool import register_tools
ENV = {"WANDB_API_KEY": "test-key-abcdefghij"}
_PATCH_POST = "aden_tools.tools.wandb_tool.wandb_tool.httpx.post"
def _mock_resp(data: Any, status_code: int = 200) -> MagicMock:
resp = MagicMock()
resp.status_code = status_code
resp.json.return_value = data
resp.text = str(data)
return resp
def _gql_ok(data: dict[str, Any]) -> MagicMock:
"""Wrap data in the GraphQL envelope: {"data": {...}}."""
return _mock_resp({"data": data})
@pytest.fixture
def tool_fns(mcp: FastMCP) -> dict[str, Any]:
register_tools(mcp, credentials=None)
tools = mcp._tool_manager._tools
return {name: tools[name].fn for name in tools}
class TestWandbTool:
# --- Credential tests ---
def test_missing_credentials_returns_error(self, tool_fns: dict[str, Any]) -> None:
"""Missing WANDB_API_KEY must return a descriptive error dict with help."""
with patch.dict("os.environ", {}, clear=True):
result = tool_fns["wandb_list_projects"](entity="test-entity")
assert "error" in result
assert "credentials not configured" in result["error"]
assert "help" in result
# --- wandb_list_projects ---
def test_wandb_list_projects_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_projects returns projects list from GraphQL."""
gql_data = {
"projects": {
"edges": [
{
"node": {
"name": "proj-a",
"description": "Desc A",
"createdAt": "2024-01-01",
}
},
{
"node": {
"name": "proj-b",
"description": "",
"createdAt": "2024-02-01",
}
},
]
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_list_projects"](entity="test-entity")
assert result["entity"] == "test-entity"
assert len(result["projects"]) == 2
assert result["projects"][0]["name"] == "proj-a"
def test_wandb_list_projects_http_401(self, tool_fns: dict[str, Any]) -> None:
"""HTTP 401 returns an invalid key error."""
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_mock_resp({}, status_code=401)),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert result["error"] == "Invalid Weights & Biases API key"
def test_wandb_list_projects_graphql_error(self, tool_fns: dict[str, Any]) -> None:
"""GraphQL error block is surfaced as an error dict."""
gql_err = {"errors": [{"message": "entity not found"}]}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_mock_resp(gql_err)),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert "error" in result
assert "entity not found" in result["error"]
# --- wandb_list_runs ---
def test_wandb_list_runs_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_runs returns runs list."""
gql_data = {
"project": {
"runs": {
"edges": [
{
"node": {
"name": "w854ckuu",
"id": "ferengi-directive-1",
"state": "finished",
"createdAt": "2024-01-01",
"config": "{}",
"summaryMetrics": "{}",
}
}
]
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)) as mock_post,
):
result = tool_fns["wandb_list_runs"](
entity="testentity",
project="testproject",
filters='{"key": "value"}',
per_page=50,
)
assert result["project"] == "testproject"
assert len(result["runs"]) == 1
assert result["runs"][0]["id"] == "w854ckuu"
# Verify filters and per_page were forwarded in GraphQL variables
call_json = mock_post.call_args[1]["json"]
assert call_json["variables"]["perPage"] == 50
assert call_json["variables"]["filters"] == {"key": "value"}
def test_wandb_list_runs_invalid_filters_json(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_runs returns error for invalid JSON filters before any HTTP call."""
with patch.dict("os.environ", ENV):
result = tool_fns["wandb_list_runs"](entity="e", project="p", filters="not-json")
assert "error" in result
assert "valid JSON" in result["error"]
# --- wandb_get_run ---
def test_wandb_get_run_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run returns run details."""
gql_data = {
"project": {
"run": {
"name": "run-123",
"id": "my-run",
"state": "finished",
"createdAt": "2024-01-01",
"config": '{"lr": 0.001}',
"summaryMetrics": '{"accuracy": 0.9}',
"tags": ["v1"],
"notes": "test",
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="run-123")
assert result["id"] == "run-123"
assert result["config"] == {"lr": 0.001}
assert result["summary"] == {"accuracy": 0.9}
def test_wandb_get_run_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run with empty run_id returns error before HTTP call."""
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
def test_wandb_get_run_not_found(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run returns not-found error when run is null."""
gql_data = {"project": {"run": None}}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="nope")
assert "error" in result
assert "not found" in result["error"]
# --- wandb_get_run_metrics ---
def test_wandb_get_run_metrics_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run_metrics returns sampled history."""
gql_data = {
"project": {
"run": {
"sampledHistory": [{"loss": 0.5}, {"loss": 0.3}],
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_run_metrics"](
entity="e", project="p", run_id="r1", metric_keys="loss"
)
assert result["run_id"] == "r1"
assert result["metric_keys"] == ["loss"]
assert result["history"] == [{"loss": 0.5}, {"loss": 0.3}]
def test_wandb_get_run_metrics_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run_metrics with empty run_id returns error."""
with patch.dict("os.environ", ENV):
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
def test_wandb_get_run_metrics_missing_keys(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_run_metrics with no metric_keys returns error."""
with patch.dict("os.environ", ENV):
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="r1")
assert "error" in result
assert "metric_keys is required" in result["error"]
# --- wandb_list_artifacts ---
def test_wandb_list_artifacts_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_artifacts returns artifact list."""
gql_data = {
"project": {
"run": {
"outputArtifacts": {
"edges": [
{
"node": {
"name": "model:v0",
"type": "model",
"description": "",
"createdAt": "2024-01-01",
}
}
]
}
}
}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="r1")
assert result["run_id"] == "r1"
assert result["artifacts"][0]["name"] == "model:v0"
def test_wandb_list_artifacts_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_list_artifacts with empty run_id returns error."""
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
# --- wandb_get_summary ---
def test_wandb_get_summary_success(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_summary returns summary filtering out _-prefixed keys."""
gql_data = {
"project": {"run": {"summaryMetrics": '{"accuracy": 0.9, "loss": 0.1, "_step": 5}'}}
}
with (
patch.dict("os.environ", ENV),
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
):
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="r1")
assert result["run_id"] == "r1"
assert result["summary"]["accuracy"] == 0.9
assert "_step" not in result["summary"]
def test_wandb_get_summary_missing_id(self, tool_fns: dict[str, Any]) -> None:
"""wandb_get_summary with empty run_id returns error."""
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="")
assert "error" in result
assert result["error"] == "run_id is required"
# --- Network/timeout errors ---
def test_timeout_returns_error(self, tool_fns: dict[str, Any]) -> None:
"""httpx.TimeoutException is caught and returns a timeout message."""
with (
patch.dict("os.environ", ENV),
patch(
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
side_effect=httpx.TimeoutException("timeout"),
),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert result["error"] == "Request timed out"
def test_network_error_returns_error(self, tool_fns: dict[str, Any]) -> None:
"""httpx.RequestError is caught and returns a network error message."""
with (
patch.dict("os.environ", ENV),
patch(
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
side_effect=httpx.RequestError("Connection refused"),
),
):
result = tool_fns["wandb_list_projects"](entity="e")
assert "Network error" in result["error"]