Merge branch 'main' into feat/open-hive-colony
This commit is contained in:
@@ -79,3 +79,4 @@ core/tests/*dumps/*
|
||||
screenshots/*
|
||||
|
||||
.gemini/*
|
||||
.coverage
|
||||
|
||||
@@ -342,34 +342,38 @@ def _dump_failed_request(
|
||||
attempt: int,
|
||||
) -> str:
|
||||
"""Dump failed request to a file for debugging. Returns the file path."""
|
||||
FAILED_REQUESTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
FAILED_REQUESTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{error_type}_{model.replace('/', '_')}_{timestamp}.json"
|
||||
filepath = FAILED_REQUESTS_DIR / filename
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{error_type}_{model.replace('/', '_')}_{timestamp}.json"
|
||||
filepath = FAILED_REQUESTS_DIR / filename
|
||||
|
||||
# Build dump data
|
||||
messages = kwargs.get("messages", [])
|
||||
dump_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"error_type": error_type,
|
||||
"attempt": attempt,
|
||||
"estimated_tokens": _estimate_tokens(model, messages),
|
||||
"num_messages": len(messages),
|
||||
"messages": messages,
|
||||
"tools": kwargs.get("tools"),
|
||||
"max_tokens": kwargs.get("max_tokens"),
|
||||
"temperature": kwargs.get("temperature"),
|
||||
}
|
||||
# Build dump data
|
||||
messages = kwargs.get("messages", [])
|
||||
dump_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"error_type": error_type,
|
||||
"attempt": attempt,
|
||||
"estimated_tokens": _estimate_tokens(model, messages),
|
||||
"num_messages": len(messages),
|
||||
"messages": messages,
|
||||
"tools": kwargs.get("tools"),
|
||||
"max_tokens": kwargs.get("max_tokens"),
|
||||
"temperature": kwargs.get("temperature"),
|
||||
}
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dump_data, f, indent=2, default=str)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(dump_data, f, indent=2, default=str)
|
||||
|
||||
# Prune old dumps to prevent unbounded disk growth
|
||||
_prune_failed_request_dumps()
|
||||
# Prune old dumps to prevent unbounded disk growth
|
||||
_prune_failed_request_dumps()
|
||||
|
||||
return str(filepath)
|
||||
return str(filepath)
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to dump request debug log to {FAILED_REQUESTS_DIR}: {e}")
|
||||
return "log_write_failed"
|
||||
|
||||
|
||||
def _compute_retry_delay(
|
||||
|
||||
@@ -0,0 +1,927 @@
|
||||
"""Tests for EventBus pub/sub event system.
|
||||
|
||||
Validates subscription management, event publishing, filtering,
|
||||
concurrency handling, history operations, and convenience publishers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.runtime.event_bus import (
|
||||
AgentEvent,
|
||||
EventBus,
|
||||
EventType,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentEvent dataclass tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAgentEvent:
|
||||
"""Tests for AgentEvent dataclass."""
|
||||
|
||||
def test_minimal_construction(self):
|
||||
"""Event can be created with just type and stream_id."""
|
||||
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test_stream")
|
||||
assert event.type == EventType.EXECUTION_STARTED
|
||||
assert event.stream_id == "test_stream"
|
||||
assert event.node_id is None
|
||||
assert event.execution_id is None
|
||||
assert event.data == {}
|
||||
assert event.correlation_id is None
|
||||
|
||||
def test_full_construction(self):
|
||||
"""Event stores all provided fields."""
|
||||
event = AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id="stream_1",
|
||||
node_id="node_1",
|
||||
execution_id="exec_123",
|
||||
data={"result": "success"},
|
||||
correlation_id="corr_456",
|
||||
)
|
||||
assert event.type == EventType.TOOL_CALL_COMPLETED
|
||||
assert event.stream_id == "stream_1"
|
||||
assert event.node_id == "node_1"
|
||||
assert event.execution_id == "exec_123"
|
||||
assert event.data == {"result": "success"}
|
||||
assert event.correlation_id == "corr_456"
|
||||
|
||||
def test_timestamp_auto_generated(self):
|
||||
"""Timestamp is auto-generated if not provided."""
|
||||
before = datetime.now()
|
||||
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test")
|
||||
after = datetime.now()
|
||||
assert before <= event.timestamp <= after
|
||||
|
||||
def test_to_dict_serialization(self):
|
||||
"""Event can be serialized to dictionary."""
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="stream_1",
|
||||
node_id="node_1",
|
||||
execution_id="exec_1",
|
||||
data={"output": "result"},
|
||||
correlation_id="corr_1",
|
||||
graph_id="graph_1",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert d["type"] == "execution_completed"
|
||||
assert d["stream_id"] == "stream_1"
|
||||
|
||||
def test_to_dict_includes_run_id(self):
|
||||
"""run_id is included in to_dict() when set."""
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="s1",
|
||||
run_id="run-abc",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert d["run_id"] == "run-abc"
|
||||
|
||||
def test_to_dict_omits_run_id_when_none(self):
|
||||
"""run_id is omitted from to_dict() when None."""
|
||||
event = AgentEvent(
|
||||
type=EventType.EXECUTION_STARTED,
|
||||
stream_id="s1",
|
||||
)
|
||||
d = event.to_dict()
|
||||
assert "run_id" not in d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subscription management tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSubscriptionManagement:
|
||||
"""Tests for subscribe/unsubscribe operations."""
|
||||
|
||||
def test_subscribe_returns_id(self):
|
||||
"""subscribe() returns a subscription ID."""
|
||||
bus = EventBus()
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
sub_id = bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
assert sub_id.startswith("sub_")
|
||||
|
||||
def test_subscribe_increments_id(self):
|
||||
"""Each subscription gets a unique incremented ID."""
|
||||
bus = EventBus()
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
id1 = bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
id2 = bus.subscribe(event_types=[EventType.EXECUTION_COMPLETED], handler=handler)
|
||||
id3 = bus.subscribe(event_types=[EventType.EXECUTION_FAILED], handler=handler)
|
||||
|
||||
assert id1 == "sub_1"
|
||||
assert id2 == "sub_2"
|
||||
assert id3 == "sub_3"
|
||||
|
||||
def test_unsubscribe_removes_subscription(self):
|
||||
"""unsubscribe() removes the subscription."""
|
||||
bus = EventBus()
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
sub_id = bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
assert sub_id in bus._subscriptions
|
||||
|
||||
result = bus.unsubscribe(sub_id)
|
||||
assert result is True
|
||||
assert sub_id not in bus._subscriptions
|
||||
|
||||
def test_unsubscribe_nonexistent_returns_false(self):
|
||||
"""unsubscribe() returns False for non-existent subscription."""
|
||||
bus = EventBus()
|
||||
result = bus.unsubscribe("sub_nonexistent")
|
||||
assert result is False
|
||||
|
||||
def test_multiple_subscriptions_same_event_type(self):
|
||||
"""Multiple handlers can subscribe to the same event type."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler1(event: AgentEvent) -> None:
|
||||
received.append("handler1")
|
||||
|
||||
async def handler2(event: AgentEvent) -> None:
|
||||
received.append("handler2")
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler1)
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler2)
|
||||
|
||||
assert len(bus._subscriptions) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event publishing tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventPublishing:
|
||||
"""Tests for publish() and event delivery."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_delivers_to_subscriber(self):
|
||||
"""Published events are delivered to matching subscribers."""
|
||||
bus = EventBus()
|
||||
received_events = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received_events.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
|
||||
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test")
|
||||
await bus.publish(event)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_to_multiple_subscribers(self):
|
||||
"""Event is delivered to all matching subscribers."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler1(event: AgentEvent) -> None:
|
||||
received.append("h1")
|
||||
|
||||
async def handler2(event: AgentEvent) -> None:
|
||||
received.append("h2")
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler1)
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler2)
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test"))
|
||||
|
||||
assert "h1" in received
|
||||
assert "h2" in received
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_non_matching_type_not_delivered(self):
|
||||
"""Events with non-matching types are not delivered."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="test"))
|
||||
|
||||
assert len(received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_adds_to_history(self):
|
||||
"""Published events are added to history."""
|
||||
bus = EventBus()
|
||||
|
||||
event = AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test")
|
||||
await bus.publish(event)
|
||||
|
||||
history = bus.get_history()
|
||||
assert len(history) == 1
|
||||
assert history[0] == event
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventFiltering:
|
||||
"""Tests for subscription filters."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_by_stream(self):
|
||||
"""filter_stream only receives events from that stream."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event.stream_id)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=handler,
|
||||
filter_stream="stream_a",
|
||||
)
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_a"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_b"))
|
||||
|
||||
assert received == ["stream_a"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_by_node(self):
|
||||
"""filter_node only receives events from that node."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event.node_id)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_LOOP_STARTED],
|
||||
handler=handler,
|
||||
filter_node="node_x",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(type=EventType.NODE_LOOP_STARTED, stream_id="s", node_id="node_x")
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(type=EventType.NODE_LOOP_STARTED, stream_id="s", node_id="node_y")
|
||||
)
|
||||
|
||||
assert received == ["node_x"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_by_execution(self):
|
||||
"""filter_execution only receives events from that execution."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event.execution_id)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_COMPLETED],
|
||||
handler=handler,
|
||||
filter_execution="exec_1",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s", execution_id="exec_1")
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s", execution_id="exec_2")
|
||||
)
|
||||
|
||||
assert received == ["exec_1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_filters(self):
|
||||
"""Multiple filters are AND-ed together."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(True)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.TOOL_CALL_COMPLETED],
|
||||
handler=handler,
|
||||
filter_stream="stream_1",
|
||||
filter_node="node_1",
|
||||
)
|
||||
|
||||
# Matches both filters
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id="stream_1",
|
||||
node_id="node_1",
|
||||
)
|
||||
)
|
||||
# Matches stream but not node
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id="stream_1",
|
||||
node_id="node_2",
|
||||
)
|
||||
)
|
||||
# Matches node but not stream
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.TOOL_CALL_COMPLETED,
|
||||
stream_id="stream_2",
|
||||
node_id="node_1",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_by_graph(self):
|
||||
"""filter_graph only receives events from that graph."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event.graph_id)
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.EXECUTION_STARTED],
|
||||
handler=handler,
|
||||
filter_graph="graph_a",
|
||||
)
|
||||
|
||||
await bus.publish(
|
||||
AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s", graph_id="graph_a")
|
||||
)
|
||||
await bus.publish(
|
||||
AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s", graph_id="graph_b")
|
||||
)
|
||||
|
||||
assert received == ["graph_a"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrency tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConcurrency:
|
||||
"""Tests for concurrent handler execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_error_doesnt_crash_others(self):
|
||||
"""One handler's error doesn't prevent other handlers from running."""
|
||||
bus = EventBus()
|
||||
results = []
|
||||
|
||||
async def failing_handler(event: AgentEvent) -> None:
|
||||
raise ValueError("Handler error!")
|
||||
|
||||
async def working_handler(event: AgentEvent) -> None:
|
||||
results.append("success")
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=failing_handler)
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=working_handler)
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test"))
|
||||
|
||||
assert "success" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_concurrent_handlers_respected(self):
|
||||
"""Semaphore limits concurrent handler executions."""
|
||||
bus = EventBus(max_concurrent_handlers=2)
|
||||
concurrent_count = 0
|
||||
max_concurrent = 0
|
||||
|
||||
async def slow_handler(event: AgentEvent) -> None:
|
||||
nonlocal concurrent_count, max_concurrent
|
||||
concurrent_count += 1
|
||||
max_concurrent = max(max_concurrent, concurrent_count)
|
||||
await asyncio.sleep(0.1)
|
||||
concurrent_count -= 1
|
||||
|
||||
# Subscribe 5 handlers
|
||||
for _ in range(5):
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=slow_handler)
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="test"))
|
||||
|
||||
# Max concurrent should be limited to 2
|
||||
assert max_concurrent <= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# History and query tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestHistoryAndQueries:
|
||||
"""Tests for get_history() and get_stats()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_returns_events_most_recent_first(self):
|
||||
"""get_history() returns events in reverse chronological order."""
|
||||
bus = EventBus()
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s1"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s2"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_FAILED, stream_id="s3"))
|
||||
|
||||
history = bus.get_history()
|
||||
assert history[0].stream_id == "s3" # Most recent
|
||||
assert history[1].stream_id == "s2"
|
||||
assert history[2].stream_id == "s1" # Oldest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_filter_by_event_type(self):
|
||||
"""get_history() can filter by event type."""
|
||||
bus = EventBus()
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
|
||||
|
||||
history = bus.get_history(event_type=EventType.EXECUTION_STARTED)
|
||||
assert len(history) == 2
|
||||
assert all(e.type == EventType.EXECUTION_STARTED for e in history)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_filter_by_stream_id(self):
|
||||
"""get_history() can filter by stream_id."""
|
||||
bus = EventBus()
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_a"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_b"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="stream_a"))
|
||||
|
||||
history = bus.get_history(stream_id="stream_a")
|
||||
assert len(history) == 2
|
||||
assert all(e.stream_id == "stream_a" for e in history)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_limit(self):
|
||||
"""get_history() respects limit parameter."""
|
||||
bus = EventBus()
|
||||
|
||||
for i in range(10):
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id=f"s{i}"))
|
||||
|
||||
history = bus.get_history(limit=3)
|
||||
assert len(history) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_history_enforced(self):
|
||||
"""EventBus enforces max_history limit."""
|
||||
bus = EventBus(max_history=5)
|
||||
|
||||
for i in range(10):
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id=f"s{i}"))
|
||||
|
||||
assert len(bus._event_history) == 5
|
||||
# Should have the 5 most recent
|
||||
assert bus._event_history[-1].stream_id == "s9"
|
||||
assert bus._event_history[0].stream_id == "s5"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self):
|
||||
"""get_stats() returns accurate statistics."""
|
||||
bus = EventBus()
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
pass
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_COMPLETED], handler=handler)
|
||||
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_STARTED, stream_id="s"))
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s"))
|
||||
|
||||
stats = bus.get_stats()
|
||||
assert stats["total_events"] == 3
|
||||
assert stats["subscriptions"] == 2
|
||||
assert stats["events_by_type"]["execution_started"] == 2
|
||||
assert stats["events_by_type"]["execution_completed"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wait operations tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWaitOperations:
|
||||
"""Tests for wait_for() async waiting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_receives_event(self):
|
||||
"""wait_for() returns when matching event is published."""
|
||||
bus = EventBus()
|
||||
|
||||
async def publish_later():
|
||||
await asyncio.sleep(0.05)
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="test",
|
||||
execution_id="exec_1",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.create_task(publish_later())
|
||||
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.EXECUTION_COMPLETED,
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
assert event is not None
|
||||
assert event.type == EventType.EXECUTION_COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_timeout_returns_none(self):
|
||||
"""wait_for() returns None on timeout."""
|
||||
bus = EventBus()
|
||||
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.EXECUTION_COMPLETED,
|
||||
timeout=0.05,
|
||||
)
|
||||
|
||||
assert event is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_with_filters(self):
|
||||
"""wait_for() respects filters."""
|
||||
bus = EventBus()
|
||||
|
||||
async def publish_events():
|
||||
await asyncio.sleep(0.02)
|
||||
# This one shouldn't match
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="wrong_stream",
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.02)
|
||||
# This one should match
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="correct_stream",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.create_task(publish_events())
|
||||
|
||||
event = await bus.wait_for(
|
||||
event_type=EventType.EXECUTION_COMPLETED,
|
||||
stream_id="correct_stream",
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
assert event is not None
|
||||
assert event.stream_id == "correct_stream"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_cleans_up_subscription(self):
|
||||
"""wait_for() removes its subscription after completion."""
|
||||
bus = EventBus()
|
||||
|
||||
initial_count = len(bus._subscriptions)
|
||||
|
||||
async def publish_later():
|
||||
await asyncio.sleep(0.02)
|
||||
await bus.publish(AgentEvent(type=EventType.EXECUTION_COMPLETED, stream_id="s"))
|
||||
|
||||
asyncio.create_task(publish_later())
|
||||
|
||||
await bus.wait_for(event_type=EventType.EXECUTION_COMPLETED, timeout=1.0)
|
||||
|
||||
assert len(bus._subscriptions) == initial_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience publisher tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConveniencePublishers:
|
||||
"""Tests for emit_* convenience methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_execution_started(self):
|
||||
"""emit_execution_started publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_STARTED], handler=handler)
|
||||
|
||||
await bus.emit_execution_started(
|
||||
stream_id="test_stream",
|
||||
execution_id="exec_1",
|
||||
input_data={"key": "value"},
|
||||
correlation_id="corr_1",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.EXECUTION_STARTED
|
||||
assert received[0].stream_id == "test_stream"
|
||||
assert received[0].execution_id == "exec_1"
|
||||
assert received[0].data == {"input": {"key": "value"}}
|
||||
assert received[0].correlation_id == "corr_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_execution_completed(self):
|
||||
"""emit_execution_completed publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_COMPLETED], handler=handler)
|
||||
|
||||
await bus.emit_execution_completed(
|
||||
stream_id="s",
|
||||
execution_id="e",
|
||||
output={"result": "success"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.EXECUTION_COMPLETED
|
||||
assert received[0].data == {"output": {"result": "success"}}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_execution_failed(self):
|
||||
"""emit_execution_failed publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.EXECUTION_FAILED], handler=handler)
|
||||
|
||||
await bus.emit_execution_failed(
|
||||
stream_id="s",
|
||||
execution_id="e",
|
||||
error="Something went wrong",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.EXECUTION_FAILED
|
||||
assert received[0].data == {"error": "Something went wrong"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_started(self):
|
||||
"""emit_tool_call_started publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_STARTED], handler=handler)
|
||||
|
||||
await bus.emit_tool_call_started(
|
||||
stream_id="s",
|
||||
node_id="n",
|
||||
tool_use_id="tool_1",
|
||||
tool_name="web_search",
|
||||
tool_input={"query": "test"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["tool_name"] == "web_search"
|
||||
assert received[0].data["tool_input"] == {"query": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_call_completed(self):
|
||||
"""emit_tool_call_completed publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.TOOL_CALL_COMPLETED], handler=handler)
|
||||
|
||||
await bus.emit_tool_call_completed(
|
||||
stream_id="s",
|
||||
node_id="n",
|
||||
tool_use_id="tool_1",
|
||||
tool_name="web_search",
|
||||
result="search results",
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["result"] == "search results"
|
||||
assert received[0].data["is_error"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_webhook_received(self):
|
||||
"""emit_webhook_received publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.WEBHOOK_RECEIVED], handler=handler)
|
||||
|
||||
await bus.emit_webhook_received(
|
||||
source_id="webhook_1",
|
||||
path="/api/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
payload={"data": "test"},
|
||||
query_params={"token": "abc"},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data["path"] == "/api/webhook"
|
||||
assert received[0].data["method"] == "POST"
|
||||
assert received[0].data["payload"] == {"data": "test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_tool_doom_loop(self):
|
||||
"""emit_tool_doom_loop publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_TOOL_DOOM_LOOP], handler=handler)
|
||||
|
||||
await bus.emit_tool_doom_loop(
|
||||
stream_id="test_stream",
|
||||
node_id="node_1",
|
||||
description="Tool called same endpoint 5 times",
|
||||
execution_id="exec_1",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.NODE_TOOL_DOOM_LOOP
|
||||
assert received[0].stream_id == "test_stream"
|
||||
assert received[0].node_id == "node_1"
|
||||
assert received[0].data["description"] == "Tool called same endpoint 5 times"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_escalation_requested(self):
|
||||
"""emit_escalation_requested publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.ESCALATION_REQUESTED], handler=handler)
|
||||
|
||||
await bus.emit_escalation_requested(
|
||||
stream_id="test_stream",
|
||||
node_id="node_1",
|
||||
reason="Need human intervention",
|
||||
context="Complex decision required",
|
||||
execution_id="exec_1",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.ESCALATION_REQUESTED
|
||||
assert received[0].stream_id == "test_stream"
|
||||
assert received[0].node_id == "node_1"
|
||||
assert received[0].data["reason"] == "Need human intervention"
|
||||
assert received[0].data["context"] == "Complex decision required"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_llm_turn_complete(self):
|
||||
"""emit_llm_turn_complete publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.LLM_TURN_COMPLETE], handler=handler)
|
||||
|
||||
await bus.emit_llm_turn_complete(
|
||||
stream_id="test_stream",
|
||||
node_id="node_1",
|
||||
stop_reason="end_turn",
|
||||
model="claude-sonnet-4-20250514",
|
||||
input_tokens=100,
|
||||
output_tokens=50,
|
||||
execution_id="exec_1",
|
||||
iteration=3,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.LLM_TURN_COMPLETE
|
||||
assert received[0].data["stop_reason"] == "end_turn"
|
||||
assert received[0].data["model"] == "claude-sonnet-4-20250514"
|
||||
assert received[0].data["input_tokens"] == 100
|
||||
assert received[0].data["output_tokens"] == 50
|
||||
assert received[0].data["iteration"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_node_action_plan(self):
|
||||
"""emit_node_action_plan publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.NODE_ACTION_PLAN], handler=handler)
|
||||
|
||||
await bus.emit_node_action_plan(
|
||||
stream_id="test_stream",
|
||||
node_id="node_1",
|
||||
plan="1. Search for data\n2. Analyze results\n3. Generate report",
|
||||
execution_id="exec_1",
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.NODE_ACTION_PLAN
|
||||
assert (
|
||||
received[0].data["plan"] == "1. Search for data\n2. Analyze results\n3. Generate report"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_subagent_report(self):
|
||||
"""emit_subagent_report publishes correct event."""
|
||||
bus = EventBus()
|
||||
received = []
|
||||
|
||||
async def handler(event: AgentEvent) -> None:
|
||||
received.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.SUBAGENT_REPORT], handler=handler)
|
||||
|
||||
await bus.emit_subagent_report(
|
||||
stream_id="test_stream",
|
||||
node_id="queen",
|
||||
subagent_id="worker-1",
|
||||
message="Task 50% complete",
|
||||
data={"progress": 0.5},
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].type == EventType.SUBAGENT_REPORT
|
||||
assert received[0].data["subagent_id"] == "worker-1"
|
||||
assert received[0].data["message"] == "Task 50% complete"
|
||||
assert received[0].data["data"]["progress"] == 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EventType enum tests
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEventType:
|
||||
"""Tests for EventType enum."""
|
||||
|
||||
def test_all_event_types_are_strings(self):
|
||||
"""All EventType values are strings."""
|
||||
for event_type in EventType:
|
||||
assert isinstance(event_type.value, str)
|
||||
|
||||
def test_event_types_are_unique(self):
|
||||
"""All EventType values are unique."""
|
||||
values = [e.value for e in EventType]
|
||||
assert len(values) == len(set(values))
|
||||
|
||||
def test_key_event_types_exist(self):
|
||||
"""Key event types are defined."""
|
||||
assert EventType.EXECUTION_STARTED
|
||||
assert EventType.EXECUTION_COMPLETED
|
||||
assert EventType.EXECUTION_FAILED
|
||||
assert EventType.EXECUTION_PAUSED
|
||||
assert EventType.EXECUTION_RESUMED
|
||||
assert EventType.TOOL_CALL_STARTED
|
||||
assert EventType.TOOL_CALL_COMPLETED
|
||||
assert EventType.WEBHOOK_RECEIVED
|
||||
assert EventType.NODE_TOOL_DOOM_LOOP
|
||||
assert EventType.ESCALATION_REQUESTED
|
||||
assert EventType.LLM_TURN_COMPLETE
|
||||
assert EventType.NODE_ACTION_PLAN
|
||||
assert EventType.WORKER_GRAPH_LOADED
|
||||
assert EventType.CREDENTIALS_REQUIRED
|
||||
assert EventType.EXECUTION_RESURRECTED
|
||||
assert EventType.DRAFT_GRAPH_UPDATED
|
||||
assert EventType.FLOWCHART_MAP_UPDATED
|
||||
assert EventType.QUEEN_PHASE_CHANGED
|
||||
assert EventType.QUEEN_PERSONA_SELECTED
|
||||
assert EventType.SUBAGENT_REPORT
|
||||
assert EventType.TRIGGER_AVAILABLE
|
||||
assert EventType.TRIGGER_FIRED
|
||||
@@ -18,6 +18,24 @@ from framework.agents.queen.recall_selector import (
|
||||
from framework.graph.prompting import build_system_prompt_for_node_context
|
||||
from framework.tools.queen_lifecycle_tools import QueenPhaseState
|
||||
|
||||
|
||||
def _make_litellm_response(tool_calls: list[dict] | None = None, content: str = ""):
|
||||
"""Build a mock that mirrors litellm ModelResponse structure."""
|
||||
if tool_calls:
|
||||
tc_objects = []
|
||||
for tc in tool_calls:
|
||||
fn = SimpleNamespace(
|
||||
name=tc["name"],
|
||||
arguments=json.dumps(tc.get("input", {})),
|
||||
)
|
||||
tc_objects.append(SimpleNamespace(id=tc["id"], function=fn))
|
||||
message = SimpleNamespace(tool_calls=tc_objects)
|
||||
else:
|
||||
message = SimpleNamespace(tool_calls=None)
|
||||
raw = SimpleNamespace(choices=[SimpleNamespace(message=message)])
|
||||
return MagicMock(content=content, raw_response=raw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_frontmatter
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -244,28 +262,25 @@ async def test_short_reflection(tmp_path: Path):
|
||||
llm = AsyncMock()
|
||||
llm.acomplete.side_effect = [
|
||||
# Turn 1: LLM writes a global memory file
|
||||
MagicMock(
|
||||
content="",
|
||||
raw_response={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_1",
|
||||
"name": "write_memory_file",
|
||||
"input": {
|
||||
"filename": "user-likes-tests.md",
|
||||
"content": (
|
||||
"---\nname: user-likes-tests\n"
|
||||
"type: preference\n"
|
||||
"description: User values thorough testing\n"
|
||||
"---\nObserved emphasis on test coverage."
|
||||
),
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
_make_litellm_response(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_1",
|
||||
"name": "write_memory_file",
|
||||
"input": {
|
||||
"filename": "user-likes-tests.md",
|
||||
"content": (
|
||||
"---\nname: user-likes-tests\n"
|
||||
"type: preference\n"
|
||||
"description: User values thorough testing\n"
|
||||
"---\nObserved emphasis on test coverage."
|
||||
),
|
||||
},
|
||||
}
|
||||
]
|
||||
),
|
||||
# Turn 2: done
|
||||
MagicMock(content="Done reflecting.", raw_response={}),
|
||||
_make_litellm_response(content="Done reflecting."),
|
||||
]
|
||||
|
||||
session_dir = tmp_path / "session"
|
||||
@@ -312,39 +327,33 @@ async def test_long_reflection(tmp_path: Path):
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.acomplete.side_effect = [
|
||||
MagicMock(
|
||||
content="",
|
||||
raw_response={
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "name": "list_memory_files", "input": {}},
|
||||
]
|
||||
},
|
||||
_make_litellm_response(
|
||||
tool_calls=[
|
||||
{"id": "tc_1", "name": "list_memory_files", "input": {}},
|
||||
]
|
||||
),
|
||||
MagicMock(
|
||||
content="",
|
||||
raw_response={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_2",
|
||||
"name": "write_memory_file",
|
||||
"input": {
|
||||
"filename": "dup-a.md",
|
||||
"content": (
|
||||
"---\nname: dup-a\ntype: profile\n"
|
||||
"description: profile A (merged)\n"
|
||||
"---\nProfile A details. Also same profile A."
|
||||
),
|
||||
},
|
||||
_make_litellm_response(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_2",
|
||||
"name": "write_memory_file",
|
||||
"input": {
|
||||
"filename": "dup-a.md",
|
||||
"content": (
|
||||
"---\nname: dup-a\ntype: profile\n"
|
||||
"description: profile A (merged)\n"
|
||||
"---\nProfile A details. Also same profile A."
|
||||
),
|
||||
},
|
||||
{
|
||||
"id": "tc_3",
|
||||
"name": "delete_memory_file",
|
||||
"input": {"filename": "dup-b.md"},
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "tc_3",
|
||||
"name": "delete_memory_file",
|
||||
"input": {"filename": "dup-b.md"},
|
||||
},
|
||||
]
|
||||
),
|
||||
MagicMock(content="Housekeeping complete.", raw_response={}),
|
||||
_make_litellm_response(content="Housekeeping complete."),
|
||||
]
|
||||
|
||||
await run_long_reflection(llm, memory_dir=mem_dir)
|
||||
|
||||
@@ -137,7 +137,7 @@ export MOCK_MODE=1
|
||||
# Fernet encryption key for credential store at ~/.hive/credentials
|
||||
export HIVE_CREDENTIAL_KEY="your-fernet-key"
|
||||
|
||||
# Custom agent storage path (default: /tmp)
|
||||
# Custom agent storage path (default: ~/.hive/agents/{agent_name}/)
|
||||
export AGENT_STORAGE_PATH="/custom/storage"
|
||||
```
|
||||
|
||||
@@ -152,7 +152,7 @@ CONFIG = {
|
||||
"max_tokens": 8192, # default: DEFAULT_MAX_TOKENS from framework.graph
|
||||
"temperature": 0.7,
|
||||
"tools": ["web_search", "pdf_read"], # MCP tools to enable
|
||||
"storage_path": "/tmp/my_agent", # Runtime data location
|
||||
"storage_path": "~/.hive/agents/my_agent/", # Runtime data location (default)
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -446,7 +446,7 @@ Quickstart also supports selecting OpenRouter and Hive LLM interactively. See [c
|
||||
# Fernet encryption key for credential store at ~/.hive/credentials
|
||||
export HIVE_CREDENTIAL_KEY="your-fernet-key"
|
||||
|
||||
# Agent storage location (default: /tmp)
|
||||
# Agent storage location (default: ~/.hive/agents/{agent_name}/)
|
||||
export AGENT_STORAGE_PATH="/custom/storage"
|
||||
```
|
||||
|
||||
|
||||
@@ -139,6 +139,7 @@ from .trello import TRELLO_CREDENTIALS
|
||||
from .twilio import TWILIO_CREDENTIALS
|
||||
from .twitter import TWITTER_CREDENTIALS
|
||||
from .vercel import VERCEL_CREDENTIALS
|
||||
from .wandb import WANDB_CREDENTIALS
|
||||
from .youtube import YOUTUBE_CREDENTIALS
|
||||
from .zendesk import ZENDESK_CREDENTIALS
|
||||
from .zoho_crm import ZOHO_CRM_CREDENTIALS
|
||||
@@ -219,6 +220,7 @@ CREDENTIAL_SPECS = {
|
||||
**TWITTER_CREDENTIALS,
|
||||
**VERCEL_CREDENTIALS,
|
||||
**YOUTUBE_CREDENTIALS,
|
||||
**WANDB_CREDENTIALS,
|
||||
**ZENDESK_CREDENTIALS,
|
||||
**ZOHO_CRM_CREDENTIALS,
|
||||
**ZOOM_CREDENTIALS,
|
||||
@@ -313,6 +315,7 @@ __all__ = [
|
||||
"TWILIO_CREDENTIALS",
|
||||
"TWITTER_CREDENTIALS",
|
||||
"VERCEL_CREDENTIALS",
|
||||
"WANDB_CREDENTIALS",
|
||||
"YOUTUBE_CREDENTIALS",
|
||||
"ZENDESK_CREDENTIALS",
|
||||
"ZOHO_CRM_CREDENTIALS",
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Weights & Biases integration credentials.
|
||||
|
||||
Contains credentials for the W&B GraphQL API.
|
||||
Requires WANDB_API_KEY only — no host configuration needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
WANDB_CREDENTIALS = {
|
||||
"wandb_api_key": CredentialSpec(
|
||||
env_var="WANDB_API_KEY",
|
||||
tools=[
|
||||
"wandb_list_projects",
|
||||
"wandb_list_runs",
|
||||
"wandb_get_run",
|
||||
"wandb_get_run_metrics",
|
||||
"wandb_list_artifacts",
|
||||
"wandb_get_summary",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://wandb.ai/authorize",
|
||||
description="Weights & Biases API Key",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To set up W&B API access:
|
||||
1. Create a W&B account at https://wandb.ai
|
||||
2. Go to https://wandb.ai/authorize
|
||||
3. Copy your API key
|
||||
4. Set environment variable:
|
||||
export WANDB_API_KEY=your-api-key""",
|
||||
health_check_endpoint="",
|
||||
credential_id="wandb_api_key",
|
||||
credential_key="api_key",
|
||||
),
|
||||
}
|
||||
@@ -133,6 +133,7 @@ from .twilio_tool import register_tools as register_twilio
|
||||
from .twitter_tool import register_tools as register_twitter
|
||||
from .vercel_tool import register_tools as register_vercel
|
||||
from .vision_tool import register_tools as register_vision
|
||||
from .wandb_tool import register_tools as register_wandb
|
||||
|
||||
try:
|
||||
from .web_scrape_tool import register_tools as register_web_scrape
|
||||
@@ -306,6 +307,7 @@ def _register_unverified(
|
||||
register_zendesk(mcp, credentials=credentials)
|
||||
register_zoho_crm(mcp, credentials=credentials)
|
||||
register_zoom(mcp, credentials=credentials)
|
||||
register_wandb(mcp, credentials=credentials)
|
||||
register_freshdesk(mcp, credentials=credentials)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
# Weights & Biases Tool
|
||||
|
||||
Query ML experiment runs, metrics, and artifacts from Weights & Biases using the W&B GraphQL API.
|
||||
|
||||
## Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `wandb_list_projects` | List all projects for a W&B entity (user or organization) |
|
||||
| `wandb_list_runs` | List runs in a project with optional filters |
|
||||
| `wandb_get_run` | Get full details of a specific run (config, state, summary) |
|
||||
| `wandb_get_run_metrics` | Get sampled metric history for a run |
|
||||
| `wandb_list_artifacts` | List output artifacts logged by a run |
|
||||
| `wandb_get_summary` | Get final summary metrics for a run |
|
||||
|
||||
## Setup
|
||||
|
||||
Requires a W&B account and API key.
|
||||
|
||||
1. Create a W&B account at [wandb.ai](https://wandb.ai)
|
||||
2. Get your API key at [wandb.ai/authorize](https://wandb.ai/authorize)
|
||||
3. Set the environment variable:
|
||||
|
||||
```bash
|
||||
export WANDB_API_KEY=your-api-key
|
||||
```
|
||||
|
||||
Or configure via the Aden credential store as `wandb_api_key`.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### List projects for an entity
|
||||
|
||||
```python
|
||||
wandb_list_projects(entity="my-team")
|
||||
```
|
||||
|
||||
### List recent runs in a project
|
||||
|
||||
```python
|
||||
wandb_list_runs(entity="my-team", project="my-project", per_page=10)
|
||||
```
|
||||
|
||||
### Filter runs by state
|
||||
|
||||
```python
|
||||
wandb_list_runs(
|
||||
entity="my-team",
|
||||
project="my-project",
|
||||
filters='{"state": "finished"}',
|
||||
)
|
||||
```
|
||||
|
||||
### Get details of a specific run
|
||||
|
||||
```python
|
||||
wandb_get_run(entity="my-team", project="my-project", run_id="abc123")
|
||||
```
|
||||
|
||||
### Get training metrics for a run
|
||||
|
||||
```python
|
||||
wandb_get_run_metrics(
|
||||
entity="my-team",
|
||||
project="my-project",
|
||||
run_id="abc123",
|
||||
metric_keys="loss,accuracy",
|
||||
)
|
||||
```
|
||||
|
||||
### Get final summary metrics
|
||||
|
||||
```python
|
||||
wandb_get_summary(entity="my-team", project="my-project", run_id="abc123")
|
||||
```
|
||||
|
||||
### List artifacts produced by a run
|
||||
|
||||
```python
|
||||
wandb_list_artifacts(entity="my-team", project="my-project", run_id="abc123")
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
All tools return error dicts on failure:
|
||||
|
||||
```python
|
||||
{"error": "Weights & Biases credentials not configured", "help": "Set WANDB_API_KEY..."}
|
||||
{"error": "Invalid Weights & Biases API key"}
|
||||
{"error": "Weights & Biases resource not found"}
|
||||
{"error": "Request timed out"}
|
||||
{"error": "filters must be a valid JSON string"}
|
||||
{"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
|
||||
```
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Weights & Biases experiment tracking tool for Aden Tools."""
|
||||
|
||||
from .wandb_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Weights & Biases ML experiment tracking tool.
|
||||
|
||||
Uses the W&B GraphQL API via httpx — no SDK dependency.
|
||||
|
||||
Authentication: Bearer token (WANDB_API_KEY)
|
||||
GraphQL endpoint: https://api.wandb.ai/graphql
|
||||
|
||||
API Reference: https://github.com/wandb/wandb/blob/main/wandb/proto/wandb_internal.proto
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
GRAPHQL_URL = "https://api.wandb.ai/graphql"
|
||||
|
||||
|
||||
def _get_creds(
|
||||
credentials: CredentialStoreAdapter | None,
|
||||
) -> tuple[str] | dict[str, Any]:
|
||||
"""Return (api_key,) or an error dict."""
|
||||
if credentials is not None:
|
||||
api_key = credentials.get("wandb_api_key")
|
||||
else:
|
||||
api_key = os.getenv("WANDB_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
return {
|
||||
"error": "Weights & Biases credentials not configured",
|
||||
"help": (
|
||||
"Set WANDB_API_KEY environment variable or configure via credential store. "
|
||||
"Get your API key at https://wandb.ai/authorize"
|
||||
),
|
||||
}
|
||||
return (api_key,)
|
||||
|
||||
|
||||
def _graphql(
|
||||
api_key: str,
|
||||
query: str,
|
||||
variables: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute a GraphQL query and return the parsed response."""
|
||||
try:
|
||||
resp = httpx.post(
|
||||
GRAPHQL_URL,
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={"query": query, "variables": variables or {}},
|
||||
timeout=30.0,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {e}"}
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"error": "Invalid Weights & Biases API key"}
|
||||
if resp.status_code == 403:
|
||||
return {"error": "Insufficient permissions for this Weights & Biases resource"}
|
||||
if resp.status_code >= 400:
|
||||
try:
|
||||
detail = resp.json().get("errors", [{}])[0].get("message", resp.text)
|
||||
except Exception:
|
||||
detail = resp.text
|
||||
return {"error": f"Weights & Biases API error (HTTP {resp.status_code}): {detail}"}
|
||||
|
||||
payload = resp.json()
|
||||
if "errors" in payload:
|
||||
msg = payload["errors"][0].get("message", str(payload["errors"]))
|
||||
return {"error": f"Weights & Biases GraphQL error: {msg}"}
|
||||
|
||||
return payload.get("data", {})
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Weights & Biases experiment tracking tools with the MCP server."""
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_list_projects(entity: str) -> dict:
|
||||
"""
|
||||
List all projects for a Weights & Biases entity (user or organization).
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
|
||||
Returns:
|
||||
Dict containing the list of projects for the entity.
|
||||
"""
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query ListProjects($entity: String!) {
|
||||
projects(entityName: $entity) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
description
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"entity": entity})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
edges = data.get("projects", {}).get("edges", [])
|
||||
return {
|
||||
"entity": entity,
|
||||
"projects": [
|
||||
{
|
||||
"name": e["node"]["name"],
|
||||
"description": e["node"].get("description", ""),
|
||||
"created_at": e["node"].get("createdAt", ""),
|
||||
}
|
||||
for e in edges
|
||||
],
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_list_runs(
|
||||
entity: str,
|
||||
project: str,
|
||||
filters: str = "",
|
||||
per_page: int = 50,
|
||||
) -> dict:
|
||||
"""
|
||||
List runs in a Weights & Biases project.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
filters: Optional JSON filter string to narrow results.
|
||||
per_page: Number of runs to return (default 50).
|
||||
|
||||
Returns:
|
||||
Dict containing the list of runs in the project.
|
||||
"""
|
||||
parsed_filters = None
|
||||
if filters:
|
||||
try:
|
||||
parsed_filters = json.loads(filters)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "filters must be a valid JSON string"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query ListRuns($project: String!, $entity: String!, $perPage: Int!, $filters: JSONString) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
runs(first: $perPage, filters: $filters) {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
id
|
||||
state
|
||||
createdAt
|
||||
config
|
||||
summaryMetrics
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables: dict[str, Any] = {"project": project, "entity": entity, "perPage": per_page}
|
||||
if parsed_filters is not None:
|
||||
variables["filters"] = parsed_filters
|
||||
data = _graphql(api_key, query, variables)
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
edges = data.get("project", {}).get("runs", {}).get("edges", [])
|
||||
runs = []
|
||||
for e in edges:
|
||||
node = e["node"]
|
||||
try:
|
||||
config = json.loads(node.get("config") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
config = {}
|
||||
runs.append(
|
||||
{
|
||||
"id": node.get("name"),
|
||||
"display_name": node.get("id"),
|
||||
"state": node.get("state"),
|
||||
"created_at": node.get("createdAt"),
|
||||
"config": config,
|
||||
}
|
||||
)
|
||||
return {"entity": entity, "project": project, "runs": runs}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_get_run(entity: str, project: str, run_id: str) -> dict:
|
||||
"""
|
||||
Get details of a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
|
||||
Returns:
|
||||
Dict containing full run details including config and metadata.
|
||||
"""
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query GetRun($project: String!, $entity: String!, $run: String!) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
run(name: $run) {
|
||||
name
|
||||
id
|
||||
state
|
||||
createdAt
|
||||
config
|
||||
summaryMetrics
|
||||
tags
|
||||
notes
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
try:
|
||||
config = json.loads(node.get("config") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
config = {}
|
||||
try:
|
||||
summary = json.loads(node.get("summaryMetrics") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
summary = {}
|
||||
|
||||
return {
|
||||
"id": node.get("name"),
|
||||
"display_name": node.get("id"),
|
||||
"state": node.get("state"),
|
||||
"created_at": node.get("createdAt"),
|
||||
"config": config,
|
||||
"summary": summary,
|
||||
"tags": node.get("tags") or [],
|
||||
"notes": node.get("notes") or "",
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_get_run_metrics(
|
||||
entity: str,
|
||||
project: str,
|
||||
run_id: str,
|
||||
metric_keys: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Get sampled metrics history for a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
metric_keys: Comma-separated metric keys to sample (e.g. "loss,accuracy").
|
||||
At least one key is required.
|
||||
|
||||
Returns:
|
||||
Dict containing sampled metric history per key.
|
||||
"""
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
if not metric_keys:
|
||||
return {"error": "metric_keys is required (comma-separated, e.g. 'loss,accuracy')"}
|
||||
|
||||
keys = [k.strip() for k in metric_keys.split(",") if k.strip()]
|
||||
if not keys:
|
||||
return {"error": "metric_keys must include at least one non-empty key"}
|
||||
specs = json.dumps([{"key": k} for k in keys])
|
||||
|
||||
query = f"""
|
||||
query GetRunMetrics($project: String!, $entity: String!, $run: String!) {{
|
||||
project(name: $project, entityName: $entity) {{
|
||||
run(name: $run) {{
|
||||
sampledHistory(specs: {specs})
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"metric_keys": keys,
|
||||
"history": node.get("sampledHistory", []),
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_list_artifacts(entity: str, project: str, run_id: str) -> dict:
|
||||
"""
|
||||
List artifacts logged by a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
|
||||
Returns:
|
||||
Dict containing the list of output artifacts for the run.
|
||||
"""
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query ListArtifacts($project: String!, $entity: String!, $run: String!) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
run(name: $run) {
|
||||
outputArtifacts {
|
||||
edges {
|
||||
node {
|
||||
name
|
||||
type
|
||||
description
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
edges = node.get("outputArtifacts", {}).get("edges", [])
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"artifacts": [
|
||||
{
|
||||
"name": e["node"]["name"],
|
||||
"type": e["node"]["type"],
|
||||
"description": e["node"].get("description", ""),
|
||||
"created_at": e["node"].get("createdAt", ""),
|
||||
}
|
||||
for e in edges
|
||||
],
|
||||
}
|
||||
|
||||
@mcp.tool()
|
||||
def wandb_get_summary(entity: str, project: str, run_id: str) -> dict:
|
||||
"""
|
||||
Get summary metrics for a specific Weights & Biases run.
|
||||
|
||||
Args:
|
||||
entity: The W&B entity name (username or organization).
|
||||
project: The project name.
|
||||
run_id: The run ID.
|
||||
|
||||
Returns:
|
||||
Dict containing the run's final summary metrics.
|
||||
"""
|
||||
if not run_id:
|
||||
return {"error": "run_id is required"}
|
||||
|
||||
creds = _get_creds(credentials)
|
||||
if isinstance(creds, dict):
|
||||
return creds
|
||||
(api_key,) = creds
|
||||
|
||||
query = """
|
||||
query GetSummary($project: String!, $entity: String!, $run: String!) {
|
||||
project(name: $project, entityName: $entity) {
|
||||
run(name: $run) {
|
||||
summaryMetrics
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
data = _graphql(api_key, query, {"project": project, "entity": entity, "run": run_id})
|
||||
if "error" in data:
|
||||
return data
|
||||
|
||||
node = data.get("project", {}).get("run")
|
||||
if not node:
|
||||
return {"error": "Weights & Biases resource not found"}
|
||||
|
||||
try:
|
||||
summary = json.loads(node.get("summaryMetrics") or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
summary = {}
|
||||
|
||||
# Filter out internal W&B keys
|
||||
summary = {k: v for k, v in summary.items() if not k.startswith("_")}
|
||||
return {"run_id": run_id, "summary": summary}
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Tests for wandb_tool - Weights & Biases integration (GraphQL/httpx)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.wandb_tool.wandb_tool import register_tools
|
||||
|
||||
ENV = {"WANDB_API_KEY": "test-key-abcdefghij"}
|
||||
_PATCH_POST = "aden_tools.tools.wandb_tool.wandb_tool.httpx.post"
|
||||
|
||||
|
||||
def _mock_resp(data: Any, status_code: int = 200) -> MagicMock:
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.json.return_value = data
|
||||
resp.text = str(data)
|
||||
return resp
|
||||
|
||||
|
||||
def _gql_ok(data: dict[str, Any]) -> MagicMock:
|
||||
"""Wrap data in the GraphQL envelope: {"data": {...}}."""
|
||||
return _mock_resp({"data": data})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fns(mcp: FastMCP) -> dict[str, Any]:
|
||||
register_tools(mcp, credentials=None)
|
||||
tools = mcp._tool_manager._tools
|
||||
return {name: tools[name].fn for name in tools}
|
||||
|
||||
|
||||
class TestWandbTool:
|
||||
# --- Credential tests ---
|
||||
|
||||
def test_missing_credentials_returns_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""Missing WANDB_API_KEY must return a descriptive error dict with help."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
result = tool_fns["wandb_list_projects"](entity="test-entity")
|
||||
assert "error" in result
|
||||
assert "credentials not configured" in result["error"]
|
||||
assert "help" in result
|
||||
|
||||
# --- wandb_list_projects ---
|
||||
|
||||
def test_wandb_list_projects_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_projects returns projects list from GraphQL."""
|
||||
gql_data = {
|
||||
"projects": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"name": "proj-a",
|
||||
"description": "Desc A",
|
||||
"createdAt": "2024-01-01",
|
||||
}
|
||||
},
|
||||
{
|
||||
"node": {
|
||||
"name": "proj-b",
|
||||
"description": "",
|
||||
"createdAt": "2024-02-01",
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="test-entity")
|
||||
|
||||
assert result["entity"] == "test-entity"
|
||||
assert len(result["projects"]) == 2
|
||||
assert result["projects"][0]["name"] == "proj-a"
|
||||
|
||||
def test_wandb_list_projects_http_401(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""HTTP 401 returns an invalid key error."""
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_mock_resp({}, status_code=401)),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert result["error"] == "Invalid Weights & Biases API key"
|
||||
|
||||
def test_wandb_list_projects_graphql_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""GraphQL error block is surfaced as an error dict."""
|
||||
gql_err = {"errors": [{"message": "entity not found"}]}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_mock_resp(gql_err)),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert "error" in result
|
||||
assert "entity not found" in result["error"]
|
||||
|
||||
# --- wandb_list_runs ---
|
||||
|
||||
def test_wandb_list_runs_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_runs returns runs list."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"runs": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"name": "w854ckuu",
|
||||
"id": "ferengi-directive-1",
|
||||
"state": "finished",
|
||||
"createdAt": "2024-01-01",
|
||||
"config": "{}",
|
||||
"summaryMetrics": "{}",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)) as mock_post,
|
||||
):
|
||||
result = tool_fns["wandb_list_runs"](
|
||||
entity="testentity",
|
||||
project="testproject",
|
||||
filters='{"key": "value"}',
|
||||
per_page=50,
|
||||
)
|
||||
|
||||
assert result["project"] == "testproject"
|
||||
assert len(result["runs"]) == 1
|
||||
assert result["runs"][0]["id"] == "w854ckuu"
|
||||
# Verify filters and per_page were forwarded in GraphQL variables
|
||||
call_json = mock_post.call_args[1]["json"]
|
||||
assert call_json["variables"]["perPage"] == 50
|
||||
assert call_json["variables"]["filters"] == {"key": "value"}
|
||||
|
||||
def test_wandb_list_runs_invalid_filters_json(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_runs returns error for invalid JSON filters before any HTTP call."""
|
||||
with patch.dict("os.environ", ENV):
|
||||
result = tool_fns["wandb_list_runs"](entity="e", project="p", filters="not-json")
|
||||
assert "error" in result
|
||||
assert "valid JSON" in result["error"]
|
||||
|
||||
# --- wandb_get_run ---
|
||||
|
||||
def test_wandb_get_run_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run returns run details."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"run": {
|
||||
"name": "run-123",
|
||||
"id": "my-run",
|
||||
"state": "finished",
|
||||
"createdAt": "2024-01-01",
|
||||
"config": '{"lr": 0.001}',
|
||||
"summaryMetrics": '{"accuracy": 0.9}',
|
||||
"tags": ["v1"],
|
||||
"notes": "test",
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="run-123")
|
||||
|
||||
assert result["id"] == "run-123"
|
||||
assert result["config"] == {"lr": 0.001}
|
||||
assert result["summary"] == {"accuracy": 0.9}
|
||||
|
||||
def test_wandb_get_run_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run with empty run_id returns error before HTTP call."""
|
||||
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
def test_wandb_get_run_not_found(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run returns not-found error when run is null."""
|
||||
gql_data = {"project": {"run": None}}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_run"](entity="e", project="p", run_id="nope")
|
||||
assert "error" in result
|
||||
assert "not found" in result["error"]
|
||||
|
||||
# --- wandb_get_run_metrics ---
|
||||
|
||||
def test_wandb_get_run_metrics_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run_metrics returns sampled history."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"run": {
|
||||
"sampledHistory": [{"loss": 0.5}, {"loss": 0.3}],
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_run_metrics"](
|
||||
entity="e", project="p", run_id="r1", metric_keys="loss"
|
||||
)
|
||||
|
||||
assert result["run_id"] == "r1"
|
||||
assert result["metric_keys"] == ["loss"]
|
||||
assert result["history"] == [{"loss": 0.5}, {"loss": 0.3}]
|
||||
|
||||
def test_wandb_get_run_metrics_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run_metrics with empty run_id returns error."""
|
||||
with patch.dict("os.environ", ENV):
|
||||
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
def test_wandb_get_run_metrics_missing_keys(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_run_metrics with no metric_keys returns error."""
|
||||
with patch.dict("os.environ", ENV):
|
||||
result = tool_fns["wandb_get_run_metrics"](entity="e", project="p", run_id="r1")
|
||||
assert "error" in result
|
||||
assert "metric_keys is required" in result["error"]
|
||||
|
||||
# --- wandb_list_artifacts ---
|
||||
|
||||
def test_wandb_list_artifacts_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_artifacts returns artifact list."""
|
||||
gql_data = {
|
||||
"project": {
|
||||
"run": {
|
||||
"outputArtifacts": {
|
||||
"edges": [
|
||||
{
|
||||
"node": {
|
||||
"name": "model:v0",
|
||||
"type": "model",
|
||||
"description": "",
|
||||
"createdAt": "2024-01-01",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="r1")
|
||||
|
||||
assert result["run_id"] == "r1"
|
||||
assert result["artifacts"][0]["name"] == "model:v0"
|
||||
|
||||
def test_wandb_list_artifacts_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_list_artifacts with empty run_id returns error."""
|
||||
result = tool_fns["wandb_list_artifacts"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
# --- wandb_get_summary ---
|
||||
|
||||
def test_wandb_get_summary_success(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_summary returns summary filtering out _-prefixed keys."""
|
||||
gql_data = {
|
||||
"project": {"run": {"summaryMetrics": '{"accuracy": 0.9, "loss": 0.1, "_step": 5}'}}
|
||||
}
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(_PATCH_POST, return_value=_gql_ok(gql_data)),
|
||||
):
|
||||
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="r1")
|
||||
|
||||
assert result["run_id"] == "r1"
|
||||
assert result["summary"]["accuracy"] == 0.9
|
||||
assert "_step" not in result["summary"]
|
||||
|
||||
def test_wandb_get_summary_missing_id(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""wandb_get_summary with empty run_id returns error."""
|
||||
result = tool_fns["wandb_get_summary"](entity="e", project="p", run_id="")
|
||||
assert "error" in result
|
||||
assert result["error"] == "run_id is required"
|
||||
|
||||
# --- Network/timeout errors ---
|
||||
|
||||
def test_timeout_returns_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""httpx.TimeoutException is caught and returns a timeout message."""
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(
|
||||
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
|
||||
side_effect=httpx.TimeoutException("timeout"),
|
||||
),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert result["error"] == "Request timed out"
|
||||
|
||||
def test_network_error_returns_error(self, tool_fns: dict[str, Any]) -> None:
|
||||
"""httpx.RequestError is caught and returns a network error message."""
|
||||
with (
|
||||
patch.dict("os.environ", ENV),
|
||||
patch(
|
||||
"aden_tools.tools.wandb_tool.wandb_tool.httpx.post",
|
||||
side_effect=httpx.RequestError("Connection refused"),
|
||||
),
|
||||
):
|
||||
result = tool_fns["wandb_list_projects"](entity="e")
|
||||
assert "Network error" in result["error"]
|
||||
Reference in New Issue
Block a user