Compare commits
175 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| adf1a10318 | |||
| a3916a6932 | |||
| cbd2c86bbf | |||
| f921846879 | |||
| a370403b16 | |||
| ad6d504ea4 | |||
| 65962ddf58 | |||
| bba44430c4 | |||
| 69c71d77fb | |||
| 7b98a6613a | |||
| 26481e27a6 | |||
| bb227b3d73 | |||
| 8a0cf5e0ae | |||
| 69218d5699 | |||
| 7d1433af21 | |||
| 0bfbf1e9c5 | |||
| 1ca4f5b22b | |||
| 0984e4c1e8 | |||
| 4cbf5a7434 | |||
| b33178c5be | |||
| dc6a336c60 | |||
| b855336448 | |||
| de021977fd | |||
| cd2b3fcd16 | |||
| b64024ede5 | |||
| a280d23113 | |||
| 41785abdba | |||
| de494c7e55 | |||
| 5fa0903ea8 | |||
| 7bd99fe074 | |||
| c838e1ca6d | |||
| f475923353 | |||
| 43f43c92e3 | |||
| 5463134322 | |||
| 3fbb392103 | |||
| a162da17e1 | |||
| b565134d57 | |||
| 3aafc89912 | |||
| 93449f92fe | |||
| d766e68d42 | |||
| 1d8b1f9774 | |||
| 5ea9abae83 | |||
| 15957499c5 | |||
| 0b50d9e874 | |||
| a1e54922bd | |||
| 63c0ca34ea | |||
| 135477e516 | |||
| 8cac49cd91 | |||
| 28dce63682 | |||
| 313ac952e0 | |||
| 0633d5130b | |||
| 995e487b49 | |||
| 64b58b57e0 | |||
| c6465908df | |||
| ca96bcc09f | |||
| 65ee628fae | |||
| 02043614e5 | |||
| 212b9bf9d4 | |||
| 6070c30a88 | |||
| 8a653e51bc | |||
| d562670425 | |||
| 677bee6fe5 | |||
| de27bfe76f | |||
| 1c1dcb9c33 | |||
| 4ba950f155 | |||
| 9c3a11d7bb | |||
| b7d357aea2 | |||
| b2fed68346 | |||
| 0e996928be | |||
| 6ff4ec3643 | |||
| a0eda3e492 | |||
| 099f9514ef | |||
| b2096e4a55 | |||
| 1bf2164745 | |||
| 48205bbde7 | |||
| 296aab6ecb | |||
| 14182c45fc | |||
| 2fa8f4283c | |||
| ad3cec2361 | |||
| eddb628298 | |||
| f63b226d8d | |||
| cc5bd61d86 | |||
| 8bd14fb16f | |||
| 30b5472e33 | |||
| bc836db0f9 | |||
| bd3b0fb8eb | |||
| 7f28474967 | |||
| 09460b28bc | |||
| 5d8ba1e49c | |||
| ccb394675b | |||
| 931487a7d4 | |||
| 3654c57f66 | |||
| fb28280ced | |||
| 6215441b58 | |||
| 52f16d5bb6 | |||
| e5b6c8581a | |||
| 5dcca99913 | |||
| 890b906f15 | |||
| 6a8286d4cf | |||
| 680024f790 | |||
| 6f7bfb92a8 | |||
| 335a9603e8 | |||
| 5e8a6202e7 | |||
| 55a4cdefd7 | |||
| 2b63135afb | |||
| 49d8c3572d | |||
| 4b40962186 | |||
| 779b376c6e | |||
| 4e2a9a247a | |||
| b1f3d6b155 | |||
| ea28a9d3c3 | |||
| 69a03e463f | |||
| e7da62e61c | |||
| 7176745e1c | |||
| cce0e26f5c | |||
| 641af16dfc | |||
| a335c427ef | |||
| 9ea6c959ae | |||
| 20efd523c9 | |||
| 8fc7fff496 | |||
| edf51e6996 | |||
| 6b867883ce | |||
| 35a05f4120 | |||
| e0e78a97ce | |||
| e4e476f463 | |||
| c4c8917ecb | |||
| 1524d2ef00 | |||
| 5032834034 | |||
| 0b83f6ea99 | |||
| 415201f467 | |||
| 73005a8498 | |||
| 4edb960fbd | |||
| 42d11ead01 | |||
| 5e18f85b10 | |||
| 85b25bf006 | |||
| c1ba108489 | |||
| 214098aaae | |||
| 241a0b7adc | |||
| 9a7b41a4be | |||
| 746f026654 | |||
| 8294cd3dd9 | |||
| 337fb6d922 | |||
| bda6b18e8a | |||
| d256ff929f | |||
| b00203702e | |||
| 754e33a1ae | |||
| 51154a3070 | |||
| b11b43bbe1 | |||
| 86f4645d1c | |||
| 2d05e96cd5 | |||
| 9c44d3b793 | |||
| 9b89ac694e | |||
| 630d8208cf | |||
| 9b342dc593 | |||
| ad879de6ff | |||
| 795266aab4 | |||
| 4e4ef121f9 | |||
| ddb9126955 | |||
| bac6d6dd68 | |||
| 3451570541 | |||
| e5e939f344 | |||
| 0d51d25482 | |||
| a0a5b10df0 | |||
| 04bac93c14 | |||
| 047f4a1a0c | |||
| 7994b90dfa | |||
| 04b6a80370 | |||
| fc0c3e169f | |||
| 4760f95bda | |||
| a04a8a866d | |||
| 8c9baa62b0 | |||
| 262eaa6d84 | |||
| fc1a48f3bc | |||
| 060f320cd1 | |||
| bff32bcaa3 |
@@ -195,7 +195,7 @@ class DeepResearchAgent:
|
||||
max_tokens=self.config.max_tokens,
|
||||
loop_config={
|
||||
"max_iterations": 100,
|
||||
"max_tool_calls_per_turn": 20,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
},
|
||||
conversation_mode="continuous",
|
||||
|
||||
@@ -71,6 +71,12 @@ Important:
|
||||
- Track which URL each finding comes from (you'll need citations later)
|
||||
- Call set_output for each key in a SEPARATE turn (not in the same turn as other tool calls)
|
||||
|
||||
Context management:
|
||||
- Your tool results are automatically saved to files. After compaction, the file \
|
||||
references remain in the conversation — use load_data() to recover any content you need.
|
||||
- Use append_data('research_notes.md', ...) to maintain a running log of key findings \
|
||||
as you go. This survives compaction and helps the report node produce a detailed report.
|
||||
|
||||
When done, use set_output (one key at a time, separate turns):
|
||||
- set_output("findings", "Structured summary: key findings with source URLs for each claim. \
|
||||
Include themes, contradictions, and confidence levels.")
|
||||
@@ -161,6 +167,9 @@ Requirements:
|
||||
- Every factual claim must cite its source with [n] notation
|
||||
- Be objective — present multiple viewpoints where sources disagree
|
||||
- Answer the original research questions from the brief
|
||||
- If findings appear incomplete or summarized, call list_data_files() and load_data() \
|
||||
to access the detailed source material from the research phase. The research node's \
|
||||
tool results and research_notes.md contain the full data.
|
||||
|
||||
Save the HTML:
|
||||
save_data(filename="report.html", data="<html>...</html>")
|
||||
|
||||
@@ -70,6 +70,7 @@ exports/*
|
||||
.agent-builder-sessions/*
|
||||
|
||||
.claude/settings.local.json
|
||||
.claude/skills/ship-it/
|
||||
|
||||
.venv
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/adenhq/hive/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="Apache 2.0 License" /></a>
|
||||
<a href="https://github.com/aden-hive/hive/blob/main/LICENSE"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="Apache 2.0 License" /></a>
|
||||
<a href="https://www.ycombinator.com/companies/aden"><img src="https://img.shields.io/badge/Y%20Combinator-Aden-orange" alt="Y Combinator" /></a>
|
||||
<a href="https://discord.com/invite/MXE49hrKDk"><img src="https://img.shields.io/discord/1172610340073242735?logo=discord&labelColor=%235462eb&logoColor=%23f5f5f5&color=%235462eb" alt="Discord" /></a>
|
||||
<a href="https://x.com/aden_hq"><img src="https://img.shields.io/twitter/follow/teamaden?logo=X&color=%23f5f5f5" alt="Twitter Follow" /></a>
|
||||
@@ -37,11 +37,11 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Build autonomous, reliable, self-improving AI agents without hardcoding workflows. Define your goal through conversation with a coding agent, and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, credential management, and real-time monitoring give you control without sacrificing adaptability.
|
||||
Build autonomous, reliable, self-improving AI agents without hardcoding workflows. Define your goal through conversation with hive coding agent(queen), and the framework generates a node graph with dynamically created connection code. When things break, the framework captures failure data, evolves the agent through the coding agent, and redeploys. Built-in human-in-the-loop nodes, credential management, and real-time monitoring give you control without sacrificing adaptability.
|
||||
|
||||
Visit [adenhq.com](https://adenhq.com) for complete documentation, examples, and guides.
|
||||
|
||||
https://github.com/user-attachments/assets/846c0cc7-ffd6-47fa-b4b7-495494857a55
|
||||
[](https://www.youtube.com/watch?v=XDOG9fOaLjU)
|
||||
|
||||
## Who Is Hive For?
|
||||
|
||||
@@ -50,7 +50,7 @@ Hive is designed for developers and teams who want to build **production-grade A
|
||||
Hive is a good fit if you:
|
||||
|
||||
- Want AI agents that **execute real business processes**, not demos
|
||||
- Prefer **goal-driven development** over hardcoded workflows
|
||||
- Need **fast or high volume agent execution** over open workflow
|
||||
- Need **self-healing and adaptive agents** that improve over time
|
||||
- Require **human-in-the-loop control**, observability, and cost limits
|
||||
- Plan to run agents in **production environments**
|
||||
@@ -71,7 +71,7 @@ Use Hive when you need:
|
||||
|
||||
- **[Documentation](https://docs.adenhq.com/)** - Complete guides and API reference
|
||||
- **[Self-Hosting Guide](https://docs.adenhq.com/getting-started/quickstart)** - Deploy Hive on your infrastructure
|
||||
- **[Changelog](https://github.com/adenhq/hive/releases)** - Latest updates and releases
|
||||
- **[Changelog](https://github.com/aden-hive/hive/releases)** - Latest updates and releases
|
||||
- **[Roadmap](docs/roadmap.md)** - Upcoming features and plans
|
||||
- **[Report Issues](https://github.com/adenhq/hive/issues)** - Bug reports and feature requests
|
||||
- **[Contributing](CONTRIBUTING.md)** - How to contribute and submit PRs
|
||||
@@ -81,7 +81,7 @@ Use Hive when you need:
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+ for agent development
|
||||
- Claude Code, Codex CLI, or Cursor for utilizing agent skills
|
||||
- An LLM provider that powers the agents
|
||||
|
||||
> **Note for Windows Users:** It is strongly recommended to use **WSL (Windows Subsystem for Linux)** or **Git Bash** to run this framework. Some core automation scripts may not execute correctly in standard Command Prompt or PowerShell.
|
||||
|
||||
@@ -94,9 +94,10 @@ Use Hive when you need:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/adenhq/hive.git
|
||||
git clone https://github.com/aden-hive/hive.git
|
||||
cd hive
|
||||
|
||||
|
||||
# Run quickstart setup
|
||||
./quickstart.sh
|
||||
```
|
||||
@@ -109,77 +110,41 @@ This sets up:
|
||||
- **LLM provider** - Interactive default model configuration
|
||||
- All required Python dependencies with `uv`
|
||||
|
||||
- At last, it will initiate the open hive interface in your browser
|
||||
|
||||
<img width="2500" height="1214" alt="home-screen" src="https://github.com/user-attachments/assets/134d897f-5e75-4874-b00b-e0505f6b45c4" />
|
||||
|
||||
### Build Your First Agent
|
||||
|
||||
```bash
|
||||
# Build an agent using Claude Code
|
||||
claude> /hive
|
||||
Type the agent you want to build in the home input box
|
||||
|
||||
# Test your agent
|
||||
claude> /hive-debugger
|
||||
<img width="2500" height="1214" alt="Image" src="https://github.com/user-attachments/assets/1ce19141-a78b-46f5-8d64-dbf987e048f4" />
|
||||
|
||||
# (at separate terminal) Launch the interactive dashboard
|
||||
hive tui
|
||||
### Use Template Agents
|
||||
|
||||
# Or run directly
|
||||
hive run exports/your_agent_name --input '{"key": "value"}'
|
||||
```
|
||||
Click "Try a sample agent" and check the templates. You can run a templates directly or choose to build your version on top of the existing template.
|
||||
|
||||
## Coding Agent Support
|
||||
### Run Agents
|
||||
|
||||
### Codex CLI
|
||||
Now you can run an agent by selectiing the agent (either an existing agent or example agent). You can click the Run button on the top left, or talk to the queen agent and it can run the agent for you.
|
||||
|
||||
Hive includes native support for [OpenAI Codex CLI](https://github.com/openai/codex) (v0.101.0+).
|
||||
|
||||
1. **Config:** `.codex/config.toml` with `agent-builder` MCP server (tracked in git)
|
||||
2. **Skills:** `.agents/skills/` symlinks to Hive skills (tracked in git)
|
||||
3. **Launch:** Run `codex` in the repo root, then type `use hive`
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
codex> use hive
|
||||
```
|
||||
|
||||
### Opencode
|
||||
|
||||
Hive includes native support for [Opencode](https://github.com/opencode-ai/opencode).
|
||||
|
||||
1. **Setup:** Run the quickstart script
|
||||
2. **Launch:** Open Opencode in the project root.
|
||||
3. **Activate:** Type `/hive` in the chat to switch to the Hive Agent.
|
||||
4. **Verify:** Ask the agent _"List your tools"_ to confirm the connection.
|
||||
|
||||
The agent has access to all Hive skills and can scaffold agents, add tools, and debug workflows directly from the chat.
|
||||
|
||||
**[📖 Complete Setup Guide](docs/environment-setup.md)** - Detailed instructions for agent development
|
||||
|
||||
### Antigravity IDE Support
|
||||
|
||||
Skills and MCP servers are also available in [Antigravity IDE](https://antigravity.google/) (Google's AI-powered IDE). **Easiest:** open a terminal in the hive repo folder and run (use `./` — the script is inside the repo):
|
||||
|
||||
```bash
|
||||
./scripts/setup-antigravity-mcp.sh
|
||||
```
|
||||
|
||||
**Important:** Always restart/refresh Antigravity IDE after running the setup script—MCP servers only load on startup. After restart, **agent-builder** and **tools** MCP servers should connect. Skills are under `.agent/skills/` (symlinks to `.claude/skills/`). See [docs/antigravity-setup.md](docs/antigravity-setup.md) for manual setup and troubleshooting.
|
||||
<img width="2500" height="1214" alt="Image" src="https://github.com/user-attachments/assets/71c38206-2ad5-49aa-bde8-6698d0bc55f5" />
|
||||
|
||||
## Features
|
||||
|
||||
- **[Goal-Driven Development](docs/key_concepts/goals_outcome.md)** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
- **Browser-Use** - Control the browser on your computer to achieve hard tasks
|
||||
- **Parallel Execution** - Execute the generated graph in parallel. This way you can have multiple agent compelteing the jobs for you
|
||||
- **[Goal-Driven Generation](docs/key_concepts/goals_outcome.md)** - Define objectives in natural language; the coding agent generates the agent graph and connection code to achieve them
|
||||
- **[Adaptiveness](docs/key_concepts/evolution.md)** - Framework captures failures, calibrates according to the objectives, and evolves the agent graph
|
||||
- **[Dynamic Node Connections](docs/key_concepts/graph.md)** - No predefined edges; connection code is generated by any capable LLM based on your goals
|
||||
- **SDK-Wrapped Nodes** - Every node gets shared memory, local RLM memory, monitoring, tools, and LLM access out of the box
|
||||
- **[Human-in-the-Loop](docs/key_concepts/graph.md#human-in-the-loop)** - Intervention nodes that pause execution for human input with configurable timeouts and escalation
|
||||
- **Real-time Observability** - WebSocket streaming for live monitoring of agent execution, decisions, and node-to-node communication
|
||||
- **Interactive TUI Dashboard** - Terminal-based dashboard with live graph view, event log, and chat interface for agent interaction
|
||||
- **Cost & Budget Control** - Set spending limits, throttles, and automatic model degradation policies
|
||||
- **Production-Ready** - Self-hostable, built for scale and reliability
|
||||
|
||||
## Integration
|
||||
|
||||
<a href="https://github.com/adenhq/hive/tree/main/tools/src/aden_tools/tools"><img width="100%" alt="Integration" src="https://github.com/user-attachments/assets/a1573f93-cf02-4bb8-b3d5-b305b05b1e51" /></a>
|
||||
|
||||
<a href="https://github.com/aden-hive/hive/tree/main/tools/src/aden_tools/tools"><img width="100%" alt="Integration" src="https://github.com/user-attachments/assets/a1573f93-cf02-4bb8-b3d5-b305b05b1e51" /></a>
|
||||
Hive is built to be model-agnostic and system-agnostic.
|
||||
|
||||
- **LLM flexibility** - Hive Framework is designed to support various types of LLMs, including hosted and local models through LiteLLM-compatible providers.
|
||||
@@ -240,35 +205,10 @@ flowchart LR
|
||||
4. **Control Plane Monitors** → Real-time metrics, budget enforcement, policy management
|
||||
5. **[Adaptiveness](docs/key_concepts/evolution.md)** → On failure, the system evolves the graph and redeploys automatically
|
||||
|
||||
## Run Agents
|
||||
|
||||
The `hive` CLI is the primary interface for running agents.
|
||||
|
||||
```bash
|
||||
# Browse and run agents interactively (Recommended)
|
||||
hive tui
|
||||
|
||||
# Run a specific agent directly
|
||||
hive run exports/my_agent --input '{"task": "Your input here"}'
|
||||
|
||||
# Run a specific agent with the TUI dashboard
|
||||
hive run exports/my_agent --tui
|
||||
|
||||
# Interactive REPL
|
||||
hive shell
|
||||
```
|
||||
|
||||
The TUI scans both `exports/` and `examples/templates/` for available agents.
|
||||
|
||||
> **Using Python directly (alternative):** You can also run agents with `PYTHONPATH=exports uv run python -m agent_name run --input '{...}'`
|
||||
|
||||
See [environment-setup.md](docs/environment-setup.md) for complete setup instructions.
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[Developer Guide](docs/developer-guide.md)** - Comprehensive guide for developers
|
||||
- [Getting Started](docs/getting-started.md) - Quick setup instructions
|
||||
- [TUI Guide](docs/tui-selection-guide.md) - Interactive dashboard usage
|
||||
- [Configuration Guide](docs/configuration.md) - All configuration options
|
||||
- [Architecture Overview](docs/architecture/README.md) - System design and structure
|
||||
|
||||
@@ -398,8 +338,7 @@ flowchart TB
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! We’re especially looking for help building tools, integrations, and example agents for the framework ([check #2805](https://github.com/adenhq/hive/issues/2805)). If you’re interested in extending its functionality, this is the perfect place to start. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
We welcome contributions from the community! We’re especially looking for help building tools, integrations, and example agents for the framework ([check #2805](https://github.com/aden-hive/hive/issues/2805)). If you’re interested in extending its functionality, this is the perfect place to start. Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
|
||||
|
||||
**Important:** Please get assigned to an issue before submitting a PR. Comment on an issue to claim it, and a maintainer will assign you. Issues with reproducible steps and proposals are prioritized. This helps prevent duplicate work.
|
||||
|
||||
@@ -436,7 +375,7 @@ This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENS
|
||||
|
||||
**Q: What LLM providers does Hive support?**
|
||||
|
||||
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name.
|
||||
Hive supports 100+ LLM providers through LiteLLM integration, including OpenAI (GPT-4, GPT-4o), Anthropic (Claude models), Google Gemini, DeepSeek, Mistral, Groq, and many more. Simply set the appropriate API key environment variable and specify the model name. We recommend using Claude, GLM and Gemini as they have the best performance.
|
||||
|
||||
**Q: Can I use Hive with local AI models like Ollama?**
|
||||
|
||||
@@ -478,14 +417,6 @@ Visit [docs.adenhq.com](https://docs.adenhq.com/) for complete guides, API refer
|
||||
|
||||
Contributions are welcome! Fork the repository, create your feature branch, implement your changes, and submit a pull request. See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed guidelines.
|
||||
|
||||
**Q: When will my team start seeing results from Aden's adaptive agents?**
|
||||
|
||||
Aden's adaptation loop begins working from the first execution. When an agent fails, the framework captures the failure data, helping developers evolve the agent graph through the coding agent. How quickly this translates to measurable results depends on the complexity of your use case, the quality of your goal definitions, and the volume of executions generating feedback.
|
||||
|
||||
**Q: How does Hive compare to other agent frameworks?**
|
||||
|
||||
Hive focuses on generating agents that run real business processes, rather than generic agents. This vision emphasizes outcome-driven design, adaptability, and an easy-to-use set of tools and integrations.
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
|
||||
+9
-9
@@ -64,7 +64,7 @@ To use the agent builder with Claude Desktop or other MCP clients, add this to y
|
||||
"agent-builder": {
|
||||
"command": "python",
|
||||
"args": ["-m", "framework.mcp.agent_builder_server"],
|
||||
"cwd": "/path/to/goal-agent"
|
||||
"cwd": "/path/to/hive/core"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,14 +85,14 @@ The MCP server provides tools for:
|
||||
Run an LLM-powered calculator:
|
||||
|
||||
```bash
|
||||
# Single calculation
|
||||
uv run python -m framework calculate "2 + 3 * 4"
|
||||
# Run an exported agent
|
||||
uv run python -m framework run exports/calculator --input '{"expression": "2 + 3 * 4"}'
|
||||
|
||||
# Interactive mode
|
||||
uv run python -m framework interactive
|
||||
# Interactive shell session
|
||||
uv run python -m framework shell exports/calculator
|
||||
|
||||
# Analyze runs with Builder
|
||||
uv run python -m framework analyze calculator
|
||||
# Show agent info
|
||||
uv run python -m framework info exports/calculator
|
||||
```
|
||||
|
||||
### Using the Runtime
|
||||
@@ -141,8 +141,8 @@ uv run python -m framework test-run <agent_path> --goal <goal_id> --parallel 4
|
||||
# Debug failed tests
|
||||
uv run python -m framework test-debug <agent_path> <test_name>
|
||||
|
||||
# List tests for a goal
|
||||
uv run python -m framework test-list <goal_id>
|
||||
# List tests for an agent
|
||||
uv run python -m framework test-list <agent_path>
|
||||
```
|
||||
|
||||
For detailed testing workflows, see the [hive-test skill](../.claude/skills/hive-test/SKILL.md).
|
||||
|
||||
+4
-2
@@ -15,6 +15,7 @@ import base64
|
||||
import hashlib
|
||||
import http.server
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import secrets
|
||||
import subprocess
|
||||
@@ -150,8 +151,9 @@ def save_credentials(token_data: dict, account_id: str) -> None:
|
||||
if "id_token" in token_data:
|
||||
auth_data["tokens"]["id_token"] = token_data["id_token"]
|
||||
|
||||
CODEX_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(CODEX_AUTH_FILE, "w") as f:
|
||||
CODEX_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
fd = os.open(CODEX_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(auth_data, f, indent=2)
|
||||
|
||||
|
||||
|
||||
@@ -1768,7 +1768,7 @@ async def _run_pipeline(websocket, initial_message: str):
|
||||
judge=judge,
|
||||
config=LoopConfig(
|
||||
max_iterations=30,
|
||||
max_tool_calls_per_turn=15,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=64000,
|
||||
max_tool_result_chars=8_000,
|
||||
spillover_dir=str(_DATA_DIR),
|
||||
|
||||
@@ -751,7 +751,7 @@ async def _run_pipeline(websocket, topic: str):
|
||||
judge=None, # implicit judge: accept when output_keys filled
|
||||
config=LoopConfig(
|
||||
max_iterations=20,
|
||||
max_tool_calls_per_turn=10,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_a,
|
||||
@@ -849,7 +849,7 @@ async def _run_pipeline(websocket, topic: str):
|
||||
judge=None, # implicit judge
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
max_tool_calls_per_turn=5,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store_b,
|
||||
|
||||
@@ -1257,7 +1257,7 @@ async def _run_org_pipeline(websocket, topic: str):
|
||||
judge=judge,
|
||||
config=LoopConfig(
|
||||
max_iterations=30,
|
||||
max_tool_calls_per_turn=25,
|
||||
max_tool_calls_per_turn=30,
|
||||
max_history_tokens=32_000,
|
||||
),
|
||||
conversation_store=store,
|
||||
|
||||
@@ -453,7 +453,7 @@ identity_prompt = (
|
||||
)
|
||||
loop_config = {
|
||||
"max_iterations": 50,
|
||||
"max_tool_calls_per_turn": 10,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
|
||||
@@ -539,7 +539,7 @@ class CredentialTesterAgent:
|
||||
max_tokens=self.config.max_tokens,
|
||||
loop_config={
|
||||
"max_iterations": 50,
|
||||
"max_tool_calls_per_turn": 10,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
},
|
||||
conversation_mode="continuous",
|
||||
|
||||
@@ -127,7 +127,7 @@ identity_prompt = (
|
||||
)
|
||||
loop_config = {
|
||||
"max_iterations": 100,
|
||||
"max_tool_calls_per_turn": 20,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
|
||||
@@ -160,8 +160,8 @@ queen_graph = GraphSpec(
|
||||
edges=[],
|
||||
conversation_mode="continuous",
|
||||
loop_config={
|
||||
"max_iterations": 200,
|
||||
"max_tool_calls_per_turn": 10,
|
||||
"max_iterations": 999_999,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -10,13 +10,35 @@ _ref_dir = Path(__file__).parent.parent / "reference"
|
||||
_framework_guide = (_ref_dir / "framework_guide.md").read_text()
|
||||
_file_templates = (_ref_dir / "file_templates.md").read_text()
|
||||
_anti_patterns = (_ref_dir / "anti_patterns.md").read_text()
|
||||
_gcu_guide_path = _ref_dir / "gcu_guide.md"
|
||||
_gcu_guide = _gcu_guide_path.read_text() if _gcu_guide_path.exists() else ""
|
||||
|
||||
|
||||
def _is_gcu_enabled() -> bool:
|
||||
try:
|
||||
from framework.config import get_gcu_enabled
|
||||
|
||||
return get_gcu_enabled()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _build_appendices() -> str:
|
||||
parts = (
|
||||
"\n\n# Appendix: Framework Reference\n\n"
|
||||
+ _framework_guide
|
||||
+ "\n\n# Appendix: File Templates\n\n"
|
||||
+ _file_templates
|
||||
+ "\n\n# Appendix: Anti-Patterns\n\n"
|
||||
+ _anti_patterns
|
||||
)
|
||||
if _is_gcu_enabled() and _gcu_guide:
|
||||
parts += "\n\n# Appendix: GCU Browser Automation Guide\n\n" + _gcu_guide
|
||||
return parts
|
||||
|
||||
|
||||
# Shared appendices — appended to every coding node's system prompt.
|
||||
_appendices = (
|
||||
"\n\n# Appendix: Framework Reference\n\n" + _framework_guide
|
||||
+ "\n\n# Appendix: File Templates\n\n" + _file_templates
|
||||
+ "\n\n# Appendix: Anti-Patterns\n\n" + _anti_patterns
|
||||
)
|
||||
_appendices = _build_appendices()
|
||||
|
||||
# Tools available to both coder (worker) and queen.
|
||||
_SHARED_TOOLS = [
|
||||
@@ -348,7 +370,7 @@ value. These DO NOT EXIST.
|
||||
```python
|
||||
loop_config = {
|
||||
"max_iterations": 100,
|
||||
"max_tool_calls_per_turn": 20,
|
||||
"max_tool_calls_per_turn": 30,
|
||||
"max_history_tokens": 32000,
|
||||
}
|
||||
```
|
||||
@@ -388,7 +410,10 @@ If list_agent_tools() shows these don't exist, use alternatives \
|
||||
**Node rules**:
|
||||
- **2-4 nodes MAX.** Never exceed 4. Merge thin nodes aggressively.
|
||||
- A node with 0 tools is NOT a real node — merge it.
|
||||
- node_type always "event_loop"
|
||||
- node_type "event_loop" for all regular graph nodes. Use "gcu" ONLY for
|
||||
browser automation subagents (see GCU appendix). GCU nodes MUST be in a
|
||||
parent node's sub_agents list, NEVER connected via edges, and NEVER used
|
||||
as entry/terminal nodes.
|
||||
- max_node_visits default is 0 (unbounded) — correct for forever-alive. \
|
||||
Only set >0 in one-shot agents with bounded feedback loops.
|
||||
- Feedback inputs: nullable_output_keys
|
||||
@@ -466,7 +491,7 @@ Most agents use `terminal_nodes=[]` (forever-alive). This means \
|
||||
terminal node that doesn't exist. Agent tests MUST be structural:
|
||||
- Validate graph, node specs, edges, tools, prompts
|
||||
- Check goal/constraints/success criteria definitions
|
||||
- Test `AgentRunner.load()` + `_setup()` (skip if no API key)
|
||||
- Test `AgentRunner.load()` succeeds (structural, no API key needed)
|
||||
- NEVER call `runner.run()` or `trigger_and_wait()` in tests for \
|
||||
forever-alive agents — they will hang and time out.
|
||||
When you restructure an agent (change nodes/edges), always update \
|
||||
@@ -533,14 +558,35 @@ critical issue. Use sparingly.
|
||||
|
||||
## Agent Loading
|
||||
- load_built_agent(agent_path) — Load a newly built agent as the worker in \
|
||||
this session. Call after building and validating an agent to make it \
|
||||
available immediately. The user sees the graph update and can interact \
|
||||
with it without leaving the session.
|
||||
this session. If a worker is already loaded, it is automatically unloaded \
|
||||
first. Call after building and validating an agent to make it available \
|
||||
immediately.
|
||||
|
||||
## Credentials
|
||||
- list_credentials(credential_id?) — List all authorized credentials in the \
|
||||
local store. Returns IDs, aliases, status, and identity metadata (never \
|
||||
secrets). Optionally filter by credential_id.
|
||||
"""
|
||||
|
||||
_queen_behavior = """
|
||||
# Behavior
|
||||
|
||||
## Greeting and identity
|
||||
|
||||
When the user greets you ("hi", "hello") or asks what you can do / \
|
||||
what you are, respond concisely. DO NOT list internal processes \
|
||||
(validation steps, AgentRunner.load, tool discovery). Focus on \
|
||||
user-facing capabilities:
|
||||
|
||||
1. Direct capabilities: file operations, shell commands, coding, \
|
||||
agent building & debugging.
|
||||
2. Delegation: describe what the loaded worker does in one sentence \
|
||||
(read the Worker Profile at the end of this prompt). If no worker \
|
||||
is loaded, say so.
|
||||
3. End with a short prompt: "What do you need?"
|
||||
|
||||
Keep it under 10 lines. No bullet-point dumps of every tool you have.
|
||||
|
||||
## Direct coding
|
||||
You can do any coding task directly — reading files, writing code, running \
|
||||
commands, building agents, debugging. For quick tasks, do them yourself.
|
||||
@@ -556,23 +602,73 @@ subtasks to justify delegation.
|
||||
- Building, modifying, or configuring agents is ALWAYS your job. Never \
|
||||
delegate agent construction to the worker, even as a "research" subtask.
|
||||
|
||||
## When the user says "run", "execute", or "start" (without specifics)
|
||||
|
||||
The loaded worker is described in the Worker Profile below. Ask what \
|
||||
task or topic they want — do NOT call list_agents() or list directories. \
|
||||
The worker is already loaded. Just ask for the input the worker needs \
|
||||
(e.g., a research topic, a target domain, a job description).
|
||||
|
||||
If NO worker is loaded, say so and offer to build one.
|
||||
|
||||
## When idle (worker not running):
|
||||
- Greet the user. Mention what the worker can do.
|
||||
- Greet the user. Mention what the worker can do in one sentence.
|
||||
- For tasks matching the worker's goal, call start_worker(task).
|
||||
- For everything else, do it directly.
|
||||
|
||||
## When the user clicks Run (external event notification)
|
||||
When you receive an event that the user clicked Run:
|
||||
- If the worker started successfully, briefly acknowledge it — do NOT \
|
||||
repeat the full status. The user can see the graph is running.
|
||||
- If the worker failed to start (credential or structural error), \
|
||||
explain the problem clearly and help fix it. For credential errors, \
|
||||
guide the user to set up the missing credentials. For structural \
|
||||
issues, offer to fix the agent graph directly.
|
||||
|
||||
## When worker is running:
|
||||
- If the user asks about progress, call get_worker_status().
|
||||
- If the user asks about progress, call get_worker_status() ONCE and \
|
||||
report the result. Do NOT poll in a loop.
|
||||
- NEVER call get_worker_status() repeatedly without user input in between. \
|
||||
The worker will surface results through client-facing nodes. You do not \
|
||||
need to monitor it. One check per user request is enough.
|
||||
- If the user has a concern or instruction for the worker, call \
|
||||
inject_worker_message(content) to relay it.
|
||||
- You can still do coding tasks directly while the worker runs.
|
||||
- If an escalation ticket arrives from the judge, assess severity:
|
||||
- Low/transient: acknowledge silently, do not disturb the user.
|
||||
- High/critical: notify the user with a brief analysis and suggested action.
|
||||
- After starting the worker or checking its status, WAIT for the user's \
|
||||
next message. Do not take autonomous actions unless the user asks.
|
||||
|
||||
## When worker asks user a question:
|
||||
- The system will route the user's response directly to the worker. \
|
||||
You do not need to relay it. The user will come back to you after responding.
|
||||
|
||||
## Showing or describing the loaded worker
|
||||
|
||||
When the user asks to "show the graph", "describe the agent", or \
|
||||
"re-generate the graph", read the Worker Profile and present the \
|
||||
worker's current architecture as an ASCII diagram. Use the processing \
|
||||
stages, tools, and edges from the loaded worker. Do NOT enter the \
|
||||
agent building workflow — you are describing what already exists, not \
|
||||
building something new.
|
||||
|
||||
## Modifying the loaded worker
|
||||
|
||||
When the user asks to change, modify, or update the loaded worker \
|
||||
(e.g., "change the report node", "add a node", "delete node X"):
|
||||
|
||||
1. Use the **Path** from the Worker Profile to locate the agent files.
|
||||
2. Read the relevant files (nodes/__init__.py, agent.py, etc.).
|
||||
3. Make the requested changes using edit_file / write_file.
|
||||
4. Run validation (default_agent.validate(), AgentRunner.load(), \
|
||||
validate_agent_tools()).
|
||||
5. **Reload the modified worker**: call load_built_agent("{path}") \
|
||||
so the changes take effect immediately. If a worker is already loaded, \
|
||||
stop it first, then reload.
|
||||
|
||||
Do NOT skip step 5 — without reloading, the user will still be \
|
||||
interacting with the old version.
|
||||
"""
|
||||
|
||||
_queen_phase_7 = """
|
||||
@@ -622,7 +718,8 @@ coder_node = NodeSpec(
|
||||
"A complete, validated Hive agent package exists at "
|
||||
"exports/{agent_name}/ and passes structural validation."
|
||||
),
|
||||
tools=_SHARED_TOOLS + [
|
||||
tools=_SHARED_TOOLS
|
||||
+ [
|
||||
# Graph lifecycle tools (multi-graph sessions)
|
||||
"load_agent",
|
||||
"unload_agent",
|
||||
@@ -711,7 +808,8 @@ queen_node = NodeSpec(
|
||||
"User's intent is understood, coding tasks are completed correctly, "
|
||||
"and the worker is managed effectively when delegated to."
|
||||
),
|
||||
tools=_SHARED_TOOLS + [
|
||||
tools=_SHARED_TOOLS
|
||||
+ [
|
||||
# Worker lifecycle
|
||||
"start_worker",
|
||||
"stop_worker",
|
||||
@@ -722,6 +820,8 @@ queen_node = NodeSpec(
|
||||
"notify_operator",
|
||||
# Agent loading
|
||||
"load_built_agent",
|
||||
# Credentials
|
||||
"list_credentials",
|
||||
],
|
||||
system_prompt=(
|
||||
"You are the Queen — the user's primary interface. You are a coding agent "
|
||||
@@ -747,6 +847,8 @@ ALL_QUEEN_TOOLS = _SHARED_TOOLS + [
|
||||
"notify_operator",
|
||||
# Agent loading
|
||||
"load_built_agent",
|
||||
# Credentials
|
||||
"list_credentials",
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -80,7 +80,7 @@ One client-facing node handles ALL user interaction (setup, logging, reports). O
|
||||
- Validate graph structure (nodes, edges, entry points)
|
||||
- Verify node specs (tools, prompts, client-facing flag)
|
||||
- Check goal/constraints/success criteria definitions
|
||||
- Test that `AgentRunner.load()` + `_setup()` succeeds (skip if no API key)
|
||||
- Test that `AgentRunner.load()` succeeds (structural, no API key needed)
|
||||
|
||||
**What NOT to do:**
|
||||
```python
|
||||
@@ -105,3 +105,7 @@ def test_research_routes_back_to_interact(self):
|
||||
23. **Forgetting sys.path setup in conftest.py** — Tests need `exports/` and `core/` on sys.path.
|
||||
|
||||
24. **Not using auto_responder for client-facing nodes** — Tests with client-facing nodes hang without an auto-responder that injects input. But note: even WITH auto_responder, forever-alive agents still hang because the graph never terminates. Auto-responder only helps for agents with terminal nodes.
|
||||
|
||||
25. **Manually wiring browser tools on event_loop nodes** — If the agent needs browser automation, use `node_type="gcu"` which auto-includes all browser tools and prepends best-practices guidance. Do NOT manually list browser tool names on event_loop nodes — they may not exist in the MCP server or may be incomplete. See the GCU Guide appendix.
|
||||
|
||||
26. **Using GCU nodes as regular graph nodes** — GCU nodes (`node_type="gcu"`) are exclusively subagents. They must ONLY appear in a parent node's `sub_agents=["gcu-node-id"]` list and be invoked via `delegate_to_sub_agent()`. They must NEVER be connected via edges, used as entry nodes, or used as terminal nodes. If a GCU node appears as an edge source or target, the graph will fail pre-load validation.
|
||||
|
||||
@@ -235,16 +235,14 @@ class MyAgent:
|
||||
identity_prompt=identity_prompt,
|
||||
)
|
||||
|
||||
def _setup(self, mock_mode=False):
|
||||
def _setup(self):
|
||||
self._storage_path = Path.home() / ".hive" / "agents" / "my_agent"
|
||||
self._storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self._tool_registry = ToolRegistry()
|
||||
mcp_config = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_config.exists():
|
||||
self._tool_registry.load_mcp_config(mcp_config)
|
||||
llm = None
|
||||
if not mock_mode:
|
||||
llm = LiteLLMProvider(model=self.config.model, api_key=self.config.api_key, api_base=self.config.api_base)
|
||||
llm = LiteLLMProvider(model=self.config.model, api_key=self.config.api_key, api_base=self.config.api_base)
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
self._graph = self._build_graph()
|
||||
@@ -257,9 +255,9 @@ class MyAgent:
|
||||
checkpoint_max_age_days=7, async_checkpoint=True),
|
||||
)
|
||||
|
||||
async def start(self, mock_mode=False):
|
||||
async def start(self):
|
||||
if self._agent_runtime is None:
|
||||
self._setup(mock_mode=mock_mode)
|
||||
self._setup()
|
||||
if not self._agent_runtime.is_running:
|
||||
await self._agent_runtime.start()
|
||||
|
||||
@@ -274,8 +272,8 @@ class MyAgent:
|
||||
return await self._agent_runtime.trigger_and_wait(
|
||||
entry_point_id=entry_point, input_data=input_data or {}, session_state=session_state)
|
||||
|
||||
async def run(self, context, mock_mode=False, session_state=None):
|
||||
await self.start(mock_mode=mock_mode)
|
||||
async def run(self, context, session_state=None):
|
||||
await self.start()
|
||||
try:
|
||||
result = await self.trigger_and_wait("default", context, session_state=session_state)
|
||||
return result or ExecutionResult(success=False, error="Execution timeout")
|
||||
@@ -471,19 +469,17 @@ def cli():
|
||||
|
||||
@cli.command()
|
||||
@click.option("--topic", "-t", required=True)
|
||||
@click.option("--mock", is_flag=True)
|
||||
@click.option("--verbose", "-v", is_flag=True)
|
||||
def run(topic, mock, verbose):
|
||||
def run(topic, verbose):
|
||||
"""Execute the agent."""
|
||||
setup_logging(verbose=verbose)
|
||||
result = asyncio.run(default_agent.run({"topic": topic}, mock_mode=mock))
|
||||
result = asyncio.run(default_agent.run({"topic": topic}))
|
||||
click.echo(json.dumps({"success": result.success, "output": result.output}, indent=2, default=str))
|
||||
sys.exit(0 if result.success else 1)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--mock", is_flag=True)
|
||||
def tui(mock):
|
||||
def tui():
|
||||
"""Launch TUI dashboard."""
|
||||
from pathlib import Path
|
||||
from framework.tui.app import AdenTUI
|
||||
@@ -499,7 +495,7 @@ def tui(mock):
|
||||
storage.mkdir(parents=True, exist_ok=True)
|
||||
mcp_cfg = Path(__file__).parent / "mcp_servers.json"
|
||||
if mcp_cfg.exists(): agent._tool_registry.load_mcp_config(mcp_cfg)
|
||||
llm = None if mock else LiteLLMProvider(model=agent.config.model, api_key=agent.config.api_key, api_base=agent.config.api_base)
|
||||
llm = LiteLLMProvider(model=agent.config.model, api_key=agent.config.api_key, api_base=agent.config.api_base)
|
||||
runtime = create_agent_runtime(
|
||||
graph=agent._build_graph(), goal=agent.goal, storage_path=storage,
|
||||
entry_points=[EntryPointSpec(id="start", name="Start", entry_node="intake", trigger_type="manual", isolation_level="isolated")],
|
||||
@@ -564,7 +560,6 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
for _p in ["exports", "core"]:
|
||||
@@ -576,18 +571,17 @@ AGENT_PATH = str(Path(__file__).resolve().parents[1])
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_mode():
|
||||
return True
|
||||
def agent_module():
|
||||
"""Import the agent package for structural validation."""
|
||||
import importlib
|
||||
return importlib.import_module(Path(AGENT_PATH).name)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def runner(tmp_path_factory, mock_mode):
|
||||
@pytest.fixture(scope="session")
|
||||
def runner_loaded():
|
||||
"""Load the agent through AgentRunner (structural only, no LLM needed)."""
|
||||
from framework.runner.runner import AgentRunner
|
||||
storage = tmp_path_factory.mktemp("agent_storage")
|
||||
r = AgentRunner.load(AGENT_PATH, mock_mode=mock_mode, storage_path=storage)
|
||||
r._setup()
|
||||
yield r
|
||||
await r.cleanup_async()
|
||||
return AgentRunner.load(AGENT_PATH)
|
||||
```
|
||||
|
||||
## entry_points Format
|
||||
|
||||
@@ -72,7 +72,7 @@ goal = Goal(
|
||||
| id | str | required | kebab-case identifier |
|
||||
| name | str | required | Display name |
|
||||
| description | str | required | What the node does |
|
||||
| node_type | str | required | Always `"event_loop"` |
|
||||
| node_type | str | required | `"event_loop"` or `"gcu"` (browser automation — see GCU Guide appendix) |
|
||||
| input_keys | list[str] | required | Memory keys this node reads |
|
||||
| output_keys | list[str] | required | Memory keys this node writes via set_output |
|
||||
| system_prompt | str | "" | LLM instructions |
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
# GCU Browser Automation Guide
|
||||
|
||||
## When to Use GCU Nodes
|
||||
|
||||
Use `node_type="gcu"` when:
|
||||
- The user's workflow requires **navigating real websites** (scraping, form-filling, social media interaction, testing web UIs)
|
||||
- The task involves **dynamic/JS-rendered pages** that `web_scrape` cannot handle (SPAs, infinite scroll, login-gated content)
|
||||
- The agent needs to **interact with a website** — clicking, typing, scrolling, selecting, uploading files
|
||||
|
||||
Do NOT use GCU for:
|
||||
- Static content that `web_scrape` handles fine
|
||||
- API-accessible data (use the API directly)
|
||||
- PDF/file processing
|
||||
- Anything that doesn't require a browser UI
|
||||
|
||||
## What GCU Nodes Are
|
||||
|
||||
- `node_type="gcu"` — a declarative enhancement over `event_loop`
|
||||
- Framework auto-prepends browser best-practices system prompt
|
||||
- Framework auto-includes all 31 browser tools from `gcu-tools` MCP server
|
||||
- Same underlying `EventLoopNode` class — no new imports needed
|
||||
- `tools=[]` is correct — tools are auto-populated at runtime
|
||||
|
||||
## GCU Architecture Pattern
|
||||
|
||||
GCU nodes are **subagents** — invoked via `delegate_to_sub_agent()`, not connected via edges.
|
||||
|
||||
- Primary nodes (`event_loop`, client-facing) orchestrate; GCU nodes do browser work
|
||||
- Parent node declares `sub_agents=["gcu-node-id"]` and calls `delegate_to_sub_agent(agent_id="gcu-node-id", task="...")`
|
||||
- GCU nodes set `max_node_visits=1` (single execution per delegation), `client_facing=False`
|
||||
- GCU nodes use `output_keys=["result"]` and return structured JSON via `set_output("result", ...)`
|
||||
|
||||
## GCU Node Definition Template
|
||||
|
||||
```python
|
||||
gcu_browser_node = NodeSpec(
|
||||
id="gcu-browser-worker",
|
||||
name="Browser Worker",
|
||||
description="Browser subagent that does X.",
|
||||
node_type="gcu",
|
||||
client_facing=False,
|
||||
max_node_visits=1,
|
||||
input_keys=[],
|
||||
output_keys=["result"],
|
||||
tools=[], # Auto-populated with all browser tools
|
||||
system_prompt="""\
|
||||
You are a browser agent. Your job: [specific task].
|
||||
|
||||
## Workflow
|
||||
1. browser_start (only if no browser is running yet)
|
||||
2. browser_open(url=TARGET_URL) — note the returned targetId
|
||||
3. browser_snapshot to read the page
|
||||
4. [task-specific steps]
|
||||
5. set_output("result", JSON)
|
||||
|
||||
## Output format
|
||||
set_output("result", JSON) with:
|
||||
- [field]: [type and description]
|
||||
""",
|
||||
)
|
||||
```
|
||||
|
||||
## Parent Node Template (orchestrating GCU subagents)
|
||||
|
||||
```python
|
||||
orchestrator_node = NodeSpec(
|
||||
id="orchestrator",
|
||||
...
|
||||
node_type="event_loop",
|
||||
sub_agents=["gcu-browser-worker"],
|
||||
system_prompt="""\
|
||||
...
|
||||
delegate_to_sub_agent(
|
||||
agent_id="gcu-browser-worker",
|
||||
task="Navigate to [URL]. Do [specific task]. Return JSON with [fields]."
|
||||
)
|
||||
...
|
||||
""",
|
||||
tools=[], # Orchestrator doesn't need browser tools
|
||||
)
|
||||
```
|
||||
|
||||
## mcp_servers.json with GCU
|
||||
|
||||
```json
|
||||
{
|
||||
"hive-tools": { ... },
|
||||
"gcu-tools": {
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "gcu.server", "--stdio"],
|
||||
"cwd": "../../tools",
|
||||
"description": "GCU tools for browser automation"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note: `gcu-tools` is auto-added if any node uses `node_type="gcu"`, but including it explicitly is fine.
|
||||
|
||||
## GCU System Prompt Best Practices
|
||||
|
||||
Key rules to bake into GCU node prompts:
|
||||
|
||||
- Prefer `browser_snapshot` over `browser_get_text("body")` — compact accessibility tree vs 100KB+ raw HTML
|
||||
- Always `browser_wait` after navigation
|
||||
- Use large scroll amounts (~2000-5000) for lazy-loaded content
|
||||
- For spillover files, use `run_command` with grep, not `read_file`
|
||||
- If auth wall detected, report immediately — don't attempt login
|
||||
- Keep tool calls per turn ≤10
|
||||
- Tab isolation: when browser is already running, use `browser_open(background=true)` and pass `target_id` to every call
|
||||
|
||||
## GCU Anti-Patterns
|
||||
|
||||
- Using `browser_screenshot` to read text (use `browser_snapshot`)
|
||||
- Re-navigating after scrolling (resets scroll position)
|
||||
- Attempting login on auth walls
|
||||
- Forgetting `target_id` in multi-tab scenarios
|
||||
- Putting browser tools directly on `event_loop` nodes instead of using GCU subagent pattern
|
||||
- Making GCU nodes `client_facing=True` (they should be autonomous subagents)
|
||||
@@ -761,7 +761,7 @@ class GraphBuilder:
|
||||
path = self.storage_path / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Session not found: {session_id}")
|
||||
return BuildSession.model_validate_json(path.read_text())
|
||||
return BuildSession.model_validate_json(path.read_text(encoding="utf-8"))
|
||||
|
||||
@classmethod
|
||||
def list_sessions(cls, storage_path: Path | str | None = None) -> list[str]:
|
||||
|
||||
@@ -90,6 +90,11 @@ def get_api_key() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_gcu_enabled() -> bool:
|
||||
"""Return whether GCU (browser automation) is enabled in user config."""
|
||||
return get_hive_config().get("gcu_enabled", False)
|
||||
|
||||
|
||||
def get_api_base() -> str | None:
|
||||
"""Return the api_base URL for OpenAI-compatible endpoints, if configured."""
|
||||
llm = get_hive_config().get("llm", {})
|
||||
|
||||
@@ -42,6 +42,14 @@ For Vault integration:
|
||||
from core.framework.credentials.vault import HashiCorpVaultStorage
|
||||
"""
|
||||
|
||||
from .key_storage import (
|
||||
delete_aden_api_key,
|
||||
generate_and_save_credential_key,
|
||||
load_aden_api_key,
|
||||
load_credential_key,
|
||||
save_aden_api_key,
|
||||
save_credential_key,
|
||||
)
|
||||
from .models import (
|
||||
CredentialDecryptionError,
|
||||
CredentialError,
|
||||
@@ -132,6 +140,13 @@ __all__ = [
|
||||
"CredentialRefreshError",
|
||||
"CredentialValidationError",
|
||||
"CredentialDecryptionError",
|
||||
# Key storage (bootstrap credentials)
|
||||
"load_credential_key",
|
||||
"save_credential_key",
|
||||
"generate_and_save_credential_key",
|
||||
"load_aden_api_key",
|
||||
"save_aden_api_key",
|
||||
"delete_aden_api_key",
|
||||
# Validation
|
||||
"ensure_credential_key_env",
|
||||
"validate_agent_credentials",
|
||||
|
||||
@@ -26,7 +26,7 @@ Usage:
|
||||
storage = AdenCachedStorage(
|
||||
local_storage=EncryptedFileStorage(),
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300, # Re-check Aden every 5 minutes
|
||||
cache_ttl_seconds=600, # Re-check Aden every 5 minutes
|
||||
)
|
||||
|
||||
# Create store
|
||||
@@ -77,7 +77,7 @@ class AdenCachedStorage(CredentialStorage):
|
||||
storage = AdenCachedStorage(
|
||||
local_storage=EncryptedFileStorage(),
|
||||
aden_provider=provider,
|
||||
cache_ttl_seconds=300, # 5 minutes
|
||||
cache_ttl_seconds=00, # 5 minutes
|
||||
)
|
||||
|
||||
store = CredentialStore(
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Dedicated file-based storage for bootstrap credentials.
|
||||
|
||||
HIVE_CREDENTIAL_KEY -> ~/.hive/secrets/credential_key (plain text, chmod 600)
|
||||
ADEN_API_KEY -> ~/.hive/credentials/ (encrypted via EncryptedFileStorage)
|
||||
|
||||
Boot order:
|
||||
1. load_credential_key() -- reads/generates the Fernet key, sets os.environ
|
||||
2. load_aden_api_key() -- uses the encrypted store (which needs the key from step 1)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CREDENTIAL_KEY_PATH = Path.home() / ".hive" / "secrets" / "credential_key"
|
||||
CREDENTIAL_KEY_ENV_VAR = "HIVE_CREDENTIAL_KEY"
|
||||
ADEN_CREDENTIAL_ID = "aden_api_key"
|
||||
ADEN_ENV_VAR = "ADEN_API_KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HIVE_CREDENTIAL_KEY
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_credential_key() -> str | None:
|
||||
"""Load HIVE_CREDENTIAL_KEY with priority: env > file > shell config.
|
||||
|
||||
Sets ``os.environ["HIVE_CREDENTIAL_KEY"]`` as a side-effect when found.
|
||||
Returns the key string, or ``None`` if unavailable everywhere.
|
||||
"""
|
||||
# 1. Already in environment (set by parent process, CI, Windows Registry, etc.)
|
||||
key = os.environ.get(CREDENTIAL_KEY_ENV_VAR)
|
||||
if key:
|
||||
return key
|
||||
|
||||
# 2. Dedicated secrets file
|
||||
key = _read_credential_key_file()
|
||||
if key:
|
||||
os.environ[CREDENTIAL_KEY_ENV_VAR] = key
|
||||
return key
|
||||
|
||||
# 3. Shell config fallback (backward compat for old installs)
|
||||
key = _read_from_shell_config(CREDENTIAL_KEY_ENV_VAR)
|
||||
if key:
|
||||
os.environ[CREDENTIAL_KEY_ENV_VAR] = key
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def save_credential_key(key: str) -> Path:
|
||||
"""Save HIVE_CREDENTIAL_KEY to ``~/.hive/secrets/credential_key``.
|
||||
|
||||
Creates parent dirs with mode 700, writes the file with mode 600.
|
||||
Also sets ``os.environ["HIVE_CREDENTIAL_KEY"]``.
|
||||
|
||||
Returns:
|
||||
The path that was written.
|
||||
"""
|
||||
path = CREDENTIAL_KEY_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Restrict the secrets directory itself
|
||||
path.parent.chmod(stat.S_IRWXU) # 0o700
|
||||
|
||||
path.write_text(key)
|
||||
path.chmod(stat.S_IRUSR | stat.S_IWUSR) # 0o600
|
||||
|
||||
os.environ[CREDENTIAL_KEY_ENV_VAR] = key
|
||||
return path
|
||||
|
||||
|
||||
def generate_and_save_credential_key() -> str:
|
||||
"""Generate a new Fernet key and persist it to ``~/.hive/secrets/credential_key``.
|
||||
|
||||
Returns:
|
||||
The generated key string.
|
||||
"""
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
key = Fernet.generate_key().decode()
|
||||
save_credential_key(key)
|
||||
return key
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ADEN_API_KEY
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_aden_api_key() -> str | None:
|
||||
"""Load ADEN_API_KEY with priority: env > encrypted store > shell config.
|
||||
|
||||
**Must** be called after ``load_credential_key()`` because the encrypted
|
||||
store depends on HIVE_CREDENTIAL_KEY.
|
||||
|
||||
Sets ``os.environ["ADEN_API_KEY"]`` as a side-effect when found.
|
||||
Returns the key string, or ``None`` if unavailable everywhere.
|
||||
"""
|
||||
# 1. Already in environment
|
||||
key = os.environ.get(ADEN_ENV_VAR)
|
||||
if key:
|
||||
return key
|
||||
|
||||
# 2. Encrypted credential store
|
||||
key = _read_aden_from_encrypted_store()
|
||||
if key:
|
||||
os.environ[ADEN_ENV_VAR] = key
|
||||
return key
|
||||
|
||||
# 3. Shell config fallback (backward compat)
|
||||
key = _read_from_shell_config(ADEN_ENV_VAR)
|
||||
if key:
|
||||
os.environ[ADEN_ENV_VAR] = key
|
||||
return key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def save_aden_api_key(key: str) -> None:
|
||||
"""Save ADEN_API_KEY to the encrypted credential store.
|
||||
|
||||
Also sets ``os.environ["ADEN_API_KEY"]``.
|
||||
"""
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import CredentialKey, CredentialObject
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
storage = EncryptedFileStorage()
|
||||
cred = CredentialObject(
|
||||
id=ADEN_CREDENTIAL_ID,
|
||||
keys={"api_key": CredentialKey(name="api_key", value=SecretStr(key))},
|
||||
)
|
||||
storage.save(cred)
|
||||
os.environ[ADEN_ENV_VAR] = key
|
||||
|
||||
|
||||
def delete_aden_api_key() -> None:
|
||||
"""Remove ADEN_API_KEY from the encrypted store and ``os.environ``."""
|
||||
try:
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
storage = EncryptedFileStorage()
|
||||
storage.delete(ADEN_CREDENTIAL_ID)
|
||||
except Exception:
|
||||
logger.debug("Could not delete %s from encrypted store", ADEN_CREDENTIAL_ID)
|
||||
|
||||
os.environ.pop(ADEN_ENV_VAR, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _read_credential_key_file() -> str | None:
|
||||
"""Read the credential key from ``~/.hive/secrets/credential_key``."""
|
||||
try:
|
||||
if CREDENTIAL_KEY_PATH.is_file():
|
||||
value = CREDENTIAL_KEY_PATH.read_text(encoding="utf-8").strip()
|
||||
if value:
|
||||
return value
|
||||
except Exception:
|
||||
logger.debug("Could not read %s", CREDENTIAL_KEY_PATH)
|
||||
return None
|
||||
|
||||
|
||||
def _read_from_shell_config(env_var: str) -> str | None:
|
||||
"""Fallback: read an env var from ~/.zshrc or ~/.bashrc."""
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
|
||||
found, value = check_env_var_in_shell_config(env_var)
|
||||
if found and value:
|
||||
return value
|
||||
except ImportError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _read_aden_from_encrypted_store() -> str | None:
|
||||
"""Try to load ADEN_API_KEY from the encrypted credential store."""
|
||||
if not os.environ.get(CREDENTIAL_KEY_ENV_VAR):
|
||||
return None
|
||||
try:
|
||||
from .storage import EncryptedFileStorage
|
||||
|
||||
storage = EncryptedFileStorage()
|
||||
cred = storage.load(ADEN_CREDENTIAL_ID)
|
||||
if cred:
|
||||
return cred.get_key("api_key")
|
||||
except Exception:
|
||||
logger.debug("Could not load %s from encrypted store", ADEN_CREDENTIAL_ID)
|
||||
return None
|
||||
@@ -256,57 +256,23 @@ class CredentialSetupSession:
|
||||
|
||||
def _ensure_credential_key(self) -> bool:
|
||||
"""Ensure HIVE_CREDENTIAL_KEY is available for encrypted storage."""
|
||||
if os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
from .key_storage import generate_and_save_credential_key, load_credential_key
|
||||
|
||||
if load_credential_key():
|
||||
return True
|
||||
|
||||
# Try to load from shell config
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
|
||||
found, value = check_env_var_in_shell_config("HIVE_CREDENTIAL_KEY")
|
||||
if found and value:
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = value
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Generate a new key
|
||||
self._print(f"{Colors.YELLOW}Initializing credential store...{Colors.NC}")
|
||||
try:
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
generated_key = Fernet.generate_key().decode()
|
||||
os.environ["HIVE_CREDENTIAL_KEY"] = generated_key
|
||||
|
||||
# Save to shell config
|
||||
self._save_key_to_shell_config(generated_key)
|
||||
generate_and_save_credential_key()
|
||||
self._print(
|
||||
f"{Colors.GREEN}✓ Encryption key saved to ~/.hive/secrets/credential_key{Colors.NC}"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
self._print(f"{Colors.RED}Failed to initialize credential store: {e}{Colors.NC}")
|
||||
return False
|
||||
|
||||
def _save_key_to_shell_config(self, key: str) -> None:
|
||||
"""Save HIVE_CREDENTIAL_KEY to shell config."""
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import (
|
||||
add_env_var_to_shell_config,
|
||||
)
|
||||
|
||||
success, config_path = add_env_var_to_shell_config(
|
||||
"HIVE_CREDENTIAL_KEY",
|
||||
key,
|
||||
comment="Encryption key for Hive credential store",
|
||||
)
|
||||
if success:
|
||||
self._print(f"{Colors.GREEN}✓ Encryption key saved to {config_path}{Colors.NC}")
|
||||
except Exception:
|
||||
# Fallback: just tell the user
|
||||
self._print("\n")
|
||||
self._print(
|
||||
f"{Colors.YELLOW}Add this to your shell config (~/.zshrc or ~/.bashrc):{Colors.NC}"
|
||||
)
|
||||
self._print(f' export HIVE_CREDENTIAL_KEY="{key}"')
|
||||
|
||||
def _setup_single_credential(self, cred: MissingCredential) -> bool:
|
||||
"""Set up a single credential. Returns True if configured."""
|
||||
self._print(f"\n{Colors.CYAN}{'─' * 60}{Colors.NC}")
|
||||
@@ -444,19 +410,10 @@ class CredentialSetupSession:
|
||||
self._print(f"{Colors.YELLOW}No key entered. Skipping.{Colors.NC}")
|
||||
return False
|
||||
|
||||
os.environ["ADEN_API_KEY"] = aden_key
|
||||
# Persist to encrypted store and set os.environ
|
||||
from .key_storage import save_aden_api_key
|
||||
|
||||
# Save to shell config
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import add_env_var_to_shell_config
|
||||
|
||||
add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
aden_key,
|
||||
comment="Aden Platform API key",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
save_aden_api_key(aden_key)
|
||||
|
||||
# Sync from Aden
|
||||
try:
|
||||
|
||||
@@ -14,43 +14,46 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ensure_credential_key_env() -> None:
|
||||
"""Load credentials from shell config if not in environment.
|
||||
"""Load bootstrap credentials into ``os.environ``.
|
||||
|
||||
The quickstart.sh and setup-credentials skill write API keys to ~/.zshrc
|
||||
or ~/.bashrc. If the user hasn't sourced their config in the current shell,
|
||||
this reads them directly so the runner (and any MCP subprocesses) can use them.
|
||||
Priority chain for each credential:
|
||||
1. ``os.environ`` (already set — nothing to do)
|
||||
2. Dedicated file storage (``~/.hive/secrets/`` or encrypted store)
|
||||
3. Shell config fallback (``~/.zshrc`` / ``~/.bashrc``) for backward compat
|
||||
|
||||
Loads:
|
||||
- HIVE_CREDENTIAL_KEY (encrypted credential store)
|
||||
- ADEN_API_KEY (Aden OAuth sync)
|
||||
- All LLM API keys (ANTHROPIC_API_KEY, OPENAI_API_KEY, ZAI_API_KEY, etc.)
|
||||
Boot order matters: HIVE_CREDENTIAL_KEY must load BEFORE ADEN_API_KEY
|
||||
because the encrypted store depends on it.
|
||||
|
||||
Remaining LLM/tool API keys still load from shell config.
|
||||
"""
|
||||
from .key_storage import load_aden_api_key, load_credential_key
|
||||
|
||||
# Step 1: HIVE_CREDENTIAL_KEY (must come first — encrypted store depends on it)
|
||||
load_credential_key()
|
||||
|
||||
# Step 2: ADEN_API_KEY (uses encrypted store, then shell config fallback)
|
||||
load_aden_api_key()
|
||||
|
||||
# Step 3: Load remaining LLM/tool API keys from shell config
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import check_env_var_in_shell_config
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
# Core credentials that are always checked
|
||||
env_vars_to_load = ["HIVE_CREDENTIAL_KEY", "ADEN_API_KEY"]
|
||||
|
||||
# Add all LLM/tool API keys from CREDENTIAL_SPECS
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
for spec in CREDENTIAL_SPECS.values():
|
||||
if spec.env_var and spec.env_var not in env_vars_to_load:
|
||||
env_vars_to_load.append(spec.env_var)
|
||||
var_name = spec.env_var
|
||||
if var_name and var_name not in ("HIVE_CREDENTIAL_KEY", "ADEN_API_KEY"):
|
||||
if not os.environ.get(var_name):
|
||||
found, value = check_env_var_in_shell_config(var_name)
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
for var_name in env_vars_to_load:
|
||||
if os.environ.get(var_name):
|
||||
continue
|
||||
found, value = check_env_var_in_shell_config(var_name)
|
||||
if found and value:
|
||||
os.environ[var_name] = value
|
||||
logger.debug("Loaded %s from shell config", var_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CredentialStatus:
|
||||
@@ -71,6 +74,7 @@ class CredentialStatus:
|
||||
direct_api_key_supported: bool
|
||||
credential_key: str
|
||||
aden_not_connected: bool # Aden-only cred, ADEN_API_KEY set, but integration missing
|
||||
alternative_group: str | None = None # non-None when multiple providers can satisfy a tool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -82,8 +86,34 @@ class CredentialValidationResult:
|
||||
|
||||
@property
|
||||
def failed(self) -> list[CredentialStatus]:
|
||||
"""Credentials that are missing, invalid, or Aden-not-connected."""
|
||||
return [c for c in self.credentials if not c.available or c.valid is False]
|
||||
"""Credentials that are missing, invalid, or Aden-not-connected.
|
||||
|
||||
For alternative groups (multi-provider tools like send_email), the group
|
||||
is satisfied if ANY member is available and valid — only report failures
|
||||
when the entire group is unsatisfied.
|
||||
"""
|
||||
# Check which alternative groups are satisfied
|
||||
alt_satisfied: dict[str, bool] = {}
|
||||
for c in self.credentials:
|
||||
if not c.alternative_group:
|
||||
continue
|
||||
if c.alternative_group not in alt_satisfied:
|
||||
alt_satisfied[c.alternative_group] = False
|
||||
if c.available and c.valid is not False:
|
||||
alt_satisfied[c.alternative_group] = True
|
||||
|
||||
result = []
|
||||
for c in self.credentials:
|
||||
if c.alternative_group:
|
||||
# Skip if any alternative in the group is satisfied
|
||||
if alt_satisfied.get(c.alternative_group, False):
|
||||
continue
|
||||
if not c.available or c.valid is False:
|
||||
result.append(c)
|
||||
else:
|
||||
if not c.available or c.valid is False:
|
||||
result.append(c)
|
||||
return result
|
||||
|
||||
@property
|
||||
def has_errors(self) -> bool:
|
||||
@@ -129,11 +159,7 @@ class CredentialValidationResult:
|
||||
f" {c.env_var} for {_label(c)}"
|
||||
f"\n Connect this integration at hive.adenhq.com first."
|
||||
)
|
||||
lines.append(
|
||||
"\nTo fix: run /hive-credentials in Claude Code."
|
||||
"\nIf you've already set up credentials, "
|
||||
"restart your terminal to load them."
|
||||
)
|
||||
lines.append("\nIf you've already set up credentials, restart your terminal to load them.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -146,7 +172,7 @@ def _label(c: CredentialStatus) -> str:
|
||||
return c.credential_name
|
||||
|
||||
|
||||
def _presync_aden_tokens(credential_specs: dict) -> None:
|
||||
def _presync_aden_tokens(credential_specs: dict, *, force: bool = False) -> None:
|
||||
"""Sync Aden-backed OAuth tokens into env vars for validation.
|
||||
|
||||
When ADEN_API_KEY is available, fetches fresh OAuth tokens from the Aden
|
||||
@@ -154,6 +180,11 @@ def _presync_aden_tokens(credential_specs: dict) -> None:
|
||||
tokens instead of stale or mis-stored values in the encrypted store.
|
||||
Only touches credentials that are ``aden_supported`` AND whose env var
|
||||
is not already set (so explicit user exports always win).
|
||||
|
||||
Args:
|
||||
force: When True, overwrite env vars that are already set. Used by
|
||||
the credentials modal to pick up freshly reauthorized tokens
|
||||
from Aden instead of reusing stale values from a prior sync.
|
||||
"""
|
||||
from framework.credentials.store import CredentialStore
|
||||
|
||||
@@ -166,7 +197,7 @@ def _presync_aden_tokens(credential_specs: dict) -> None:
|
||||
for name, spec in credential_specs.items():
|
||||
if not spec.aden_supported:
|
||||
continue
|
||||
if os.environ.get(spec.env_var):
|
||||
if not force and os.environ.get(spec.env_var):
|
||||
continue # Already set — don't overwrite
|
||||
cred_id = spec.credential_id or name
|
||||
# sync_all() already fetched everything available from Aden.
|
||||
@@ -200,6 +231,7 @@ def validate_agent_credentials(
|
||||
quiet: bool = False,
|
||||
verify: bool = True,
|
||||
raise_on_error: bool = True,
|
||||
force_refresh: bool = False,
|
||||
) -> CredentialValidationResult:
|
||||
"""Check that required credentials are available and valid before running an agent.
|
||||
|
||||
@@ -214,6 +246,9 @@ def validate_agent_credentials(
|
||||
verify: If True (default), run health checks on present credentials.
|
||||
raise_on_error: If True (default), raise CredentialError when validation
|
||||
fails. Set to False to get the result without raising.
|
||||
force_refresh: If True, force re-sync of Aden OAuth tokens even when
|
||||
env vars are already set. Used by the credentials modal after
|
||||
reauthorization.
|
||||
|
||||
Returns:
|
||||
CredentialValidationResult with status of ALL required credentials.
|
||||
@@ -245,7 +280,7 @@ def validate_agent_credentials(
|
||||
# into env vars so validation sees fresh tokens instead of stale values
|
||||
# in the encrypted store (e.g., a previously mis-stored google.enc).
|
||||
if os.environ.get("ADEN_API_KEY"):
|
||||
_presync_aden_tokens(CREDENTIAL_SPECS)
|
||||
_presync_aden_tokens(CREDENTIAL_SPECS, force=force_refresh)
|
||||
|
||||
env_mapping = {
|
||||
(spec.credential_id or name): spec.env_var for name, spec in CREDENTIAL_SPECS.items()
|
||||
@@ -257,12 +292,12 @@ def validate_agent_credentials(
|
||||
storage = env_storage
|
||||
store = CredentialStore(storage=storage)
|
||||
|
||||
# Build reverse mappings
|
||||
tool_to_cred: dict[str, str] = {}
|
||||
# Build reverse mappings — 1:many for multi-provider tools (e.g. send_email → resend OR google)
|
||||
tool_to_creds: dict[str, list[str]] = {}
|
||||
node_type_to_cred: dict[str, str] = {}
|
||||
for cred_name, spec in CREDENTIAL_SPECS.items():
|
||||
for tool_name in spec.tools:
|
||||
tool_to_cred[tool_name] = cred_name
|
||||
tool_to_creds.setdefault(tool_name, []).append(cred_name)
|
||||
for nt in spec.node_types:
|
||||
node_type_to_cred[nt] = cred_name
|
||||
|
||||
@@ -272,7 +307,13 @@ def validate_agent_credentials(
|
||||
# Credentials that are present and should be health-checked
|
||||
to_verify: list[int] = [] # indices into all_credentials
|
||||
|
||||
def _check_credential(spec, cred_name: str, affected_tools: list[str], affected_node_types: list[str]) -> None:
|
||||
def _check_credential(
|
||||
spec,
|
||||
cred_name: str,
|
||||
affected_tools: list[str],
|
||||
affected_node_types: list[str],
|
||||
alternative_group: str | None = None,
|
||||
) -> None:
|
||||
cred_id = spec.credential_id or cred_name
|
||||
available = store.is_available(cred_id)
|
||||
|
||||
@@ -300,6 +341,7 @@ def validate_agent_credentials(
|
||||
direct_api_key_supported=spec.direct_api_key_supported,
|
||||
credential_key=spec.credential_key,
|
||||
aden_not_connected=is_aden_nc,
|
||||
alternative_group=alternative_group,
|
||||
)
|
||||
all_credentials.append(status)
|
||||
|
||||
@@ -308,15 +350,56 @@ def validate_agent_credentials(
|
||||
|
||||
# Check tool credentials
|
||||
for tool_name in sorted(required_tools):
|
||||
cred_name = tool_to_cred.get(tool_name)
|
||||
if cred_name is None or cred_name in checked:
|
||||
cred_names = tool_to_creds.get(tool_name)
|
||||
if cred_names is None:
|
||||
continue
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
if not spec.required:
|
||||
|
||||
# Filter to credentials we haven't already checked
|
||||
unchecked = [cn for cn in cred_names if cn not in checked]
|
||||
if not unchecked:
|
||||
continue
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
_check_credential(spec, cred_name, affected_tools=affected, affected_node_types=[])
|
||||
|
||||
# Single provider — existing behavior
|
||||
if len(unchecked) == 1:
|
||||
cred_name = unchecked[0]
|
||||
checked.add(cred_name)
|
||||
spec = CREDENTIAL_SPECS[cred_name]
|
||||
if not spec.required:
|
||||
continue
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
_check_credential(spec, cred_name, affected_tools=affected, affected_node_types=[])
|
||||
continue
|
||||
|
||||
# Multi-provider (e.g. send_email → resend OR google):
|
||||
# satisfied if ANY provider credential is available.
|
||||
available_cn = None
|
||||
for cn in unchecked:
|
||||
spec = CREDENTIAL_SPECS[cn]
|
||||
cred_id = spec.credential_id or cn
|
||||
if store.is_available(cred_id):
|
||||
available_cn = cn
|
||||
break
|
||||
|
||||
if available_cn is not None:
|
||||
# Found an available provider — check (and health-check) it
|
||||
checked.add(available_cn)
|
||||
spec = CREDENTIAL_SPECS[available_cn]
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
_check_credential(spec, available_cn, affected_tools=affected, affected_node_types=[])
|
||||
else:
|
||||
# None available — report ALL alternatives so the modal can show them
|
||||
group_key = tool_name # e.g. "send_email"
|
||||
for cn in unchecked:
|
||||
checked.add(cn)
|
||||
spec = CREDENTIAL_SPECS[cn]
|
||||
affected = sorted(t for t in required_tools if t in spec.tools)
|
||||
_check_credential(
|
||||
spec,
|
||||
cn,
|
||||
affected_tools=affected,
|
||||
affected_node_types=[],
|
||||
alternative_group=group_key,
|
||||
)
|
||||
|
||||
# Check node type credentials (e.g., ANTHROPIC_API_KEY for LLM nodes)
|
||||
for nt in sorted(node_types):
|
||||
@@ -400,12 +483,10 @@ def build_setup_session_from_error(
|
||||
nodes: Graph nodes (preferred — avoids re-loading from disk).
|
||||
agent_path: Agent directory path (used when nodes aren't available).
|
||||
"""
|
||||
from framework.credentials.setup import CredentialSetupSession, MissingCredential
|
||||
from framework.credentials.setup import CredentialSetupSession
|
||||
|
||||
# Prefer the validation result attached to the exception
|
||||
result: CredentialValidationResult | None = getattr(
|
||||
credential_error, "validation_result", None
|
||||
)
|
||||
result: CredentialValidationResult | None = getattr(credential_error, "validation_result", None)
|
||||
if result is not None:
|
||||
missing = [_status_to_missing(c) for c in result.failed]
|
||||
return CredentialSetupSession(missing)
|
||||
|
||||
@@ -46,9 +46,11 @@ class ActiveNodeClientIO(NodeClientIO):
|
||||
self,
|
||||
node_id: str,
|
||||
event_bus: EventBus | None = None,
|
||||
execution_id: str = "",
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self._event_bus = event_bus
|
||||
self._execution_id = execution_id
|
||||
|
||||
self._output_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._output_snapshot = ""
|
||||
@@ -66,6 +68,7 @@ class ActiveNodeClientIO(NodeClientIO):
|
||||
node_id=self.node_id,
|
||||
content=content,
|
||||
snapshot=self._output_snapshot,
|
||||
execution_id=self._execution_id or None,
|
||||
)
|
||||
|
||||
if is_final:
|
||||
@@ -83,6 +86,7 @@ class ActiveNodeClientIO(NodeClientIO):
|
||||
stream_id=self.node_id,
|
||||
node_id=self.node_id,
|
||||
prompt=prompt,
|
||||
execution_id=self._execution_id or None,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -158,11 +162,12 @@ class ClientIOGateway:
|
||||
def __init__(self, event_bus: EventBus | None = None) -> None:
|
||||
self._event_bus = event_bus
|
||||
|
||||
def create_io(self, node_id: str, client_facing: bool) -> NodeClientIO:
|
||||
def create_io(self, node_id: str, client_facing: bool, execution_id: str = "") -> NodeClientIO:
|
||||
if client_facing:
|
||||
return ActiveNodeClientIO(
|
||||
node_id=node_id,
|
||||
event_bus=self._event_bus,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
return InertNodeClientIO(
|
||||
node_id=node_id,
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@@ -90,15 +91,67 @@ class Message:
|
||||
|
||||
|
||||
def _extract_spillover_filename(content: str) -> str | None:
|
||||
"""Extract spillover filename from a truncated tool result.
|
||||
"""Extract spillover filename from a tool result annotation.
|
||||
|
||||
Matches the pattern produced by EventLoopNode._truncate_tool_result():
|
||||
"saved to 'tool_github_list_stargazers_abc123.txt'"
|
||||
Matches patterns produced by EventLoopNode._truncate_tool_result():
|
||||
- Large result: "saved to 'web_search_1.txt'"
|
||||
- Small result: "[Saved to 'web_search_1.txt']"
|
||||
"""
|
||||
match = re.search(r"saved to '([^']+)'", content)
|
||||
match = re.search(r"[Ss]aved to '([^']+)'", content)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
_TC_ARG_LIMIT = 200 # max chars per tool_call argument after compaction
|
||||
|
||||
|
||||
def _compact_tool_calls(tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Truncate tool_call arguments to save context tokens during compaction.
|
||||
|
||||
Preserves ``id``, ``type``, and ``function.name`` exactly. When arguments
|
||||
exceed ``_TC_ARG_LIMIT``, replaces the full JSON string with a compact
|
||||
**valid** JSON summary. The Anthropic API parses tool_call arguments and
|
||||
rejects requests with malformed JSON (e.g. unterminated strings), so we
|
||||
must never produce broken JSON here.
|
||||
"""
|
||||
compact = []
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "")
|
||||
if len(args) > _TC_ARG_LIMIT:
|
||||
# Build a valid JSON summary instead of slicing mid-string.
|
||||
# Try to extract top-level keys for a meaningful preview.
|
||||
try:
|
||||
parsed = json.loads(args)
|
||||
if isinstance(parsed, dict):
|
||||
# Preserve key names, truncate values
|
||||
summary_parts = []
|
||||
for k, v in parsed.items():
|
||||
v_str = str(v)
|
||||
if len(v_str) > 60:
|
||||
v_str = v_str[:60] + "..."
|
||||
summary_parts.append(f"{k}={v_str}")
|
||||
summary = ", ".join(summary_parts)
|
||||
if len(summary) > _TC_ARG_LIMIT:
|
||||
summary = summary[:_TC_ARG_LIMIT] + "..."
|
||||
args = json.dumps({"_compacted": summary})
|
||||
else:
|
||||
args = json.dumps({"_compacted": str(parsed)[:_TC_ARG_LIMIT]})
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Args were already invalid JSON — wrap the preview safely
|
||||
args = json.dumps({"_compacted": args[:_TC_ARG_LIMIT]})
|
||||
compact.append(
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
"type": tc.get("type", "function"),
|
||||
"function": {
|
||||
"name": func.get("name", ""),
|
||||
"arguments": args,
|
||||
},
|
||||
}
|
||||
)
|
||||
return compact
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationStore protocol (Phase 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -353,12 +406,20 @@ class NodeConversation:
|
||||
"""Best available token estimate.
|
||||
|
||||
Uses actual API input token count when available (set via
|
||||
:meth:`update_token_count`), otherwise falls back to the rough
|
||||
``total_chars / 4`` heuristic.
|
||||
:meth:`update_token_count`), otherwise falls back to a
|
||||
``total_chars / 4`` heuristic that includes both message content
|
||||
AND tool_call argument sizes.
|
||||
"""
|
||||
if self._last_api_input_tokens is not None:
|
||||
return self._last_api_input_tokens
|
||||
total_chars = sum(len(m.content) for m in self._messages)
|
||||
total_chars = 0
|
||||
for m in self._messages:
|
||||
total_chars += len(m.content)
|
||||
if m.tool_calls:
|
||||
for tc in m.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
total_chars += len(func.get("arguments", ""))
|
||||
total_chars += len(func.get("name", ""))
|
||||
return total_chars // 4
|
||||
|
||||
def update_token_count(self, actual_input_tokens: int) -> None:
|
||||
@@ -587,6 +648,138 @@ class NodeConversation:
|
||||
self._messages = [summary_msg] + recent_messages
|
||||
self._last_api_input_tokens = None # reset; next LLM call will recalibrate
|
||||
|
||||
async def compact_preserving_structure(
|
||||
self,
|
||||
spillover_dir: str,
|
||||
keep_recent: int = 4,
|
||||
phase_graduated: bool = False,
|
||||
) -> None:
|
||||
"""Structure-preserving compaction: save freeform text to file, keep tool messages.
|
||||
|
||||
Unlike ``compact()`` which replaces ALL old messages with a single LLM
|
||||
summary, this method preserves the tool call structure (assistant
|
||||
messages with tool_calls + tool result messages) that are already tiny
|
||||
after pruning. Only freeform text exchanges (user messages,
|
||||
text-only assistant messages) are saved to a file and removed.
|
||||
|
||||
The result: the agent retains exact knowledge of what tools it called,
|
||||
where each result is stored, and can load the conversation text if
|
||||
needed. No LLM summary call. No heuristics. Nothing lost.
|
||||
"""
|
||||
if not self._messages:
|
||||
return
|
||||
|
||||
total = len(self._messages)
|
||||
|
||||
# Determine split point (same logic as compact)
|
||||
if phase_graduated and self._current_phase:
|
||||
split = self._find_phase_graduated_split()
|
||||
else:
|
||||
split = None
|
||||
|
||||
if split is None:
|
||||
keep_recent = max(0, min(keep_recent, total - 1))
|
||||
split = total - keep_recent if keep_recent > 0 else total
|
||||
|
||||
# Advance split past orphaned tool results at the boundary
|
||||
while split < total and self._messages[split].role == "tool":
|
||||
split += 1
|
||||
|
||||
if split == 0:
|
||||
return
|
||||
|
||||
old_messages = self._messages[:split]
|
||||
|
||||
# Classify old messages: structural (keep) vs freeform (save to file)
|
||||
kept_structural: list[Message] = []
|
||||
freeform_lines: list[str] = []
|
||||
|
||||
for msg in old_messages:
|
||||
if msg.role == "tool":
|
||||
# Tool results — already pruned to ~30 tokens (file reference).
|
||||
# Keep in conversation.
|
||||
kept_structural.append(msg)
|
||||
elif msg.role == "assistant" and msg.tool_calls:
|
||||
# Assistant message with tool_calls — keep the tool_calls
|
||||
# with truncated arguments, clear the freeform text content.
|
||||
compact_tcs = _compact_tool_calls(msg.tool_calls)
|
||||
kept_structural.append(
|
||||
Message(
|
||||
seq=msg.seq,
|
||||
role=msg.role,
|
||||
content="",
|
||||
tool_calls=compact_tcs,
|
||||
is_error=msg.is_error,
|
||||
phase_id=msg.phase_id,
|
||||
is_transition_marker=msg.is_transition_marker,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Freeform text (user messages, text-only assistant messages)
|
||||
# — save to file and remove from conversation.
|
||||
role_label = msg.role
|
||||
text = msg.content
|
||||
if len(text) > 2000:
|
||||
text = text[:2000] + "…"
|
||||
freeform_lines.append(f"[{role_label}] (seq={msg.seq}): {text}")
|
||||
|
||||
# Write freeform text to a numbered conversation file
|
||||
spill_path = Path(spillover_dir)
|
||||
spill_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find next conversation file number
|
||||
existing = sorted(spill_path.glob("conversation_*.md"))
|
||||
next_n = len(existing) + 1
|
||||
conv_filename = f"conversation_{next_n}.md"
|
||||
|
||||
if freeform_lines:
|
||||
header = f"## Compacted conversation (messages 1-{split})\n\n"
|
||||
conv_text = header + "\n\n".join(freeform_lines)
|
||||
(spill_path / conv_filename).write_text(conv_text, encoding="utf-8")
|
||||
else:
|
||||
# Nothing to save — skip file creation
|
||||
conv_filename = ""
|
||||
|
||||
# Build reference message
|
||||
if conv_filename:
|
||||
ref_content = (
|
||||
f"[Previous conversation saved to '{conv_filename}'. "
|
||||
f"Use load_data('{conv_filename}') to review if needed.]"
|
||||
)
|
||||
else:
|
||||
ref_content = "[Previous freeform messages compacted.]"
|
||||
# Use a seq just before the first kept message
|
||||
recent_messages = list(self._messages[split:])
|
||||
if kept_structural:
|
||||
ref_seq = kept_structural[0].seq - 1
|
||||
elif recent_messages:
|
||||
ref_seq = recent_messages[0].seq - 1
|
||||
else:
|
||||
ref_seq = self._next_seq
|
||||
self._next_seq += 1
|
||||
|
||||
ref_msg = Message(seq=ref_seq, role="user", content=ref_content)
|
||||
|
||||
# Persist: delete old messages from store, write reference + kept structural
|
||||
if self._store:
|
||||
first_kept_seq = (
|
||||
kept_structural[0].seq
|
||||
if kept_structural
|
||||
else (recent_messages[0].seq if recent_messages else self._next_seq)
|
||||
)
|
||||
# Delete everything before the first structural message we're keeping
|
||||
await self._store.delete_parts_before(first_kept_seq)
|
||||
# Write the reference message
|
||||
await self._store.write_part(ref_msg.seq, ref_msg.to_storage_dict())
|
||||
# Write kept structural messages (they may have been modified)
|
||||
for msg in kept_structural:
|
||||
await self._store.write_part(msg.seq, msg.to_storage_dict())
|
||||
await self._store.write_cursor({"next_seq": self._next_seq})
|
||||
|
||||
# Reassemble: reference + kept structural (in original order) + recent
|
||||
self._messages = [ref_msg] + kept_structural + recent_messages
|
||||
self._last_api_input_tokens = None
|
||||
|
||||
def _find_phase_graduated_split(self) -> int | None:
|
||||
"""Find split point that preserves current + previous phase.
|
||||
|
||||
|
||||
@@ -103,7 +103,12 @@ FEEDBACK: (reason if RETRY, empty if ACCEPT)"""
|
||||
|
||||
|
||||
def _extract_recent_context(conversation: NodeConversation, max_messages: int = 10) -> str:
|
||||
"""Extract recent conversation messages for evaluation."""
|
||||
"""Extract recent conversation messages for evaluation.
|
||||
|
||||
Includes tool-call summaries from assistant messages so the judge
|
||||
can see what tools were invoked (especially set_output values) even
|
||||
when the assistant message body is empty.
|
||||
"""
|
||||
messages = conversation.messages
|
||||
recent = messages[-max_messages:] if len(messages) > max_messages else messages
|
||||
|
||||
@@ -112,8 +117,24 @@ def _extract_recent_context(conversation: NodeConversation, max_messages: int =
|
||||
role = msg.role.upper()
|
||||
content = msg.content or ""
|
||||
# Truncate long tool results
|
||||
if msg.role == "tool" and len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
if msg.role == "tool" and len(content) > 500:
|
||||
content = content[:500] + "..."
|
||||
# For assistant messages with empty content but tool_calls,
|
||||
# summarise the tool calls so the judge knows what happened.
|
||||
if msg.role == "assistant" and not content.strip():
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if tool_calls:
|
||||
tc_parts = []
|
||||
for tc in tool_calls:
|
||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
||||
name = fn.get("name", "")
|
||||
args = fn.get("arguments", "")
|
||||
if name == "set_output":
|
||||
# Show the value so the judge can evaluate content quality
|
||||
tc_parts.append(f" called {name}({args[:1000]})")
|
||||
else:
|
||||
tc_parts.append(f" called {name}(...)")
|
||||
content = "Tool calls:\n" + "\n".join(tc_parts)
|
||||
if content.strip():
|
||||
parts.append(f"[{role}]: {content.strip()}")
|
||||
|
||||
@@ -125,6 +146,10 @@ def _format_outputs(accumulator_state: dict[str, Any]) -> str:
|
||||
|
||||
Lists and dicts get structural formatting so the judge can assess
|
||||
quantity and structure, not just a truncated stringification.
|
||||
|
||||
String values are given a generous limit (2000 chars) so the judge
|
||||
can verify substantive content (e.g. a research brief with key
|
||||
questions, scope boundaries, and deliverables).
|
||||
"""
|
||||
if not accumulator_state:
|
||||
return "(none)"
|
||||
@@ -144,12 +169,12 @@ def _format_outputs(accumulator_state: dict[str, Any]) -> str:
|
||||
val_str += f"\n ... and {len(value) - 8} more"
|
||||
elif isinstance(value, dict):
|
||||
val_str = str(value)
|
||||
if len(val_str) > 400:
|
||||
val_str = val_str[:400] + "..."
|
||||
if len(val_str) > 2000:
|
||||
val_str = val_str[:2000] + "..."
|
||||
else:
|
||||
val_str = str(value)
|
||||
if len(val_str) > 300:
|
||||
val_str = val_str[:300] + "..."
|
||||
if len(val_str) > 2000:
|
||||
val_str = val_str[:2000] + "..."
|
||||
parts.append(f" {key}: {val_str}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
@@ -338,6 +338,10 @@ class AsyncEntryPointSpec(BaseModel):
|
||||
max_concurrent: int = Field(
|
||||
default=10, description="Maximum concurrent executions for this entry point"
|
||||
)
|
||||
max_resurrections: int = Field(
|
||||
default=3,
|
||||
description="Auto-restart on non-fatal failure (0 to disable)",
|
||||
)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@@ -644,6 +648,13 @@ class GraphSpec(BaseModel):
|
||||
for edge in self.get_outgoing_edges(current):
|
||||
to_visit.append(edge.target)
|
||||
|
||||
# Also mark sub-agents as reachable (they're invoked via delegate_to_sub_agent, not edges)
|
||||
for node in self.nodes:
|
||||
if node.id in reachable:
|
||||
sub_agents = getattr(node, "sub_agents", []) or []
|
||||
for sub_agent_id in sub_agents:
|
||||
reachable.add(sub_agent_id)
|
||||
|
||||
# Build set of async entry point nodes for quick lookup
|
||||
async_entry_nodes = {ep.entry_node for ep in self.async_entry_points}
|
||||
|
||||
@@ -695,4 +706,48 @@ class GraphSpec(BaseModel):
|
||||
else:
|
||||
seen_keys[key] = node_id
|
||||
|
||||
# GCU nodes must only be used as subagents
|
||||
gcu_node_ids = {n.id for n in self.nodes if n.node_type == "gcu"}
|
||||
if gcu_node_ids:
|
||||
# GCU nodes must not be entry nodes
|
||||
if self.entry_node in gcu_node_ids:
|
||||
errors.append(
|
||||
f"GCU node '{self.entry_node}' is used as entry node. "
|
||||
"GCU nodes must only be used as subagents via delegate_to_sub_agent()."
|
||||
)
|
||||
|
||||
# GCU nodes must not be terminal nodes
|
||||
for term in self.terminal_nodes:
|
||||
if term in gcu_node_ids:
|
||||
errors.append(
|
||||
f"GCU node '{term}' is used as terminal node. "
|
||||
"GCU nodes must only be used as subagents."
|
||||
)
|
||||
|
||||
# GCU nodes must not be connected via edges
|
||||
for edge in self.edges:
|
||||
if edge.source in gcu_node_ids:
|
||||
errors.append(
|
||||
f"GCU node '{edge.source}' is used as edge source (edge '{edge.id}'). "
|
||||
"GCU nodes must only be used as subagents, not connected via edges."
|
||||
)
|
||||
if edge.target in gcu_node_ids:
|
||||
errors.append(
|
||||
f"GCU node '{edge.target}' is used as edge target (edge '{edge.id}'). "
|
||||
"GCU nodes must only be used as subagents, not connected via edges."
|
||||
)
|
||||
|
||||
# GCU nodes must be referenced in at least one parent's sub_agents
|
||||
referenced_subagents = set()
|
||||
for node in self.nodes:
|
||||
for sa_id in node.sub_agents or []:
|
||||
referenced_subagents.add(sa_id)
|
||||
|
||||
orphaned = gcu_node_ids - referenced_subagents
|
||||
for nid in orphaned:
|
||||
errors.append(
|
||||
f"GCU node '{nid}' is not referenced in any node's sub_agents list. "
|
||||
"GCU nodes must be declared as subagents of a parent node."
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
+1597
-216
File diff suppressed because it is too large
Load Diff
@@ -193,6 +193,9 @@ class GraphExecutor:
|
||||
# Pause/resume control
|
||||
self._pause_requested = asyncio.Event()
|
||||
|
||||
# Track the currently executing node for external injection routing
|
||||
self.current_node_id: str | None = None
|
||||
|
||||
def _write_progress(
|
||||
self,
|
||||
current_node: str,
|
||||
@@ -338,6 +341,9 @@ class GraphExecutor:
|
||||
cumulative_tool_names: set[str] = set()
|
||||
cumulative_output_keys: list[str] = [] # Output keys from all visited nodes
|
||||
|
||||
# Build node registry for subagent lookup
|
||||
node_registry: dict[str, NodeSpec] = {node.id: node for node in graph.nodes}
|
||||
|
||||
# Initialize checkpoint store if checkpointing is enabled
|
||||
checkpoint_store: CheckpointStore | None = None
|
||||
if checkpoint_config and checkpoint_config.enabled and self._storage_path:
|
||||
@@ -694,6 +700,9 @@ class GraphExecutor:
|
||||
# Execute this node, then pause
|
||||
# (We'll check again after execution and save state)
|
||||
|
||||
# Expose current node for external injection routing
|
||||
self.current_node_id = current_node_id
|
||||
|
||||
self.logger.info(f"\n▶ Step {steps}: {node_spec.name} ({node_spec.node_type})")
|
||||
self.logger.info(f" Inputs: {node_spec.input_keys}")
|
||||
self.logger.info(f" Outputs: {node_spec.output_keys}")
|
||||
@@ -709,6 +718,14 @@ class GraphExecutor:
|
||||
if k not in cumulative_output_keys:
|
||||
cumulative_output_keys.append(k)
|
||||
|
||||
# Build resume narrative (Layer 2) when restoring a session
|
||||
# so the EventLoopNode can rebuild the full 3-layer system prompt.
|
||||
_resume_narrative = ""
|
||||
if _is_resuming and path:
|
||||
from framework.graph.prompt_composer import build_narrative
|
||||
|
||||
_resume_narrative = build_narrative(memory, path, graph)
|
||||
|
||||
# Build context for node
|
||||
ctx = self._build_context(
|
||||
node_spec=node_spec,
|
||||
@@ -721,6 +738,10 @@ class GraphExecutor:
|
||||
override_tools=cumulative_tools if is_continuous else None,
|
||||
cumulative_output_keys=cumulative_output_keys if is_continuous else None,
|
||||
event_triggered=_event_triggered,
|
||||
node_registry=node_registry,
|
||||
identity_prompt=getattr(graph, "identity_prompt", ""),
|
||||
narrative=_resume_narrative,
|
||||
graph=graph,
|
||||
)
|
||||
|
||||
# Log actual input data being read
|
||||
@@ -1120,6 +1141,7 @@ class GraphExecutor:
|
||||
source_result=result,
|
||||
source_node_spec=node_spec,
|
||||
path=path,
|
||||
node_registry=node_registry,
|
||||
)
|
||||
|
||||
total_tokens += branch_tokens
|
||||
@@ -1277,19 +1299,48 @@ class GraphExecutor:
|
||||
protect_tokens=2000,
|
||||
)
|
||||
if continuous_conversation.needs_compaction():
|
||||
_phase_ratio = continuous_conversation.usage_ratio()
|
||||
self.logger.info(
|
||||
" Phase-boundary compaction (%.0f%% usage)",
|
||||
continuous_conversation.usage_ratio() * 100,
|
||||
_phase_ratio * 100,
|
||||
)
|
||||
summary = (
|
||||
f"Summary of earlier phases (before {next_spec.name}). "
|
||||
"See transition markers for phase details."
|
||||
)
|
||||
await continuous_conversation.compact(
|
||||
summary,
|
||||
keep_recent=4,
|
||||
phase_graduated=True,
|
||||
_data_dir = (
|
||||
str(self._storage_path / "data") if self._storage_path else None
|
||||
)
|
||||
if _data_dir:
|
||||
await continuous_conversation.compact_preserving_structure(
|
||||
spillover_dir=_data_dir,
|
||||
keep_recent=4,
|
||||
phase_graduated=True,
|
||||
)
|
||||
# Circuit breaker: if still over budget, fall back
|
||||
_post_ratio = continuous_conversation.usage_ratio()
|
||||
if _post_ratio >= 0.9 * _phase_ratio:
|
||||
self.logger.warning(
|
||||
" Structure-preserving compaction ineffective "
|
||||
"(%.0f%% -> %.0f%%), falling back to summary",
|
||||
_phase_ratio * 100,
|
||||
_post_ratio * 100,
|
||||
)
|
||||
summary = (
|
||||
f"Summary of earlier phases (before {next_spec.name}). "
|
||||
"See transition markers for phase details."
|
||||
)
|
||||
await continuous_conversation.compact(
|
||||
summary,
|
||||
keep_recent=4,
|
||||
phase_graduated=True,
|
||||
)
|
||||
else:
|
||||
summary = (
|
||||
f"Summary of earlier phases (before {next_spec.name}). "
|
||||
"See transition markers for phase details."
|
||||
)
|
||||
await continuous_conversation.compact(
|
||||
summary,
|
||||
keep_recent=4,
|
||||
phase_graduated=True,
|
||||
)
|
||||
|
||||
# Update input_data for next node
|
||||
input_data = result.output
|
||||
@@ -1541,6 +1592,10 @@ class GraphExecutor:
|
||||
override_tools: list | None = None,
|
||||
cumulative_output_keys: list[str] | None = None,
|
||||
event_triggered: bool = False,
|
||||
identity_prompt: str = "",
|
||||
narrative: str = "",
|
||||
node_registry: dict[str, NodeSpec] | None = None,
|
||||
graph: "GraphSpec | None" = None,
|
||||
) -> NodeContext:
|
||||
"""Build execution context for a node."""
|
||||
# Filter tools to those available to this node
|
||||
@@ -1569,6 +1624,8 @@ class GraphExecutor:
|
||||
node_tool_names=node_spec.tools,
|
||||
)
|
||||
|
||||
goal_context = goal.to_prompt_context()
|
||||
|
||||
return NodeContext(
|
||||
runtime=self.runtime,
|
||||
node_id=node_spec.id,
|
||||
@@ -1577,7 +1634,7 @@ class GraphExecutor:
|
||||
input_data=input_data,
|
||||
llm=self.llm,
|
||||
available_tools=available_tools,
|
||||
goal_context=goal.to_prompt_context(),
|
||||
goal_context=goal_context,
|
||||
goal=goal, # Pass Goal object for LLM-powered routers
|
||||
max_tokens=max_tokens,
|
||||
runtime_logger=self.runtime_logger,
|
||||
@@ -1587,12 +1644,18 @@ class GraphExecutor:
|
||||
cumulative_output_keys=cumulative_output_keys or [],
|
||||
event_triggered=event_triggered,
|
||||
accounts_prompt=node_accounts_prompt,
|
||||
identity_prompt=identity_prompt,
|
||||
narrative=narrative,
|
||||
execution_id=self._execution_id,
|
||||
stream_id=self._stream_id,
|
||||
node_registry=node_registry or {},
|
||||
all_tools=list(self.tools), # Full catalog for subagent tool resolution
|
||||
shared_node_registry=self.node_registry, # For subagent escalation routing
|
||||
)
|
||||
|
||||
VALID_NODE_TYPES = {
|
||||
"event_loop",
|
||||
"gcu",
|
||||
}
|
||||
# Node types removed in v0.5 — provide migration guidance
|
||||
REMOVED_NODE_TYPES = {
|
||||
@@ -1627,8 +1690,8 @@ class GraphExecutor:
|
||||
f"Must be one of: {sorted(self.VALID_NODE_TYPES)}."
|
||||
)
|
||||
|
||||
# Create based on type (only event_loop is valid)
|
||||
if node_spec.node_type == "event_loop":
|
||||
# Create based on type
|
||||
if node_spec.node_type in ("event_loop", "gcu"):
|
||||
# Auto-create EventLoopNode with sensible defaults.
|
||||
# Custom configs can still be pre-registered via node_registry.
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig
|
||||
@@ -1658,11 +1721,11 @@ class GraphExecutor:
|
||||
judge=None, # implicit judge: accept when output_keys are filled
|
||||
config=LoopConfig(
|
||||
max_iterations=lc.get("max_iterations", default_max_iter),
|
||||
max_tool_calls_per_turn=lc.get("max_tool_calls_per_turn", 10),
|
||||
max_tool_calls_per_turn=lc.get("max_tool_calls_per_turn", 30),
|
||||
tool_call_overflow_margin=lc.get("tool_call_overflow_margin", 0.5),
|
||||
stall_detection_threshold=lc.get("stall_detection_threshold", 3),
|
||||
max_history_tokens=lc.get("max_history_tokens", 32000),
|
||||
max_tool_result_chars=lc.get("max_tool_result_chars", 3_000),
|
||||
max_tool_result_chars=lc.get("max_tool_result_chars", 30_000),
|
||||
spillover_dir=spillover,
|
||||
),
|
||||
tool_executor=self.tool_executor,
|
||||
@@ -1845,6 +1908,7 @@ class GraphExecutor:
|
||||
source_result: NodeResult,
|
||||
source_node_spec: Any,
|
||||
path: list[str],
|
||||
node_registry: dict[str, NodeSpec] | None = None,
|
||||
) -> tuple[dict[str, NodeResult], int, int]:
|
||||
"""
|
||||
Execute multiple branches in parallel using asyncio.gather.
|
||||
@@ -1942,7 +2006,15 @@ class GraphExecutor:
|
||||
branch.retry_count = attempt
|
||||
|
||||
# Build context for this branch
|
||||
ctx = self._build_context(node_spec, memory, goal, mapped, graph.max_tokens)
|
||||
ctx = self._build_context(
|
||||
node_spec,
|
||||
memory,
|
||||
goal,
|
||||
mapped,
|
||||
graph.max_tokens,
|
||||
node_registry=node_registry,
|
||||
graph=graph,
|
||||
)
|
||||
node_impl = self._get_node_implementation(node_spec, graph.cleanup_llm_model)
|
||||
|
||||
# Emit node-started event (skip event_loop nodes)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""File tools MCP server constants.
|
||||
|
||||
Analogous to ``gcu.py`` — defines the server name and default stdio config
|
||||
so the runner can auto-register the files MCP server for any agent that has
|
||||
``event_loop`` or ``gcu`` nodes.
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP server identity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FILES_MCP_SERVER_NAME = "files-tools"
|
||||
"""Name used to identify the file tools MCP server in ``mcp_servers.json``."""
|
||||
|
||||
FILES_MCP_SERVER_CONFIG: dict = {
|
||||
"name": FILES_MCP_SERVER_NAME,
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "files_server.py", "--stdio"],
|
||||
"cwd": "../../tools",
|
||||
"description": "File tools for reading, writing, editing, and searching files",
|
||||
}
|
||||
"""Default stdio config for the file tools MCP server (relative to exports/<agent>/)."""
|
||||
@@ -0,0 +1,86 @@
|
||||
"""GCU (browser automation) node type constants.
|
||||
|
||||
A ``gcu`` node is an ``event_loop`` node with two automatic enhancements:
|
||||
1. A canonical browser best-practices system prompt is prepended.
|
||||
2. All tools from the GCU MCP server are auto-included.
|
||||
|
||||
No new ``NodeProtocol`` subclass — the ``gcu`` type is purely a declarative
|
||||
signal processed by the runner and executor at setup time.
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP server identity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
GCU_SERVER_NAME = "gcu-tools"
|
||||
"""Name used to identify the GCU MCP server in ``mcp_servers.json``."""
|
||||
|
||||
GCU_MCP_SERVER_CONFIG: dict = {
|
||||
"name": GCU_SERVER_NAME,
|
||||
"transport": "stdio",
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "gcu.server", "--stdio"],
|
||||
"cwd": "../../tools",
|
||||
"description": "GCU tools for browser automation",
|
||||
}
|
||||
"""Default stdio config for the GCU MCP server (relative to exports/<agent>/)."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser best-practices system prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
GCU_BROWSER_SYSTEM_PROMPT = """\
|
||||
# Browser Automation Best Practices
|
||||
|
||||
Follow these rules for reliable, efficient browser interaction.
|
||||
|
||||
## Reading Pages
|
||||
- ALWAYS prefer `browser_snapshot` over `browser_get_text("body")`
|
||||
— it returns a compact ~1-5 KB accessibility tree vs 100+ KB of raw HTML.
|
||||
- Use `browser_snapshot_aria` when you need full ARIA properties
|
||||
for detailed element inspection.
|
||||
- Do NOT use `browser_screenshot` for reading text content
|
||||
— it produces huge base64 images with no searchable text.
|
||||
- Only fall back to `browser_get_text` for extracting specific
|
||||
small elements by CSS selector.
|
||||
|
||||
## Navigation & Waiting
|
||||
- Always call `browser_wait` after navigation actions
|
||||
(`browser_open`, `browser_navigate`, `browser_click` on links)
|
||||
to let the page load.
|
||||
- NEVER re-navigate to the same URL after scrolling
|
||||
— this resets your scroll position and loses loaded content.
|
||||
|
||||
## Scrolling
|
||||
- Use large scroll amounts ~2000 when loading more content
|
||||
— sites like twitter and linkedin have lazy loading for paging.
|
||||
- After scrolling, take a new `browser_snapshot` to see updated content.
|
||||
|
||||
## Error Recovery
|
||||
- If a tool fails, retry once with the same approach.
|
||||
- If it fails a second time, STOP retrying and switch approach.
|
||||
- If `browser_snapshot` fails → try `browser_get_text` with a
|
||||
specific small selector as fallback.
|
||||
- If `browser_open` fails or page seems stale → `browser_stop`,
|
||||
then `browser_start`, then retry.
|
||||
|
||||
## Tab Management
|
||||
- Use `browser_tabs` to list open tabs when managing multiple pages.
|
||||
- Pass `target_id` to tools when operating on a specific tab.
|
||||
- Open background tabs with `browser_open(url=..., background=true)`
|
||||
to avoid losing your current context.
|
||||
- Close tabs you no longer need with `browser_close` to free resources.
|
||||
|
||||
## Login & Auth Walls
|
||||
- If you see a "Log in" or "Sign up" prompt instead of expected
|
||||
content, report the auth wall immediately — do NOT attempt to log in.
|
||||
- Check for cookie consent banners and dismiss them if they block content.
|
||||
|
||||
## Efficiency
|
||||
- Minimize tool calls — combine actions where possible.
|
||||
- When a snapshot result is saved to a spillover file, use
|
||||
`run_command` with grep to extract specific data rather than
|
||||
re-reading the full file.
|
||||
- Call `set_output` in the same turn as your last browser action
|
||||
when possible — don't waste a turn.
|
||||
"""
|
||||
@@ -176,7 +176,17 @@ class Goal(BaseModel):
|
||||
return True
|
||||
|
||||
def to_prompt_context(self) -> str:
|
||||
"""Generate context string for LLM prompts."""
|
||||
"""Generate context string for LLM prompts.
|
||||
|
||||
Returns empty string when the goal is a stub (no success criteria,
|
||||
no constraints, no context). Stub goals are metadata-only — used for
|
||||
graph identification but not communicated to the LLM as actionable
|
||||
intent. This prevents runtime agents (e.g. the queen) from
|
||||
misinterpreting their own goal as a user request.
|
||||
"""
|
||||
if not self.success_criteria and not self.constraints and not self.context:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
f"# Goal: {self.name}",
|
||||
f"{self.description}",
|
||||
|
||||
@@ -166,7 +166,7 @@ class NodeSpec(BaseModel):
|
||||
# Node behavior type
|
||||
node_type: str = Field(
|
||||
default="event_loop",
|
||||
description="Type: 'event_loop' (recommended), 'router', 'human_input'.",
|
||||
description="Type: 'event_loop' (recommended), 'gcu' (browser automation).",
|
||||
)
|
||||
|
||||
# Data flow
|
||||
@@ -204,6 +204,16 @@ class NodeSpec(BaseModel):
|
||||
default=None, description="Specific model to use (defaults to graph default)"
|
||||
)
|
||||
|
||||
# For subagent delegation
|
||||
sub_agents: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Node IDs that can be invoked as subagents from this node",
|
||||
)
|
||||
# For function nodes
|
||||
function: str | None = Field(
|
||||
default=None, description="Function name or path for function nodes"
|
||||
)
|
||||
|
||||
# For router nodes
|
||||
routes: dict[str, str] = Field(
|
||||
default_factory=dict, description="Condition -> target_node_id mapping for routers"
|
||||
@@ -505,6 +515,11 @@ class NodeContext:
|
||||
# Connected accounts prompt (injected from runner)
|
||||
accounts_prompt: str = ""
|
||||
|
||||
# Resume context — Layer 1 (identity) and Layer 2 (narrative) for
|
||||
# rebuilding the full system prompt when restoring from conversation store.
|
||||
identity_prompt: str = ""
|
||||
narrative: str = ""
|
||||
|
||||
# Event-triggered execution (no interactive user attached)
|
||||
event_triggered: bool = False
|
||||
|
||||
@@ -515,6 +530,20 @@ class NodeContext:
|
||||
# Falls back to node_id when not set (legacy / standalone executor).
|
||||
stream_id: str = ""
|
||||
|
||||
# Subagent mode
|
||||
is_subagent_mode: bool = False # True when running as a subagent (prevents nested delegation)
|
||||
report_callback: Any = None # async (message: str, data: dict | None) -> None
|
||||
node_registry: dict[str, "NodeSpec"] = field(default_factory=dict) # For subagent lookup
|
||||
|
||||
# Full tool catalog (unfiltered) — used by _execute_subagent to resolve
|
||||
# subagent tools that aren't in the parent node's filtered available_tools.
|
||||
all_tools: list[Tool] = field(default_factory=list)
|
||||
|
||||
# Shared reference to the executor's node_registry — used by subagent
|
||||
# escalation (_EscalationReceiver) to register temporary receivers that
|
||||
# the inject_input() routing chain can find.
|
||||
shared_node_registry: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeResult:
|
||||
|
||||
@@ -280,7 +280,7 @@ def build_transition_marker(
|
||||
]
|
||||
if file_lines:
|
||||
sections.append(
|
||||
"\nData files (use load_data to access):\n" + "\n".join(file_lines)
|
||||
"\nData files (use read_file to access):\n" + "\n".join(file_lines)
|
||||
)
|
||||
|
||||
# Agent working memory
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Anthropic Claude LLM provider - backward compatible wrapper around LiteLLM."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
def _get_api_key_from_credential_store() -> str | None:
|
||||
@@ -83,23 +82,6 @@ class AnthropicProvider(LLMProvider):
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until Claude produces a final response (via LiteLLM)."""
|
||||
return self._provider.complete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -120,20 +102,3 @@ class AnthropicProvider(LLMProvider):
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async tool-use loop via LiteLLM."""
|
||||
return await self._provider.acomplete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
+240
-353
@@ -11,7 +11,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,7 +23,7 @@ except ImportError:
|
||||
litellm = None # type: ignore[assignment]
|
||||
RateLimitError = Exception # type: ignore[assignment, misc]
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import StreamEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -70,13 +70,59 @@ def _patch_litellm_anthropic_oauth() -> None:
|
||||
AnthropicModelInfo.validate_environment = _patched_validate_environment
|
||||
|
||||
|
||||
def _patch_litellm_metadata_nonetype() -> None:
|
||||
"""Patch litellm entry points to prevent metadata=None TypeError.
|
||||
|
||||
litellm bug: the @client decorator in utils.py has four places that do
|
||||
"model_group" in kwargs.get("metadata", {})
|
||||
but kwargs["metadata"] can be explicitly None (set internally by
|
||||
litellm_params), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
This masks the real API error with a confusing APIConnectionError.
|
||||
|
||||
Fix: wrap the four litellm entry points (completion, acompletion,
|
||||
responses, aresponses) to pop metadata=None before the @client
|
||||
decorator's error handler can crash on it.
|
||||
"""
|
||||
import functools
|
||||
|
||||
for fn_name in ("completion", "acompletion", "responses", "aresponses"):
|
||||
original = getattr(litellm, fn_name, None)
|
||||
if original is None:
|
||||
continue
|
||||
if asyncio.iscoroutinefunction(original):
|
||||
|
||||
@functools.wraps(original)
|
||||
async def _async_wrapper(*args, _orig=original, **kwargs):
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs.pop("metadata", None)
|
||||
return await _orig(*args, **kwargs)
|
||||
|
||||
setattr(litellm, fn_name, _async_wrapper)
|
||||
else:
|
||||
|
||||
@functools.wraps(original)
|
||||
def _sync_wrapper(*args, _orig=original, **kwargs):
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs.pop("metadata", None)
|
||||
return _orig(*args, **kwargs)
|
||||
|
||||
setattr(litellm, fn_name, _sync_wrapper)
|
||||
|
||||
|
||||
if litellm is not None:
|
||||
_patch_litellm_anthropic_oauth()
|
||||
_patch_litellm_metadata_nonetype()
|
||||
|
||||
RATE_LIMIT_MAX_RETRIES = 10
|
||||
RATE_LIMIT_BACKOFF_BASE = 2 # seconds
|
||||
RATE_LIMIT_MAX_DELAY = 120 # seconds - cap to prevent absurd waits
|
||||
|
||||
# Empty-stream retries use a short fixed delay, not the rate-limit backoff.
|
||||
# Conversation-structure issues are deterministic — long waits don't help.
|
||||
EMPTY_STREAM_MAX_RETRIES = 3
|
||||
EMPTY_STREAM_RETRY_DELAY = 1.0 # seconds
|
||||
|
||||
# Directory for dumping failed requests
|
||||
FAILED_REQUESTS_DIR = Path.home() / ".hive" / "failed_requests"
|
||||
|
||||
@@ -191,6 +237,11 @@ def _is_stream_transient_error(exc: BaseException) -> bool:
|
||||
|
||||
Transient errors (recoverable=True): network issues, server errors, timeouts.
|
||||
Permanent errors (recoverable=False): auth, bad request, context window, etc.
|
||||
|
||||
NOTE: "Failed to parse tool call arguments" (malformed LLM output) is NOT
|
||||
transient at the stream level — retrying with the same messages produces the
|
||||
same malformed output. This error is handled at the EventLoopNode level
|
||||
where the conversation can be modified before retrying.
|
||||
"""
|
||||
try:
|
||||
from litellm.exceptions import (
|
||||
@@ -284,6 +335,12 @@ class LiteLLMProvider(LLMProvider):
|
||||
"LiteLLM is not installed. Please install it with: uv pip install litellm"
|
||||
)
|
||||
|
||||
# Note: The Codex ChatGPT backend is a Responses API endpoint at
|
||||
# chatgpt.com/backend-api/codex/responses. LiteLLM's model registry
|
||||
# correctly marks codex models with mode="responses", so we do NOT
|
||||
# override the mode. The responses_api_bridge in litellm handles
|
||||
# converting Chat Completions requests to Responses API format.
|
||||
|
||||
def _completion_with_rate_limit_retry(
|
||||
self, max_retries: int | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
@@ -396,43 +453,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
# unreachable, but satisfies type checker
|
||||
raise RuntimeError("Exhausted rate limit retries")
|
||||
|
||||
def _codex_sync_complete(self, kwargs: dict[str, Any]) -> "LLMResponse":
|
||||
"""Collect a streaming Codex response into a single LLMResponse.
|
||||
|
||||
The ChatGPT Codex backend only supports ``stream=True``, so non-streaming
|
||||
callers go through this helper which forces streaming, accumulates the
|
||||
chunks, and returns the same LLMResponse that ``complete()`` would.
|
||||
"""
|
||||
kwargs["stream"] = True
|
||||
response = litellm.completion(**kwargs) # type: ignore[union-attr]
|
||||
content = ""
|
||||
model_name = self.model
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = ""
|
||||
for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
delta = choice.delta
|
||||
if delta and delta.content:
|
||||
content += delta.content
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
input_tokens = getattr(chunk.usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(chunk.usage, "completion_tokens", 0) or 0
|
||||
if chunk.model:
|
||||
model_name = chunk.model
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=finish_reason,
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -444,6 +464,21 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion using LiteLLM."""
|
||||
# Codex ChatGPT backend requires streaming — delegate to the unified
|
||||
# async streaming path which properly handles tool calls.
|
||||
if self._codex_backend:
|
||||
return asyncio.run(
|
||||
self.acomplete(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare messages with system prompt
|
||||
full_messages = []
|
||||
if system:
|
||||
@@ -481,11 +516,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
# Codex ChatGPT backend requires streaming and rejects max_output_tokens.
|
||||
if self._codex_backend:
|
||||
kwargs.pop("max_tokens", None)
|
||||
return self._codex_sync_complete(kwargs)
|
||||
|
||||
# Make the call
|
||||
response = self._completion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
@@ -511,127 +541,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Run a tool-use loop until the LLM produces a final response."""
|
||||
# Prepare messages with system prompt
|
||||
current_messages = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
# Convert tools to OpenAI format
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Build kwargs
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = self._completion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
# Track tokens
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Check if we're done (no tool calls)
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
# Process tool calls.
|
||||
# Add assistant message with tool calls.
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Execute tools and add results.
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
# Surface error to LLM and skip tool execution
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
# Add tool result message
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
# Max iterations reached
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Async variants — non-blocking on the event loop
|
||||
# ------------------------------------------------------------------
|
||||
@@ -754,6 +663,19 @@ class LiteLLMProvider(LLMProvider):
|
||||
max_retries: int | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete(). Uses litellm.acompletion — non-blocking."""
|
||||
# Codex ChatGPT backend requires streaming — route through stream() which
|
||||
# already handles Codex quirks and has proper tool call accumulation.
|
||||
if self._codex_backend:
|
||||
stream_iter = self.stream(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
max_tokens=max_tokens,
|
||||
response_format=response_format,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return await self._collect_stream_to_response(stream_iter)
|
||||
|
||||
full_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
@@ -782,11 +704,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
# Codex ChatGPT backend requires streaming and rejects max_output_tokens.
|
||||
if self._codex_backend:
|
||||
kwargs.pop("max_tokens", None)
|
||||
return await self._codex_async_complete(kwargs)
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(max_retries=max_retries, **kwargs)
|
||||
|
||||
content = response.choices[0].message.content or ""
|
||||
@@ -803,147 +720,6 @@ class LiteLLMProvider(LLMProvider):
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
async def _codex_async_complete(self, kwargs: dict[str, Any]) -> "LLMResponse":
|
||||
"""Async version of _codex_sync_complete."""
|
||||
kwargs["stream"] = True
|
||||
response = await litellm.acompletion(**kwargs) # type: ignore[union-attr]
|
||||
content = ""
|
||||
model_name = self.model
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
finish_reason = ""
|
||||
async for chunk in response:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
delta = choice.delta
|
||||
if delta and delta.content:
|
||||
content += delta.content
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
input_tokens = getattr(chunk.usage, "prompt_tokens", 0) or 0
|
||||
output_tokens = getattr(chunk.usage, "completion_tokens", 0) or 0
|
||||
if chunk.model:
|
||||
model_name = chunk.model
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model_name,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=finish_reason,
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 4096,
|
||||
) -> LLMResponse:
|
||||
"""Async version of complete_with_tools(). Uses litellm.acompletion — non-blocking."""
|
||||
current_messages: list[dict[str, Any]] = []
|
||||
if system:
|
||||
current_messages.append({"role": "system", "content": system})
|
||||
current_messages.extend(messages)
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
openai_tools = [self._tool_to_openai_format(t) for t in tools]
|
||||
|
||||
for _ in range(max_iterations):
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": current_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"tools": openai_tools,
|
||||
**self.extra_kwargs,
|
||||
}
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
response = await self._acompletion_with_rate_limit_retry(**kwargs)
|
||||
|
||||
usage = response.usage
|
||||
if usage:
|
||||
total_input_tokens += usage.prompt_tokens
|
||||
total_output_tokens += usage.completion_tokens
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if choice.finish_reason == "stop" or not message.tool_calls:
|
||||
return LLMResponse(
|
||||
content=message.content or "",
|
||||
model=response.model or self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason=choice.finish_reason or "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
for tool_call in message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": "Invalid JSON arguments provided to tool.",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
tool_use = ToolUse(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=args,
|
||||
)
|
||||
|
||||
result = tool_executor(tool_use)
|
||||
|
||||
current_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result.tool_use_id,
|
||||
"content": result.content,
|
||||
}
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content="Max tool iterations reached",
|
||||
model=self.model,
|
||||
input_tokens=total_input_tokens,
|
||||
output_tokens=total_output_tokens,
|
||||
stop_reason="max_iterations",
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def _tool_to_openai_format(self, tool: Tool) -> dict[str, Any]:
|
||||
"""Convert Tool to OpenAI function calling format."""
|
||||
return {
|
||||
@@ -965,6 +741,8 @@ class LiteLLMProvider(LLMProvider):
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
json_mode: bool = False,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream a completion via litellm.acompletion(stream=True).
|
||||
|
||||
@@ -989,6 +767,31 @@ class LiteLLMProvider(LLMProvider):
|
||||
full_messages.append({"role": "system", "content": system})
|
||||
full_messages.extend(messages)
|
||||
|
||||
# Codex Responses API requires an `instructions` field (system prompt).
|
||||
# Inject a minimal one when callers don't provide a system message.
|
||||
if self._codex_backend and not any(m["role"] == "system" for m in full_messages):
|
||||
full_messages.insert(0, {"role": "system", "content": "You are a helpful assistant."})
|
||||
|
||||
# Add JSON mode via prompt engineering (works across all providers)
|
||||
if json_mode:
|
||||
json_instruction = "\n\nPlease respond with a valid JSON object."
|
||||
if full_messages and full_messages[0]["role"] == "system":
|
||||
full_messages[0]["content"] += json_instruction
|
||||
else:
|
||||
full_messages.insert(0, {"role": "system", "content": json_instruction.strip()})
|
||||
|
||||
# Remove ghost empty assistant messages (content="" and no tool_calls).
|
||||
# These arise when a model returns an empty stream after a tool result
|
||||
# (an "expected" no-op turn). Keeping them in history confuses some
|
||||
# models (notably Codex/gpt-5.3) and causes cascading empty streams.
|
||||
full_messages = [
|
||||
m
|
||||
for m in full_messages
|
||||
if not (
|
||||
m.get("role") == "assistant" and not m.get("content") and not m.get("tool_calls")
|
||||
)
|
||||
]
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": full_messages,
|
||||
@@ -1003,7 +806,9 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["api_base"] = self.api_base
|
||||
if tools:
|
||||
kwargs["tools"] = [self._tool_to_openai_format(t) for t in tools]
|
||||
# The Codex ChatGPT backend rejects max_output_tokens and stream_options.
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
# The Codex ChatGPT backend (Responses API) rejects several params.
|
||||
if self._codex_backend:
|
||||
kwargs.pop("max_tokens", None)
|
||||
kwargs.pop("stream_options", None)
|
||||
@@ -1015,6 +820,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
tail_events: list[StreamEvent] = []
|
||||
accumulated_text = ""
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
_last_tool_idx = 0 # tracks most recently opened tool call slot
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
stream_finish_reason: str | None = None
|
||||
@@ -1038,9 +844,36 @@ class LiteLLMProvider(LLMProvider):
|
||||
)
|
||||
|
||||
# --- Tool calls (accumulate across chunks) ---
|
||||
# The Codex/Responses API bridge (litellm bug) hardcodes
|
||||
# index=0 on every ChatCompletionToolCallChunk, even for
|
||||
# parallel tool calls. We work around this by using tc.id
|
||||
# (set on output_item.added events) as a "new tool call"
|
||||
# signal and tracking the most recently opened slot for
|
||||
# argument deltas that arrive with id=None.
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index if hasattr(tc, "index") and tc.index is not None else 0
|
||||
|
||||
if tc.id:
|
||||
# New tool call announced (or done event re-sent).
|
||||
# Check if this id already has a slot.
|
||||
existing_idx = next(
|
||||
(k for k, v in tool_calls_acc.items() if v["id"] == tc.id),
|
||||
None,
|
||||
)
|
||||
if existing_idx is not None:
|
||||
idx = existing_idx
|
||||
elif idx in tool_calls_acc and tool_calls_acc[idx]["id"] not in (
|
||||
"",
|
||||
tc.id,
|
||||
):
|
||||
# Slot taken by a different call — assign new index
|
||||
idx = max(tool_calls_acc.keys()) + 1
|
||||
_last_tool_idx = idx
|
||||
else:
|
||||
# Argument delta with no id — route to last opened slot
|
||||
idx = _last_tool_idx
|
||||
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""}
|
||||
if tc.id:
|
||||
@@ -1088,27 +921,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
# (If text deltas were yielded above, has_content is True
|
||||
# and we skip the retry path — nothing was yielded in vain.)
|
||||
has_content = accumulated_text or tool_calls_acc
|
||||
if not has_content and attempt < RATE_LIMIT_MAX_RETRIES:
|
||||
# If the conversation ends with an assistant or tool
|
||||
# message, an empty stream is expected — the LLM has
|
||||
# nothing new to say. Don't burn retries on this;
|
||||
# let the caller (EventLoopNode) decide what to do.
|
||||
# Typical case: client_facing node where the LLM set
|
||||
# all outputs via set_output tool calls, and the tool
|
||||
# results are the last messages.
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
if last_role in ("assistant", "tool"):
|
||||
logger.debug(
|
||||
"[stream] Empty response after %s message — expected, not retrying.",
|
||||
last_role,
|
||||
)
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
|
||||
if not has_content:
|
||||
# finish_reason=length means the model exhausted
|
||||
# max_tokens before producing content. Retrying with
|
||||
# the same max_tokens will never help.
|
||||
@@ -1126,28 +939,49 @@ class LiteLLMProvider(LLMProvider):
|
||||
yield event
|
||||
return
|
||||
|
||||
wait = _compute_retry_delay(attempt)
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
# Empty stream — always retry regardless of last message
|
||||
# role. Ghost empty streams after tool results are NOT
|
||||
# expected no-ops; they create infinite loops when the
|
||||
# conversation doesn't change between iterations.
|
||||
# After retries, return the empty result and let the
|
||||
# caller (EventLoopNode) decide how to handle it.
|
||||
last_role = next(
|
||||
(m["role"] for m in reversed(full_messages) if m.get("role") != "system"),
|
||||
None,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{RATE_LIMIT_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
if attempt < EMPTY_STREAM_MAX_RETRIES:
|
||||
token_count, token_method = _estimate_tokens(
|
||||
self.model,
|
||||
full_messages,
|
||||
)
|
||||
dump_path = _dump_failed_request(
|
||||
model=self.model,
|
||||
kwargs=kwargs,
|
||||
error_type="empty_stream",
|
||||
attempt=attempt,
|
||||
)
|
||||
logger.warning(
|
||||
f"[stream-retry] {self.model} returned empty stream "
|
||||
f"after {last_role} message — "
|
||||
f"~{token_count} tokens ({token_method}). "
|
||||
f"Request dumped to: {dump_path}. "
|
||||
f"Retrying in {EMPTY_STREAM_RETRY_DELAY}s "
|
||||
f"(attempt {attempt + 1}/{EMPTY_STREAM_MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(EMPTY_STREAM_RETRY_DELAY)
|
||||
continue
|
||||
|
||||
# Success (or final attempt) — flush remaining events.
|
||||
# All retries exhausted — log and return the empty
|
||||
# result. EventLoopNode's empty response guard will
|
||||
# accept if all outputs are set, or handle the ghost
|
||||
# stream case if outputs are still missing.
|
||||
logger.error(
|
||||
f"[stream] {self.model} returned empty stream after "
|
||||
f"{EMPTY_STREAM_MAX_RETRIES} retries "
|
||||
f"(last_role={last_role}). Returning empty result."
|
||||
)
|
||||
|
||||
# Success (or empty after exhausted retries) — flush events.
|
||||
for event in tail_events:
|
||||
yield event
|
||||
return
|
||||
@@ -1179,3 +1013,56 @@ class LiteLLMProvider(LLMProvider):
|
||||
recoverable = _is_stream_transient_error(e)
|
||||
yield StreamErrorEvent(error=str(e), recoverable=recoverable)
|
||||
return
|
||||
|
||||
async def _collect_stream_to_response(
|
||||
self,
|
||||
stream: AsyncIterator[StreamEvent],
|
||||
) -> LLMResponse:
|
||||
"""Consume a stream() iterator and collect it into a single LLMResponse.
|
||||
|
||||
Used by acomplete() to route through the unified streaming path so that
|
||||
all backends (including Codex) get proper tool call handling.
|
||||
"""
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
content = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
stop_reason = ""
|
||||
model = self.model
|
||||
|
||||
async for event in stream:
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
content = event.snapshot # snapshot is the accumulated text
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": event.tool_use_id,
|
||||
"name": event.tool_name,
|
||||
"input": event.tool_input,
|
||||
}
|
||||
)
|
||||
elif isinstance(event, FinishEvent):
|
||||
input_tokens = event.input_tokens
|
||||
output_tokens = event.output_tokens
|
||||
stop_reason = event.stop_reason
|
||||
if event.model:
|
||||
model = event.model
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
if not event.recoverable:
|
||||
raise RuntimeError(f"Stream error: {event.error}")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
stop_reason=stop_reason,
|
||||
raw_response={"tool_calls": tool_calls} if tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
@@ -146,43 +146,6 @@ class MockLLMProvider(LLMProvider):
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Generate a mock completion without tool use.
|
||||
|
||||
In mock mode, we skip tool execution and return a final response immediately.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation (ignored in mock mode)
|
||||
system: System prompt (used to extract expected output keys)
|
||||
tools: Available tools (ignored in mock mode)
|
||||
tool_executor: Tool executor function (ignored in mock mode)
|
||||
max_iterations: Max iterations (ignored in mock mode)
|
||||
|
||||
Returns:
|
||||
LLMResponse with mock content
|
||||
"""
|
||||
# In mock mode, we don't execute tools - just return a final response
|
||||
# Try to generate JSON if the system prompt suggests structured output
|
||||
json_mode = "json" in system.lower() or "output_keys" in system.lower()
|
||||
|
||||
content = self._generate_mock_response(system=system, json_mode=json_mode)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
stop_reason="mock_complete",
|
||||
)
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -204,23 +167,6 @@ class MockLLMProvider(LLMProvider):
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[[ToolUse], ToolResult],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""Async mock tool-use completion (no I/O, returns immediately)."""
|
||||
return self.complete_with_tools(
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
@@ -90,30 +90,6 @@ class LLMProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Run a tool-use loop until the LLM produces a final response.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation
|
||||
system: System prompt
|
||||
tools: Available tools
|
||||
tool_executor: Function to execute tools: (ToolUse) -> ToolResult
|
||||
max_iterations: Max tool calls before stopping
|
||||
|
||||
Returns:
|
||||
Final LLMResponse after tool use completes
|
||||
"""
|
||||
pass
|
||||
|
||||
async def acomplete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -144,32 +120,6 @@ class LLMProvider(ABC):
|
||||
),
|
||||
)
|
||||
|
||||
async def acomplete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list["Tool"],
|
||||
tool_executor: Callable[["ToolUse"], "ToolResult"],
|
||||
max_iterations: int = 10,
|
||||
) -> "LLMResponse":
|
||||
"""Async version of complete_with_tools(). Non-blocking on the event loop.
|
||||
|
||||
Default implementation offloads the sync complete_with_tools() to a thread pool.
|
||||
Subclasses SHOULD override for native async I/O.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(
|
||||
self.complete_with_tools,
|
||||
messages=messages,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
max_iterations=max_iterations,
|
||||
),
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -10,6 +10,7 @@ Usage:
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -562,16 +563,29 @@ def _validate_agent_path(agent_path: str) -> tuple[Path | None, str | None]:
|
||||
path = Path(agent_path)
|
||||
|
||||
# Resolve relative paths against project root (not MCP server's cwd)
|
||||
if not path.is_absolute() and not path.exists():
|
||||
resolved = _PROJECT_ROOT / path
|
||||
if resolved.exists():
|
||||
path = resolved
|
||||
if not path.is_absolute():
|
||||
path = _PROJECT_ROOT / path
|
||||
|
||||
# Restrict to allowed directories BEFORE checking existence to prevent
|
||||
# leaking whether arbitrary filesystem paths exist on disk.
|
||||
from framework.server.app import validate_agent_path
|
||||
|
||||
try:
|
||||
path = validate_agent_path(path)
|
||||
except ValueError:
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "agent_path must be inside an allowed directory "
|
||||
"(exports/, examples/, or ~/.hive/agents/)",
|
||||
}
|
||||
)
|
||||
|
||||
if not path.exists():
|
||||
return None, json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Agent path not found: {path}",
|
||||
"error": f"Agent path not found: {agent_path}",
|
||||
"hint": "Run export_graph to create an agent in exports/ first",
|
||||
}
|
||||
)
|
||||
@@ -586,7 +600,7 @@ def add_node(
|
||||
description: Annotated[str, "What this node does"],
|
||||
node_type: Annotated[
|
||||
str,
|
||||
"Type: event_loop (recommended), router.",
|
||||
"Type: event_loop (recommended), gcu (browser automation), router.",
|
||||
],
|
||||
input_keys: Annotated[str, "JSON array of keys this node reads from shared memory"],
|
||||
output_keys: Annotated[str, "JSON array of keys this node writes to shared memory"],
|
||||
@@ -675,8 +689,23 @@ def add_node(
|
||||
if node_type == "event_loop" and not system_prompt:
|
||||
warnings.append(f"Event loop node '{node_id}' should have a system_prompt")
|
||||
|
||||
# GCU node validation
|
||||
if node_type == "gcu":
|
||||
if tools_list:
|
||||
warnings.append(
|
||||
f"GCU node '{node_id}' auto-includes all browser tools from the "
|
||||
f"gcu-tools MCP server. Manually listed tools {tools_list} will be "
|
||||
f"merged with the auto-included set."
|
||||
)
|
||||
if not system_prompt:
|
||||
warnings.append(
|
||||
f"GCU node '{node_id}' has a default browser best-practices prompt. "
|
||||
f"Consider adding a task-specific system_prompt — it will be appended "
|
||||
f"after the browser instructions."
|
||||
)
|
||||
|
||||
# Warn about client_facing on nodes with tools (likely autonomous work)
|
||||
if node_type == "event_loop" and client_facing and tools_list:
|
||||
if node_type in ("event_loop", "gcu") and client_facing and tools_list:
|
||||
warnings.append(
|
||||
f"Node '{node_id}' is client_facing=True but has tools {tools_list}. "
|
||||
"Nodes with tools typically do autonomous work and should be "
|
||||
@@ -1774,6 +1803,14 @@ def export_graph() -> str:
|
||||
enriched_criteria.append(crit_dict)
|
||||
export_data["goal"]["success_criteria"] = enriched_criteria
|
||||
|
||||
# Auto-add GCU MCP server if any node uses the gcu type
|
||||
has_gcu_nodes = any(n.node_type == "gcu" for n in session.nodes)
|
||||
if has_gcu_nodes:
|
||||
from framework.graph.gcu import GCU_MCP_SERVER_CONFIG, GCU_SERVER_NAME
|
||||
|
||||
if not any(s.get("name") == GCU_SERVER_NAME for s in session.mcp_servers):
|
||||
session.mcp_servers.append(dict(GCU_MCP_SERVER_CONFIG))
|
||||
|
||||
# === WRITE FILES TO DISK ===
|
||||
# Create exports directory
|
||||
exports_dir = Path("exports") / session.name
|
||||
@@ -1864,7 +1901,7 @@ def import_from_export(
|
||||
return json.dumps({"success": False, "error": f"File not found: {agent_json_path}"})
|
||||
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as e:
|
||||
return json.dumps({"success": False, "error": f"Invalid JSON: {e}"})
|
||||
|
||||
@@ -1946,7 +1983,7 @@ def get_session_status() -> str:
|
||||
@mcp.tool()
|
||||
def configure_loop(
|
||||
max_iterations: Annotated[int, "Maximum loop iterations per node execution (default 50)"] = 50,
|
||||
max_tool_calls_per_turn: Annotated[int, "Maximum tool calls per LLM turn (default 10)"] = 10,
|
||||
max_tool_calls_per_turn: Annotated[int, "Maximum tool calls per LLM turn (default 30)"] = 30,
|
||||
stall_detection_threshold: Annotated[
|
||||
int, "Consecutive identical responses before stall detection triggers (default 3)"
|
||||
] = 3,
|
||||
@@ -2772,6 +2809,21 @@ def run_tests(
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
# Guard: pytest must be available as a subprocess command.
|
||||
# Install with: pip install 'framework[testing]'
|
||||
if shutil.which("pytest") is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"error": (
|
||||
"pytest is not installed or not on PATH. "
|
||||
"Hive's test runner requires pytest at runtime. "
|
||||
"Install it with: pip install 'framework[testing]' "
|
||||
"or: uv pip install 'framework[testing]'"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
path, err = _validate_agent_path(agent_path)
|
||||
if err:
|
||||
return err
|
||||
@@ -2965,6 +3017,22 @@ def debug_test(
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
# Guard: pytest must be available as a subprocess command.
|
||||
# Install with: pip install 'framework[testing]'
|
||||
if shutil.which("pytest") is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"goal_id": goal_id,
|
||||
"test_name": test_name,
|
||||
"error": (
|
||||
"pytest is not installed or not on PATH. "
|
||||
"Hive's test runner requires pytest at runtime. "
|
||||
"Install it with: pip install 'framework[testing]' "
|
||||
"or: uv pip install 'framework[testing]'"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Derive agent_path from session if not provided
|
||||
if not agent_path and _session:
|
||||
agent_path = f"exports/{_session.name}"
|
||||
@@ -2986,7 +3054,7 @@ def debug_test(
|
||||
# Find which file contains the test
|
||||
test_file = None
|
||||
for py_file in tests_dir.glob("test_*.py"):
|
||||
content = py_file.read_text()
|
||||
content = py_file.read_text(encoding="utf-8")
|
||||
if f"def {test_name}" in content or f"async def {test_name}" in content:
|
||||
test_file = py_file
|
||||
break
|
||||
@@ -3138,7 +3206,7 @@ def list_tests(
|
||||
tests = []
|
||||
for test_file in sorted(tests_dir.glob("test_*.py")):
|
||||
try:
|
||||
content = test_file.read_text()
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
tree = ast.parse(content)
|
||||
|
||||
# Find all async function definitions that start with "test_"
|
||||
|
||||
@@ -428,7 +428,7 @@ def _load_resume_state(
|
||||
if not cp_path.exists():
|
||||
return None
|
||||
try:
|
||||
cp_data = json.loads(cp_path.read_text())
|
||||
cp_data = json.loads(cp_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
return {
|
||||
@@ -444,7 +444,7 @@ def _load_resume_state(
|
||||
if not state_path.exists():
|
||||
return None
|
||||
try:
|
||||
state_data = json.loads(state_path.read_text())
|
||||
state_data = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
progress = state_data.get("progress", {})
|
||||
@@ -1941,12 +1941,74 @@ def _open_browser(url: str) -> None:
|
||||
pass # Best-effort — don't crash if browser can't open
|
||||
|
||||
|
||||
def _build_frontend() -> bool:
|
||||
"""Build the frontend if source is newer than dist. Returns True if dist exists."""
|
||||
import subprocess
|
||||
|
||||
# Find the frontend directory relative to this file or cwd
|
||||
candidates = [
|
||||
Path("core/frontend"),
|
||||
Path(__file__).resolve().parent.parent.parent / "frontend",
|
||||
]
|
||||
frontend_dir: Path | None = None
|
||||
for c in candidates:
|
||||
if (c / "package.json").is_file():
|
||||
frontend_dir = c.resolve()
|
||||
break
|
||||
|
||||
if frontend_dir is None:
|
||||
return False
|
||||
|
||||
dist_dir = frontend_dir / "dist"
|
||||
src_dir = frontend_dir / "src"
|
||||
|
||||
# Skip build if dist is up-to-date (newest src file older than dist index.html)
|
||||
index_html = dist_dir / "index.html"
|
||||
if index_html.exists() and src_dir.is_dir():
|
||||
dist_mtime = index_html.stat().st_mtime
|
||||
needs_build = False
|
||||
for f in src_dir.rglob("*"):
|
||||
if f.is_file() and f.stat().st_mtime > dist_mtime:
|
||||
needs_build = True
|
||||
break
|
||||
if not needs_build:
|
||||
return True
|
||||
|
||||
# Need to build
|
||||
print("Building frontend...")
|
||||
try:
|
||||
# Ensure deps are installed
|
||||
subprocess.run(
|
||||
["npm", "install", "--no-fund", "--no-audit"],
|
||||
cwd=frontend_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["npm", "run", "build"],
|
||||
cwd=frontend_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
print("Frontend built.")
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
print("Node.js not found — skipping frontend build.")
|
||||
return dist_dir.is_dir()
|
||||
except subprocess.CalledProcessError as exc:
|
||||
stderr = exc.stderr.decode(errors="replace") if exc.stderr else ""
|
||||
print(f"Frontend build failed: {stderr[:500]}")
|
||||
return dist_dir.is_dir()
|
||||
|
||||
|
||||
def cmd_serve(args: argparse.Namespace) -> int:
|
||||
"""Start the HTTP API server."""
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
_build_frontend()
|
||||
|
||||
from framework.server.app import create_app
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -1971,7 +2033,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
||||
print(f"Error loading {agent_path}: {e}")
|
||||
|
||||
# Start server using AppRunner/TCPSite (same pattern as webhook_server.py)
|
||||
runner = web.AppRunner(app)
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, args.host, args.port)
|
||||
await site.start()
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Pre-load validation for agent graphs.
|
||||
|
||||
Runs structural and credential checks before MCP servers are spawned.
|
||||
Fails fast with actionable error messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.graph.node import NodeSpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreloadValidationError(Exception):
|
||||
"""Raised when pre-load validation fails."""
|
||||
|
||||
def __init__(self, errors: list[str]):
|
||||
self.errors = errors
|
||||
msg = "Pre-load validation failed:\n" + "\n".join(f" - {e}" for e in errors)
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreloadResult:
|
||||
"""Result of pre-load validation."""
|
||||
|
||||
valid: bool
|
||||
errors: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def validate_graph_structure(graph: GraphSpec) -> list[str]:
|
||||
"""Run graph structural validation (includes GCU subagent-only checks).
|
||||
|
||||
Delegates to GraphSpec.validate() which checks entry/terminal nodes,
|
||||
edge references, reachability, fan-out rules, and GCU constraints.
|
||||
"""
|
||||
return graph.validate()
|
||||
|
||||
|
||||
def validate_credentials(
|
||||
nodes: list[NodeSpec],
|
||||
*,
|
||||
interactive: bool = True,
|
||||
skip: bool = False,
|
||||
) -> None:
|
||||
"""Validate agent credentials.
|
||||
|
||||
Calls ``validate_agent_credentials`` which performs two-phase validation:
|
||||
1. Presence check (env var, encrypted store, Aden sync)
|
||||
2. Health check (lightweight HTTP call to verify the key works)
|
||||
|
||||
On failure raises ``CredentialError`` with ``validation_result`` and
|
||||
``failed_cred_names`` attributes preserved from the upstream check.
|
||||
|
||||
In interactive mode (CLI with TTY), attempts recovery via the
|
||||
credential setup flow before re-raising.
|
||||
"""
|
||||
if skip:
|
||||
return
|
||||
|
||||
from framework.credentials.validation import validate_agent_credentials
|
||||
|
||||
if not interactive:
|
||||
# Non-interactive: let CredentialError propagate with full context.
|
||||
# validate_agent_credentials attaches .validation_result and
|
||||
# .failed_cred_names to the exception automatically.
|
||||
validate_agent_credentials(nodes)
|
||||
return
|
||||
|
||||
import sys
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
try:
|
||||
validate_agent_credentials(nodes)
|
||||
except CredentialError as e:
|
||||
if not sys.stdin.isatty():
|
||||
raise
|
||||
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
|
||||
from framework.credentials.validation import build_setup_session_from_error
|
||||
|
||||
session = build_setup_session_from_error(e, nodes=nodes)
|
||||
if not session.missing:
|
||||
raise
|
||||
|
||||
result = session.run_interactive()
|
||||
if not result.success:
|
||||
# Preserve the original validation_result so callers can
|
||||
# inspect which credentials are still missing.
|
||||
exc = CredentialError(
|
||||
"Credential setup incomplete. Run again after configuring the required credentials."
|
||||
)
|
||||
if hasattr(e, "validation_result"):
|
||||
exc.validation_result = e.validation_result # type: ignore[attr-defined]
|
||||
if hasattr(e, "failed_cred_names"):
|
||||
exc.failed_cred_names = e.failed_cred_names # type: ignore[attr-defined]
|
||||
raise exc from None
|
||||
|
||||
# Re-validate after successful setup — this will raise if still broken,
|
||||
# with fresh validation_result attached to the new exception.
|
||||
validate_agent_credentials(nodes)
|
||||
|
||||
|
||||
def credential_errors_to_json(exc: Exception) -> dict:
|
||||
"""Extract structured credential failure details from a CredentialError.
|
||||
|
||||
Returns a dict suitable for JSON serialization with enough detail for
|
||||
the queen to report actionable guidance to the user. Falls back to
|
||||
``str(exc)`` when rich metadata is not available.
|
||||
"""
|
||||
result = getattr(exc, "validation_result", None)
|
||||
if result is None:
|
||||
return {
|
||||
"error": "credentials_required",
|
||||
"message": str(exc),
|
||||
}
|
||||
|
||||
failed = result.failed
|
||||
missing = []
|
||||
for c in failed:
|
||||
if c.available:
|
||||
status = "invalid"
|
||||
elif c.aden_not_connected:
|
||||
status = "aden_not_connected"
|
||||
else:
|
||||
status = "missing"
|
||||
entry: dict = {
|
||||
"credential": c.credential_name,
|
||||
"env_var": c.env_var,
|
||||
"status": status,
|
||||
}
|
||||
if c.tools:
|
||||
entry["tools"] = c.tools
|
||||
if c.node_types:
|
||||
entry["node_types"] = c.node_types
|
||||
if c.help_url:
|
||||
entry["help_url"] = c.help_url
|
||||
if c.validation_message:
|
||||
entry["validation_message"] = c.validation_message
|
||||
missing.append(entry)
|
||||
|
||||
return {
|
||||
"error": "credentials_required",
|
||||
"message": str(exc),
|
||||
"missing_credentials": missing,
|
||||
}
|
||||
|
||||
|
||||
def run_preload_validation(
|
||||
graph: GraphSpec,
|
||||
*,
|
||||
interactive: bool = True,
|
||||
skip_credential_validation: bool = False,
|
||||
) -> PreloadResult:
|
||||
"""Run all pre-load validations.
|
||||
|
||||
Order:
|
||||
1. Graph structure (includes GCU subagent-only checks) — non-recoverable
|
||||
2. Credentials — potentially recoverable via interactive setup
|
||||
|
||||
Raises PreloadValidationError for structural issues.
|
||||
Raises CredentialError for credential issues.
|
||||
"""
|
||||
# 1. Structural validation (calls graph.validate() which includes GCU checks)
|
||||
graph_errors = validate_graph_structure(graph)
|
||||
if graph_errors:
|
||||
raise PreloadValidationError(graph_errors)
|
||||
|
||||
# 2. Credential validation
|
||||
validate_credentials(
|
||||
graph.nodes,
|
||||
interactive=interactive,
|
||||
skip=skip_credential_validation,
|
||||
)
|
||||
|
||||
return PreloadResult(valid=True)
|
||||
@@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, Any
|
||||
from framework.config import get_hive_config, get_preferred_model
|
||||
from framework.credentials.validation import (
|
||||
ensure_credential_key_env as _ensure_credential_key_env,
|
||||
validate_agent_credentials,
|
||||
)
|
||||
from framework.graph import Goal
|
||||
from framework.graph.edge import (
|
||||
@@ -25,6 +24,7 @@ from framework.graph.edge import (
|
||||
from framework.graph.executor import ExecutionResult
|
||||
from framework.graph.node import NodeSpec
|
||||
from framework.llm.provider import LLMProvider, Tool
|
||||
from framework.runner.preload_validation import run_preload_validation
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime, AgentRuntimeConfig, create_agent_runtime
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
@@ -322,8 +322,9 @@ def _save_refreshed_codex_credentials(auth_data: dict, token_data: dict) -> None
|
||||
auth_data["tokens"] = tokens
|
||||
auth_data["last_refresh"] = datetime.now(UTC).isoformat()
|
||||
|
||||
CODEX_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(CODEX_AUTH_FILE, "w") as f:
|
||||
CODEX_AUTH_FILE.parent.mkdir(parents=True, exist_ok=True, mode=0o700)
|
||||
fd = os.open(CODEX_AUTH_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(auth_data, f, indent=2)
|
||||
logger.debug("Codex credentials refreshed successfully")
|
||||
except (OSError, KeyError) as exc:
|
||||
@@ -678,68 +679,29 @@ class AgentRunner:
|
||||
self._agent_runtime: AgentRuntime | None = None
|
||||
self._uses_async_entry_points = self.graph.has_async_entry_points()
|
||||
|
||||
# Validate credentials before spawning MCP servers.
|
||||
# Pre-load validation: structural checks + credentials.
|
||||
# Fails fast with actionable guidance — no MCP noise on screen.
|
||||
self._validate_credentials()
|
||||
run_preload_validation(
|
||||
self.graph,
|
||||
interactive=self._interactive,
|
||||
skip_credential_validation=self.skip_credential_validation,
|
||||
)
|
||||
|
||||
# Auto-discover tools from tools.py
|
||||
tools_path = agent_path / "tools.py"
|
||||
if tools_path.exists():
|
||||
self._tool_registry.discover_from_module(tools_path)
|
||||
|
||||
# Set environment variables for MCP subprocesses
|
||||
# These are inherited by MCP servers (e.g., GCU browser tools)
|
||||
os.environ["HIVE_AGENT_NAME"] = agent_path.name
|
||||
os.environ["HIVE_STORAGE_PATH"] = str(self._storage_path)
|
||||
|
||||
# Auto-discover MCP servers from mcp_servers.json
|
||||
mcp_config_path = agent_path / "mcp_servers.json"
|
||||
if mcp_config_path.exists():
|
||||
self._load_mcp_servers_from_config(mcp_config_path)
|
||||
|
||||
def _validate_credentials(self) -> None:
|
||||
"""Check that required credentials are available before spawning MCP servers.
|
||||
|
||||
If ``interactive`` is True and stdin is a TTY, automatically launches
|
||||
the interactive credential setup flow so the user can fix the issue
|
||||
in-place. Re-validates after setup succeeds.
|
||||
|
||||
When ``interactive`` is False (e.g. TUI callers), the CredentialError
|
||||
propagates immediately so the caller can handle it with its own UI.
|
||||
"""
|
||||
if self.skip_credential_validation:
|
||||
return
|
||||
|
||||
if not self._interactive:
|
||||
# Let the CredentialError propagate — caller handles UI.
|
||||
validate_agent_credentials(self.graph.nodes)
|
||||
return
|
||||
|
||||
import sys
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
try:
|
||||
validate_agent_credentials(self.graph.nodes)
|
||||
return # All good
|
||||
except CredentialError as e:
|
||||
if not sys.stdin.isatty():
|
||||
raise
|
||||
|
||||
# Interactive: show the error then enter credential setup
|
||||
print(f"\n{e}", file=sys.stderr)
|
||||
|
||||
from framework.credentials.validation import build_setup_session_from_error
|
||||
|
||||
session = build_setup_session_from_error(e, nodes=self.graph.nodes)
|
||||
if not session.missing:
|
||||
raise
|
||||
|
||||
result = session.run_interactive()
|
||||
if not result.success:
|
||||
raise CredentialError(
|
||||
"Credential setup incomplete. "
|
||||
"Run again after configuring the required credentials."
|
||||
) from None
|
||||
|
||||
# Re-validate after setup
|
||||
validate_agent_credentials(self.graph.nodes)
|
||||
|
||||
@staticmethod
|
||||
def _import_agent_module(agent_path: Path):
|
||||
"""Import an agent package from its directory path.
|
||||
@@ -770,7 +732,8 @@ class AgentRunner:
|
||||
# deep_research_agent.agent) so the top-level reload picks up
|
||||
# changes in the entire package — not just __init__.py.
|
||||
stale = [
|
||||
name for name in sys.modules
|
||||
name
|
||||
for name in sys.modules
|
||||
if name == package_name or name.startswith(f"{package_name}.")
|
||||
]
|
||||
for name in stale:
|
||||
@@ -786,6 +749,7 @@ class AgentRunner:
|
||||
storage_path: Path | None = None,
|
||||
model: str | None = None,
|
||||
interactive: bool = True,
|
||||
skip_credential_validation: bool | None = None,
|
||||
) -> "AgentRunner":
|
||||
"""
|
||||
Load an agent from an export folder.
|
||||
@@ -801,6 +765,8 @@ class AgentRunner:
|
||||
model: LLM model to use (reads from agent's default_config if None)
|
||||
interactive: If True (default), offer interactive credential setup.
|
||||
Set to False from TUI callers that handle setup via their own UI.
|
||||
skip_credential_validation: If True, skip credential checks at load time.
|
||||
When None (default), uses the agent module's setting.
|
||||
|
||||
Returns:
|
||||
AgentRunner instance ready to run
|
||||
@@ -870,6 +836,8 @@ class AgentRunner:
|
||||
|
||||
# Read pre-run hooks (e.g., credential_tester needs account selection)
|
||||
skip_cred = getattr(agent_module, "skip_credential_validation", False)
|
||||
if skip_credential_validation is not None:
|
||||
skip_cred = skip_credential_validation
|
||||
needs_acct = getattr(agent_module, "requires_account_selection", False)
|
||||
configure_fn = getattr(agent_module, "configure_for_account", None)
|
||||
list_accts_fn = getattr(agent_module, "list_connected_accounts", None)
|
||||
@@ -906,6 +874,7 @@ class AgentRunner:
|
||||
storage_path=storage_path,
|
||||
model=model,
|
||||
interactive=interactive,
|
||||
skip_credential_validation=skip_credential_validation or False,
|
||||
)
|
||||
|
||||
def register_tool(
|
||||
@@ -1111,7 +1080,9 @@ class AgentRunner:
|
||||
|
||||
# Fail fast if the agent needs an LLM but none was configured
|
||||
if self._llm is None:
|
||||
has_llm_nodes = any(node.node_type == "event_loop" for node in self.graph.nodes)
|
||||
has_llm_nodes = any(
|
||||
node.node_type in ("event_loop", "gcu") for node in self.graph.nodes
|
||||
)
|
||||
if has_llm_nodes:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
@@ -1129,6 +1100,53 @@ class AgentRunner:
|
||||
)
|
||||
raise CredentialError(f"LLM API key not found for model '{self.model}'. {hint}")
|
||||
|
||||
# For GCU nodes: auto-register GCU MCP server if needed, then expand tool lists
|
||||
has_gcu_nodes = any(node.node_type == "gcu" for node in self.graph.nodes)
|
||||
if has_gcu_nodes:
|
||||
from framework.graph.gcu import GCU_MCP_SERVER_CONFIG, GCU_SERVER_NAME
|
||||
|
||||
# Auto-register GCU MCP server if tools aren't loaded yet
|
||||
gcu_tool_names = self._tool_registry.get_server_tool_names(GCU_SERVER_NAME)
|
||||
if not gcu_tool_names:
|
||||
# Resolve relative cwd against agent path
|
||||
gcu_config = dict(GCU_MCP_SERVER_CONFIG)
|
||||
cwd = gcu_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
gcu_config["cwd"] = str((self.agent_path / cwd).resolve())
|
||||
self._tool_registry.register_mcp_server(gcu_config)
|
||||
gcu_tool_names = self._tool_registry.get_server_tool_names(GCU_SERVER_NAME)
|
||||
|
||||
# Expand each GCU node's tools list to include all GCU server tools
|
||||
if gcu_tool_names:
|
||||
for node in self.graph.nodes:
|
||||
if node.node_type == "gcu":
|
||||
existing = set(node.tools)
|
||||
for tool_name in sorted(gcu_tool_names):
|
||||
if tool_name not in existing:
|
||||
node.tools.append(tool_name)
|
||||
|
||||
# For event_loop/gcu nodes: auto-register file tools MCP server, then expand tool lists
|
||||
has_loop_nodes = any(node.node_type in ("event_loop", "gcu") for node in self.graph.nodes)
|
||||
if has_loop_nodes:
|
||||
from framework.graph.files import FILES_MCP_SERVER_CONFIG, FILES_MCP_SERVER_NAME
|
||||
|
||||
files_tool_names = self._tool_registry.get_server_tool_names(FILES_MCP_SERVER_NAME)
|
||||
if not files_tool_names:
|
||||
files_config = dict(FILES_MCP_SERVER_CONFIG)
|
||||
cwd = files_config.get("cwd")
|
||||
if cwd and not Path(cwd).is_absolute():
|
||||
files_config["cwd"] = str((self.agent_path / cwd).resolve())
|
||||
self._tool_registry.register_mcp_server(files_config)
|
||||
files_tool_names = self._tool_registry.get_server_tool_names(FILES_MCP_SERVER_NAME)
|
||||
|
||||
if files_tool_names:
|
||||
for node in self.graph.nodes:
|
||||
if node.node_type in ("event_loop", "gcu"):
|
||||
existing = set(node.tools)
|
||||
for tool_name in sorted(files_tool_names):
|
||||
if tool_name not in existing:
|
||||
node.tools.append(tool_name)
|
||||
|
||||
# Get tools for runtime
|
||||
tools = list(self._tool_registry.get_tools().values())
|
||||
tool_executor = self._tool_registry.get_executor()
|
||||
@@ -1256,6 +1274,7 @@ class AgentRunner:
|
||||
isolation_level=async_ep.isolation_level,
|
||||
priority=async_ep.priority,
|
||||
max_concurrent=async_ep.max_concurrent,
|
||||
max_resurrections=async_ep.max_resurrections,
|
||||
)
|
||||
entry_points.append(ep)
|
||||
|
||||
@@ -1665,7 +1684,9 @@ class AgentRunner:
|
||||
warnings.append(warning_msg)
|
||||
except ImportError:
|
||||
# aden_tools not installed - fall back to direct check
|
||||
has_llm_nodes = any(node.node_type == "event_loop" for node in self.graph.nodes)
|
||||
has_llm_nodes = any(
|
||||
node.node_type in ("event_loop", "gcu") for node in self.graph.nodes
|
||||
)
|
||||
if has_llm_nodes:
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
if api_key_env and not os.environ.get(api_key_env):
|
||||
|
||||
@@ -6,6 +6,7 @@ import importlib.util
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -47,11 +48,20 @@ class ToolRegistry:
|
||||
# and auto-injected at call time for tools that accept them.
|
||||
CONTEXT_PARAMS = frozenset({"workspace_id", "agent_id", "session_id", "data_dir"})
|
||||
|
||||
# Credential directory used for change detection
|
||||
_CREDENTIAL_DIR = Path("~/.hive/credentials/credentials").expanduser()
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, RegisteredTool] = {}
|
||||
self._mcp_clients: list[Any] = [] # List of MCPClient instances
|
||||
self._session_context: dict[str, Any] = {} # Auto-injected context for tools
|
||||
self._provider_index: dict[str, set[str]] = {} # provider -> tool names
|
||||
# MCP resync tracking
|
||||
self._mcp_config_path: Path | None = None # Path used for initial load
|
||||
self._mcp_tool_names: set[str] = set() # Tool names registered from MCP
|
||||
self._mcp_cred_snapshot: set[str] = set() # Credential filenames at MCP load time
|
||||
self._mcp_aden_key_snapshot: str | None = None # ADEN_API_KEY value at MCP load time
|
||||
self._mcp_server_tools: dict[str, set[str]] = {} # server name -> tool names
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -285,6 +295,10 @@ class ToolRegistry:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
def get_server_tool_names(self, server_name: str) -> set[str]:
|
||||
"""Return tool names registered from a specific MCP server."""
|
||||
return set(self._mcp_server_tools.get(server_name, set()))
|
||||
|
||||
def set_session_context(self, **context) -> None:
|
||||
"""
|
||||
Set session context to auto-inject into tool calls.
|
||||
@@ -322,6 +336,9 @@ class ToolRegistry:
|
||||
Args:
|
||||
config_path: Path to an ``mcp_servers.json`` file.
|
||||
"""
|
||||
# Remember config path for potential resync later
|
||||
self._mcp_config_path = Path(config_path)
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
@@ -349,6 +366,10 @@ class ToolRegistry:
|
||||
name = server_config.get("name", "unknown")
|
||||
logger.warning(f"Failed to register MCP server '{name}': {e}")
|
||||
|
||||
# Snapshot credential files and ADEN_API_KEY so we can detect mid-session changes
|
||||
self._mcp_cred_snapshot = self._snapshot_credentials()
|
||||
self._mcp_aden_key_snapshot = os.environ.get("ADEN_API_KEY")
|
||||
|
||||
def register_mcp_server(
|
||||
self,
|
||||
server_config: dict[str, Any],
|
||||
@@ -395,6 +416,9 @@ class ToolRegistry:
|
||||
self._mcp_clients.append(client)
|
||||
|
||||
# Register each tool
|
||||
server_name = server_config["name"]
|
||||
if server_name not in self._mcp_server_tools:
|
||||
self._mcp_server_tools[server_name] = set()
|
||||
count = 0
|
||||
for mcp_tool in client.list_tools():
|
||||
# Convert MCP tool to framework Tool (strips context params from LLM schema)
|
||||
@@ -419,7 +443,15 @@ class ToolRegistry:
|
||||
filtered_context = {
|
||||
k: v for k, v in base_context.items() if k in tool_params
|
||||
}
|
||||
merged_inputs = {**filtered_context, **inputs}
|
||||
# Strip context params from LLM inputs — the framework
|
||||
# values are authoritative (prevents the LLM from passing
|
||||
# e.g. data_dir="/data" and overriding the real path).
|
||||
clean_inputs = {
|
||||
k: v
|
||||
for k, v in inputs.items()
|
||||
if k not in registry_ref.CONTEXT_PARAMS
|
||||
}
|
||||
merged_inputs = {**clean_inputs, **filtered_context}
|
||||
result = client_ref.call_tool(tool_name, merged_inputs)
|
||||
# MCP tools return content array, extract the result
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
@@ -439,6 +471,8 @@ class ToolRegistry:
|
||||
tool,
|
||||
make_mcp_executor(client, mcp_tool.name, self, tool_params),
|
||||
)
|
||||
self._mcp_tool_names.add(mcp_tool.name)
|
||||
self._mcp_server_tools[server_name].add(mcp_tool.name)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Registered {count} tools from MCP server '{config.name}'")
|
||||
@@ -531,6 +565,67 @@ class ToolRegistry:
|
||||
all_names.update(names)
|
||||
return sorted(name for name in self._tools if name in all_names)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# MCP credential resync
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _snapshot_credentials(self) -> set[str]:
|
||||
"""Return the set of credential filenames currently on disk."""
|
||||
try:
|
||||
return set(self._CREDENTIAL_DIR.iterdir()) if self._CREDENTIAL_DIR.is_dir() else set()
|
||||
except OSError:
|
||||
return set()
|
||||
|
||||
def resync_mcp_servers_if_needed(self) -> bool:
|
||||
"""Restart MCP servers if credential files changed since last load.
|
||||
|
||||
Compares the current credential directory listing against the snapshot
|
||||
taken when MCP servers were first loaded. If new files appeared (e.g.
|
||||
user connected an OAuth account mid-session), disconnects all MCP
|
||||
clients and re-loads them so the new subprocess picks up the fresh
|
||||
credentials.
|
||||
|
||||
Returns True if a resync was performed, False otherwise.
|
||||
"""
|
||||
if not self._mcp_clients or self._mcp_config_path is None:
|
||||
return False
|
||||
|
||||
current = self._snapshot_credentials()
|
||||
current_aden_key = os.environ.get("ADEN_API_KEY")
|
||||
files_changed = current != self._mcp_cred_snapshot
|
||||
aden_key_changed = current_aden_key != self._mcp_aden_key_snapshot
|
||||
|
||||
if not files_changed and not aden_key_changed:
|
||||
return False
|
||||
|
||||
reason = (
|
||||
"Credential files and ADEN_API_KEY changed"
|
||||
if files_changed and aden_key_changed
|
||||
else "ADEN_API_KEY changed"
|
||||
if aden_key_changed
|
||||
else "Credential files changed"
|
||||
)
|
||||
logger.info("%s — resyncing MCP servers", reason)
|
||||
|
||||
# 1. Disconnect existing MCP clients
|
||||
for client in self._mcp_clients:
|
||||
try:
|
||||
client.disconnect()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error disconnecting MCP client during resync: {e}")
|
||||
self._mcp_clients.clear()
|
||||
|
||||
# 2. Remove MCP-registered tools
|
||||
for name in self._mcp_tool_names:
|
||||
self._tools.pop(name, None)
|
||||
self._mcp_tool_names.clear()
|
||||
|
||||
# 3. Re-load MCP servers (spawns fresh subprocesses with new credentials)
|
||||
self.load_mcp_config(self._mcp_config_path)
|
||||
|
||||
logger.info("MCP server resync complete")
|
||||
return True
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all MCP client connections."""
|
||||
for client in self._mcp_clients:
|
||||
|
||||
@@ -411,7 +411,12 @@ class AgentRuntime:
|
||||
)
|
||||
continue
|
||||
|
||||
def _make_cron_timer(entry_point_id: str, expr: str, immediate: bool):
|
||||
def _make_cron_timer(
|
||||
entry_point_id: str,
|
||||
expr: str,
|
||||
immediate: bool,
|
||||
idle_timeout: float = 300,
|
||||
):
|
||||
async def _cron_loop():
|
||||
from croniter import croniter
|
||||
|
||||
@@ -442,11 +447,28 @@ class AgentRuntime:
|
||||
await asyncio.sleep(max(0, sleep_secs))
|
||||
continue
|
||||
|
||||
# Gate: skip tick if previous execution still running
|
||||
_stream = self._streams.get(entry_point_id)
|
||||
if _stream and _stream.active_execution_ids:
|
||||
logger.debug(
|
||||
"Cron '%s': execution already in progress, skipping tick",
|
||||
# Gate: skip tick if ANY stream is actively working.
|
||||
# If the execution is idle (no LLM/tool activity
|
||||
# beyond idle_timeout) let the timer proceed —
|
||||
# execute() will cancel the stale execution.
|
||||
_any_active = False
|
||||
_min_idle = float("inf")
|
||||
for _s in self._streams.values():
|
||||
if _s.active_execution_ids:
|
||||
_any_active = True
|
||||
_idle = _s.agent_idle_seconds
|
||||
if _idle < _min_idle:
|
||||
_min_idle = _idle
|
||||
logger.info(
|
||||
"Cron '%s': gate — active=%s, idle=%.1fs, timeout=%ds",
|
||||
entry_point_id,
|
||||
_any_active,
|
||||
_min_idle,
|
||||
idle_timeout,
|
||||
)
|
||||
if _any_active and _min_idle < idle_timeout:
|
||||
logger.info(
|
||||
"Cron '%s': agent actively working, skipping tick",
|
||||
entry_point_id,
|
||||
)
|
||||
self._timer_next_fire[entry_point_id] = (
|
||||
@@ -517,7 +539,12 @@ class AgentRuntime:
|
||||
return _cron_loop
|
||||
|
||||
task = asyncio.create_task(
|
||||
_make_cron_timer(ep_id, cron_expr, run_immediately)()
|
||||
_make_cron_timer(
|
||||
ep_id,
|
||||
cron_expr,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
)()
|
||||
)
|
||||
self._timer_tasks.append(task)
|
||||
logger.info(
|
||||
@@ -529,7 +556,12 @@ class AgentRuntime:
|
||||
|
||||
elif interval and interval > 0:
|
||||
# Fixed interval mode (original behavior)
|
||||
def _make_timer(entry_point_id: str, mins: float, immediate: bool):
|
||||
def _make_timer(
|
||||
entry_point_id: str,
|
||||
mins: float,
|
||||
immediate: bool,
|
||||
idle_timeout: float = 300,
|
||||
):
|
||||
async def _timer_loop():
|
||||
interval_secs = mins * 60
|
||||
_persistent_session_id: str | None = None
|
||||
@@ -551,11 +583,26 @@ class AgentRuntime:
|
||||
await asyncio.sleep(interval_secs)
|
||||
continue
|
||||
|
||||
# Gate: skip tick if previous execution still running
|
||||
_stream = self._streams.get(entry_point_id)
|
||||
if _stream and _stream.active_execution_ids:
|
||||
logger.debug(
|
||||
"Timer '%s': execution already in progress, skipping tick",
|
||||
# Gate: skip tick if agent is actively working.
|
||||
# Gate: skip tick if ANY stream is actively working.
|
||||
_any_active = False
|
||||
_min_idle = float("inf")
|
||||
for _s in self._streams.values():
|
||||
if _s.active_execution_ids:
|
||||
_any_active = True
|
||||
_idle = _s.agent_idle_seconds
|
||||
if _idle < _min_idle:
|
||||
_min_idle = _idle
|
||||
logger.info(
|
||||
"Timer '%s': gate — active=%s, idle=%.1fs, timeout=%ds",
|
||||
entry_point_id,
|
||||
_any_active,
|
||||
_min_idle,
|
||||
idle_timeout,
|
||||
)
|
||||
if _any_active and _min_idle < idle_timeout:
|
||||
logger.info(
|
||||
"Timer '%s': agent actively working, skipping tick",
|
||||
entry_point_id,
|
||||
)
|
||||
self._timer_next_fire[entry_point_id] = (
|
||||
@@ -621,7 +668,14 @@ class AgentRuntime:
|
||||
|
||||
return _timer_loop
|
||||
|
||||
task = asyncio.create_task(_make_timer(ep_id, interval, run_immediately)())
|
||||
task = asyncio.create_task(
|
||||
_make_timer(
|
||||
ep_id,
|
||||
interval,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
)()
|
||||
)
|
||||
self._timer_tasks.append(task)
|
||||
logger.info(
|
||||
"Started timer for entry point '%s' every %s min%s",
|
||||
@@ -961,6 +1015,7 @@ class AgentRuntime:
|
||||
local_ep: str,
|
||||
mins: float,
|
||||
immediate: bool,
|
||||
idle_timeout: float = 300,
|
||||
):
|
||||
async def _timer_loop():
|
||||
interval_secs = mins * 60
|
||||
@@ -990,12 +1045,28 @@ class AgentRuntime:
|
||||
await asyncio.sleep(interval_secs)
|
||||
continue
|
||||
|
||||
# Gate: skip tick if previous execution still running
|
||||
# Gate: skip tick if ANY stream in this graph is actively working.
|
||||
_reg = self._graphs.get(gid)
|
||||
_stream = _reg.streams.get(local_ep) if _reg else None
|
||||
if _stream and _stream.active_execution_ids:
|
||||
logger.debug(
|
||||
"Timer '%s::%s': execution already in progress, skipping tick",
|
||||
_any_active = False
|
||||
_min_idle = float("inf")
|
||||
if _reg:
|
||||
for _sid, _s in _reg.streams.items():
|
||||
if _s.active_execution_ids:
|
||||
_any_active = True
|
||||
_idle = _s.agent_idle_seconds
|
||||
if _idle < _min_idle:
|
||||
_min_idle = _idle
|
||||
logger.info(
|
||||
"Timer '%s::%s': gate — active=%s, idle=%.1fs, timeout=%ds",
|
||||
gid,
|
||||
local_ep,
|
||||
_any_active,
|
||||
_min_idle,
|
||||
idle_timeout,
|
||||
)
|
||||
if _any_active and _min_idle < idle_timeout:
|
||||
logger.info(
|
||||
"Timer '%s::%s': agent actively working, skipping tick",
|
||||
gid,
|
||||
local_ep,
|
||||
)
|
||||
@@ -1066,7 +1137,13 @@ class AgentRuntime:
|
||||
return _timer_loop
|
||||
|
||||
task = asyncio.create_task(
|
||||
_make_timer(graph_id, ep_id, interval, run_immediately)()
|
||||
_make_timer(
|
||||
graph_id,
|
||||
ep_id,
|
||||
interval,
|
||||
run_immediately,
|
||||
idle_timeout=tc.get("idle_timeout_seconds", 300),
|
||||
)()
|
||||
)
|
||||
timer_tasks.append(task)
|
||||
logger.info("Timer task created for '%s::%s': %s", graph_id, ep_id, task)
|
||||
@@ -1174,6 +1251,21 @@ class AgentRuntime:
|
||||
return float("inf")
|
||||
return time.monotonic() - self._last_user_input_time
|
||||
|
||||
@property
|
||||
def agent_idle_seconds(self) -> float:
|
||||
"""Seconds since any stream last had activity (LLM call, tool call, etc.).
|
||||
|
||||
Returns the *minimum* idle time across all streams with active
|
||||
executions. Returns ``float('inf')`` if nothing is running.
|
||||
"""
|
||||
min_idle = float("inf")
|
||||
for reg in self._graphs.values():
|
||||
for stream in reg.streams.values():
|
||||
idle = stream.agent_idle_seconds
|
||||
if idle < min_idle:
|
||||
min_idle = idle
|
||||
return min_idle
|
||||
|
||||
def get_graph_registration(self, graph_id: str) -> _GraphRegistration | None:
|
||||
"""Get the registration for a specific graph (or None)."""
|
||||
return self._graphs.get(graph_id)
|
||||
@@ -1368,6 +1460,23 @@ class AgentRuntime:
|
||||
# Fallback: primary graph
|
||||
return list(self._entry_points.values())
|
||||
|
||||
def get_timer_next_fire_in(self, entry_point_id: str) -> float | None:
|
||||
"""Return seconds until the next timer fire for *entry_point_id*.
|
||||
|
||||
Checks the primary graph's ``_timer_next_fire`` dict as well as
|
||||
all registered secondary graphs. Returns ``None`` when no fire
|
||||
time is recorded (e.g. the timer is currently executing or the
|
||||
entry point is not a timer).
|
||||
"""
|
||||
mono = self._timer_next_fire.get(entry_point_id)
|
||||
if mono is not None:
|
||||
return max(0.0, mono - time.monotonic())
|
||||
for reg in self._graphs.values():
|
||||
mono = reg.timer_next_fire.get(entry_point_id)
|
||||
if mono is not None:
|
||||
return max(0.0, mono - time.monotonic())
|
||||
return None
|
||||
|
||||
def get_stream(self, entry_point_id: str) -> ExecutionStream | None:
|
||||
"""Get a specific execution stream."""
|
||||
return self._streams.get(entry_point_id)
|
||||
|
||||
@@ -88,6 +88,7 @@ class EventType(StrEnum):
|
||||
# LLM streaming observability
|
||||
LLM_TEXT_DELTA = "llm_text_delta"
|
||||
LLM_REASONING_DELTA = "llm_reasoning_delta"
|
||||
LLM_TURN_COMPLETE = "llm_turn_complete"
|
||||
|
||||
# Tool lifecycle
|
||||
TOOL_CALL_STARTED = "tool_call_started"
|
||||
@@ -129,8 +130,15 @@ class EventType(StrEnum):
|
||||
WORKER_ESCALATION_TICKET = "worker_escalation_ticket"
|
||||
QUEEN_INTERVENTION_REQUESTED = "queen_intervention_requested"
|
||||
|
||||
# Execution resurrection (auto-restart on non-fatal failure)
|
||||
EXECUTION_RESURRECTED = "execution_resurrected"
|
||||
|
||||
# Worker lifecycle (session manager → frontend)
|
||||
WORKER_LOADED = "worker_loaded"
|
||||
CREDENTIALS_REQUIRED = "credentials_required"
|
||||
|
||||
# Subagent reports (one-way progress updates from sub-agents)
|
||||
SUBAGENT_REPORT = "subagent_report"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -594,6 +602,36 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_llm_turn_complete(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
stop_reason: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
execution_id: str | None = None,
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM turn completion with stop reason and model metadata."""
|
||||
data: dict = {
|
||||
"stop_reason": stop_reason,
|
||||
"model": model,
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
}
|
||||
if iteration is not None:
|
||||
data["iteration"] = iteration
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.LLM_TURN_COMPLETE,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
# === TOOL LIFECYCLE PUBLISHERS ===
|
||||
|
||||
async def emit_tool_call_started(
|
||||
@@ -655,15 +693,19 @@ class EventBus:
|
||||
content: str,
|
||||
snapshot: str,
|
||||
execution_id: str | None = None,
|
||||
iteration: int | None = None,
|
||||
) -> None:
|
||||
"""Emit client output delta event (client_facing=True nodes)."""
|
||||
data: dict = {"content": content, "snapshot": snapshot}
|
||||
if iteration is not None:
|
||||
data["iteration"] = iteration
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CLIENT_OUTPUT_DELTA,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={"content": content, "snapshot": snapshot},
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -976,6 +1018,30 @@ class EventBus:
|
||||
)
|
||||
)
|
||||
|
||||
async def emit_subagent_report(
|
||||
self,
|
||||
stream_id: str,
|
||||
node_id: str,
|
||||
subagent_id: str,
|
||||
message: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
execution_id: str | None = None,
|
||||
) -> None:
|
||||
"""Emit a one-way progress report from a sub-agent."""
|
||||
await self.publish(
|
||||
AgentEvent(
|
||||
type=EventType.SUBAGENT_REPORT,
|
||||
stream_id=stream_id,
|
||||
node_id=node_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"subagent_id": subagent_id,
|
||||
"message": message,
|
||||
"data": data,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# === QUERY OPERATIONS ===
|
||||
|
||||
def get_history(
|
||||
|
||||
@@ -32,6 +32,19 @@ if TYPE_CHECKING:
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
from framework.storage.session_store import SessionStore
|
||||
|
||||
|
||||
class ExecutionAlreadyRunningError(RuntimeError):
|
||||
"""Raised when attempting to start an execution on a stream that already has one running."""
|
||||
|
||||
def __init__(self, stream_id: str, active_ids: list[str]):
|
||||
self.stream_id = stream_id
|
||||
self.active_ids = active_ids
|
||||
super().__init__(
|
||||
f"Stream '{stream_id}' already has an active execution: {active_ids}. "
|
||||
"Concurrent executions on the same stream are not allowed."
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,9 +69,11 @@ class GraphScopedEventBus(EventBus):
|
||||
# (subscriptions, history, semaphore, etc.) to the real bus.
|
||||
self._real_bus = bus
|
||||
self._scope_graph_id = graph_id
|
||||
self.last_activity_time: float = time.monotonic()
|
||||
|
||||
async def publish(self, event: "AgentEvent") -> None: # type: ignore[override]
|
||||
event.graph_id = self._scope_graph_id
|
||||
self.last_activity_time = time.monotonic()
|
||||
await self._real_bus.publish(event)
|
||||
|
||||
# --- Delegate state-reading methods to the real bus ---
|
||||
@@ -93,6 +108,7 @@ class EntryPointSpec:
|
||||
isolation_level: str = "shared" # "isolated" | "shared" | "synchronized"
|
||||
priority: int = 0
|
||||
max_concurrent: int = 10 # Max concurrent executions for this entry point
|
||||
max_resurrections: int = 3 # Auto-restart on non-fatal failure (0 to disable)
|
||||
|
||||
def get_isolation_level(self) -> IsolationLevel:
|
||||
"""Convert string isolation level to enum."""
|
||||
@@ -233,9 +249,11 @@ class ExecutionStream:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Graph-scoped event bus (stamps graph_id on published events)
|
||||
self._scoped_event_bus = self._event_bus
|
||||
if self._event_bus and self.graph_id:
|
||||
self._scoped_event_bus = GraphScopedEventBus(self._event_bus, self.graph_id)
|
||||
# Always wrap in GraphScopedEventBus so we can track last_activity_time.
|
||||
if self._event_bus:
|
||||
self._scoped_event_bus = GraphScopedEventBus(self._event_bus, self.graph_id or "")
|
||||
else:
|
||||
self._scoped_event_bus = None
|
||||
|
||||
# State
|
||||
self._running = False
|
||||
@@ -265,6 +283,21 @@ class ExecutionStream:
|
||||
"""Return IDs of all currently active executions."""
|
||||
return list(self._active_executions.keys())
|
||||
|
||||
@property
|
||||
def agent_idle_seconds(self) -> float:
|
||||
"""Seconds since the last agent activity (LLM call, tool call, node transition).
|
||||
|
||||
Returns ``float('inf')`` if no event bus is attached or no events have
|
||||
been published yet. When there are no active executions, also returns
|
||||
``float('inf')`` (nothing to be idle *about*).
|
||||
"""
|
||||
if not self._active_executions:
|
||||
return float("inf")
|
||||
bus = self._scoped_event_bus
|
||||
if isinstance(bus, GraphScopedEventBus):
|
||||
return time.monotonic() - bus.last_activity_time
|
||||
return float("inf")
|
||||
|
||||
@property
|
||||
def is_awaiting_input(self) -> bool:
|
||||
"""True when an active execution is blocked waiting for client input."""
|
||||
@@ -292,13 +325,21 @@ class ExecutionStream:
|
||||
"""Return nodes that support message injection (have ``inject_event``).
|
||||
|
||||
Each entry is ``{"node_id": ..., "execution_id": ...}``.
|
||||
The currently executing node is placed first so that
|
||||
``inject_worker_message`` targets the active node, not a stale one.
|
||||
"""
|
||||
injectable: list[dict[str, str]] = []
|
||||
current_first: list[dict[str, str]] = []
|
||||
for exec_id, executor in self._active_executors.items():
|
||||
current = getattr(executor, "current_node_id", None)
|
||||
for node_id, node in executor.node_registry.items():
|
||||
if hasattr(node, "inject_event"):
|
||||
injectable.append({"node_id": node_id, "execution_id": exec_id})
|
||||
return injectable
|
||||
entry = {"node_id": node_id, "execution_id": exec_id}
|
||||
if node_id == current:
|
||||
current_first.append(entry)
|
||||
else:
|
||||
injectable.append(entry)
|
||||
return current_first + injectable
|
||||
|
||||
def _record_execution_result(self, execution_id: str, result: ExecutionResult) -> None:
|
||||
"""Record a completed execution result with retention pruning."""
|
||||
@@ -329,20 +370,21 @@ class ExecutionStream:
|
||||
self._running = False
|
||||
|
||||
# Cancel all active executions
|
||||
tasks_to_wait = []
|
||||
for _, task in self._execution_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
# Task may be attached to a different event loop (e.g., when TUI
|
||||
# uses a separate loop). Log and continue cleanup.
|
||||
if "attached to a different loop" in str(e):
|
||||
logger.warning(f"Task cleanup skipped (different event loop): {e}")
|
||||
else:
|
||||
raise
|
||||
tasks_to_wait.append(task)
|
||||
|
||||
if tasks_to_wait:
|
||||
# Wait briefly — don't block indefinitely if tasks are stuck
|
||||
# in long-running operations (LLM calls, tool executions).
|
||||
_, pending = await asyncio.wait(tasks_to_wait, timeout=5.0)
|
||||
if pending:
|
||||
logger.warning(
|
||||
"%d execution task(s) did not finish within 5s after cancellation",
|
||||
len(pending),
|
||||
)
|
||||
|
||||
self._execution_tasks.clear()
|
||||
self._active_executions.clear()
|
||||
@@ -403,6 +445,27 @@ class ExecutionStream:
|
||||
if not self._running:
|
||||
raise RuntimeError(f"ExecutionStream '{self.stream_id}' is not running")
|
||||
|
||||
# Only one execution may run on a stream at a time — concurrent
|
||||
# executions corrupt shared session state. Cancel any running
|
||||
# execution before starting the new one. The cancelled execution
|
||||
# writes its state to disk before cleanup, and the new execution
|
||||
# runs in the same session directory (via resume_session_id).
|
||||
active = self.active_execution_ids
|
||||
for eid in active:
|
||||
logger.info(
|
||||
"Cancelling running execution %s on stream '%s' before starting new one",
|
||||
eid,
|
||||
self.stream_id,
|
||||
)
|
||||
executor = self._active_executors.get(eid)
|
||||
if executor:
|
||||
for node in executor.node_registry.values():
|
||||
if hasattr(node, "signal_shutdown"):
|
||||
node.signal_shutdown()
|
||||
if hasattr(node, "cancel_current_turn"):
|
||||
node.cancel_current_turn()
|
||||
await self.cancel_execution(eid)
|
||||
|
||||
# When resuming, reuse the original session ID so the execution
|
||||
# continues in the same session directory instead of creating a new one.
|
||||
resume_session_id = session_state.get("resume_session_id") if session_state else None
|
||||
@@ -448,8 +511,37 @@ class ExecutionStream:
|
||||
logger.debug(f"Queued execution {execution_id} for stream {self.stream_id}")
|
||||
return execution_id
|
||||
|
||||
# Errors that indicate a fundamental configuration or environment problem.
|
||||
# Resurrecting after these is pointless — the same error will recur.
|
||||
_FATAL_ERROR_PATTERNS: tuple[str, ...] = (
|
||||
"credential",
|
||||
"authentication",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"api key",
|
||||
"import error",
|
||||
"module not found",
|
||||
"no module named",
|
||||
"permission denied",
|
||||
"invalid api",
|
||||
"configuration error",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_fatal_error(cls, error: str | None) -> bool:
|
||||
"""Return True if the error is life-threatening (no point resurrecting)."""
|
||||
if not error:
|
||||
return False
|
||||
error_lower = error.lower()
|
||||
return any(pat in error_lower for pat in cls._FATAL_ERROR_PATTERNS)
|
||||
|
||||
async def _run_execution(self, ctx: ExecutionContext) -> None:
|
||||
"""Run a single execution within the stream."""
|
||||
"""Run a single execution within the stream.
|
||||
|
||||
Supports automatic resurrection: when the execution fails with a
|
||||
non-fatal error, it restarts from the failed node up to
|
||||
``entry_spec.max_resurrections`` times (default 3).
|
||||
"""
|
||||
execution_id = ctx.id
|
||||
|
||||
# When sharing a session with another entry point (resume_session_id),
|
||||
@@ -457,6 +549,11 @@ class ExecutionStream:
|
||||
# owns the state.json and _write_progress() keeps memory up-to-date.
|
||||
_is_shared_session = bool(ctx.session_state and ctx.session_state.get("resume_session_id"))
|
||||
|
||||
max_resurrections = self.entry_spec.max_resurrections
|
||||
_resurrection_count = 0
|
||||
_current_session_state = ctx.session_state
|
||||
_current_input_data = ctx.input_data
|
||||
|
||||
# Acquire semaphore to limit concurrency
|
||||
async with self._semaphore:
|
||||
ctx.status = "running"
|
||||
@@ -497,12 +594,6 @@ class ExecutionStream:
|
||||
store=self._runtime_log_store, agent_id=self.graph.id
|
||||
)
|
||||
|
||||
# Create executor for this execution.
|
||||
# Each execution gets its own storage under sessions/{exec_id}/
|
||||
# so conversations, spillover, and data files are all scoped
|
||||
# to this execution. The executor sets data_dir via execution
|
||||
# context (contextvars) so data tools and spillover share the
|
||||
# same session-scoped directory.
|
||||
# Derive storage from session_store (graph-specific for secondary
|
||||
# graphs) so that all files — conversations, state, checkpoints,
|
||||
# data — land under the graph's own sessions/ directory, not the
|
||||
@@ -511,43 +602,106 @@ class ExecutionStream:
|
||||
exec_storage = self._session_store.sessions_dir / execution_id
|
||||
else:
|
||||
exec_storage = self._storage.base_path / "sessions" / execution_id
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime_adapter,
|
||||
llm=self._llm,
|
||||
tools=self._tools,
|
||||
tool_executor=self._tool_executor,
|
||||
event_bus=self._scoped_event_bus,
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
storage_path=exec_storage,
|
||||
runtime_logger=runtime_logger,
|
||||
loop_config=self.graph.loop_config,
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
# Write initial session state
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx)
|
||||
|
||||
# Create modified graph with entry point
|
||||
# We need to override the entry_node to use our entry point
|
||||
modified_graph = self._create_modified_graph()
|
||||
|
||||
# Execute
|
||||
result = await executor.execute(
|
||||
graph=modified_graph,
|
||||
goal=self.goal,
|
||||
input_data=ctx.input_data,
|
||||
session_state=ctx.session_state,
|
||||
checkpoint_config=self._checkpoint_config,
|
||||
)
|
||||
# Write initial session state
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx)
|
||||
|
||||
# Clean up executor reference
|
||||
self._active_executors.pop(execution_id, None)
|
||||
# --- Resurrection loop ---
|
||||
# Each iteration creates a fresh executor. On non-fatal failure,
|
||||
# the executor's session_state (memory + resume_from) carries
|
||||
# forward so the next attempt resumes at the failed node.
|
||||
while True:
|
||||
# Create executor for this execution.
|
||||
# Each execution gets its own storage under sessions/{exec_id}/
|
||||
# so conversations, spillover, and data files are all scoped
|
||||
# to this execution. The executor sets data_dir via execution
|
||||
# context (contextvars) so data tools and spillover share the
|
||||
# same session-scoped directory.
|
||||
executor = GraphExecutor(
|
||||
runtime=runtime_adapter,
|
||||
llm=self._llm,
|
||||
tools=self._tools,
|
||||
tool_executor=self._tool_executor,
|
||||
event_bus=self._scoped_event_bus,
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
storage_path=exec_storage,
|
||||
runtime_logger=runtime_logger,
|
||||
loop_config=self.graph.loop_config,
|
||||
accounts_prompt=self._accounts_prompt,
|
||||
accounts_data=self._accounts_data,
|
||||
tool_provider_map=self._tool_provider_map,
|
||||
)
|
||||
# Track executor so inject_input() can reach EventLoopNode instances
|
||||
self._active_executors[execution_id] = executor
|
||||
|
||||
# Execute
|
||||
result = await executor.execute(
|
||||
graph=modified_graph,
|
||||
goal=self.goal,
|
||||
input_data=_current_input_data,
|
||||
session_state=_current_session_state,
|
||||
checkpoint_config=self._checkpoint_config,
|
||||
)
|
||||
|
||||
# Clean up executor reference
|
||||
self._active_executors.pop(execution_id, None)
|
||||
|
||||
# Check if resurrection is appropriate
|
||||
if (
|
||||
not result.success
|
||||
and not result.paused_at
|
||||
and _resurrection_count < max_resurrections
|
||||
and result.session_state
|
||||
and not self._is_fatal_error(result.error)
|
||||
):
|
||||
_resurrection_count += 1
|
||||
logger.warning(
|
||||
"Execution %s failed (%s) — resurrecting (%d/%d) from node '%s'",
|
||||
execution_id,
|
||||
(result.error or "unknown")[:200],
|
||||
_resurrection_count,
|
||||
max_resurrections,
|
||||
result.session_state.get("resume_from", "?"),
|
||||
)
|
||||
|
||||
# Emit resurrection event
|
||||
if self._scoped_event_bus:
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
|
||||
await self._scoped_event_bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.EXECUTION_RESURRECTED,
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
data={
|
||||
"attempt": _resurrection_count,
|
||||
"max_resurrections": max_resurrections,
|
||||
"error": (result.error or "")[:500],
|
||||
"resume_from": result.session_state.get("resume_from"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Resume from the failed node with preserved memory
|
||||
_current_session_state = {
|
||||
**result.session_state,
|
||||
"resume_session_id": execution_id,
|
||||
}
|
||||
# On resurrection, input_data is already in memory —
|
||||
# pass empty so we don't overwrite intermediate results.
|
||||
_current_input_data = {}
|
||||
|
||||
# Brief cooldown before resurrection
|
||||
await asyncio.sleep(2.0)
|
||||
continue
|
||||
|
||||
break # success, fatal failure, or resurrections exhausted
|
||||
|
||||
# Store result with retention
|
||||
self._record_execution_result(execution_id, result)
|
||||
@@ -569,8 +723,7 @@ class ExecutionStream:
|
||||
if not _is_shared_session:
|
||||
await self._write_session_state(execution_id, ctx, result=result)
|
||||
|
||||
# Emit completion/failure event
|
||||
# (skip for pauses — executor already emitted execution_paused)
|
||||
# Emit completion/failure/pause event
|
||||
if self._scoped_event_bus:
|
||||
if result.success:
|
||||
await self._scoped_event_bus.emit_execution_completed(
|
||||
@@ -579,7 +732,17 @@ class ExecutionStream:
|
||||
output=result.output,
|
||||
correlation_id=ctx.correlation_id,
|
||||
)
|
||||
elif not result.paused_at:
|
||||
elif result.paused_at:
|
||||
# The executor returns paused_at on CancelledError but
|
||||
# does NOT emit execution_paused itself — we must emit
|
||||
# it here so the frontend can transition out of "running".
|
||||
await self._scoped_event_bus.emit_execution_paused(
|
||||
stream_id=self.stream_id,
|
||||
node_id=result.paused_at,
|
||||
reason=result.error or "Execution paused",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
else:
|
||||
await self._scoped_event_bus.emit_execution_failed(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
@@ -628,6 +791,25 @@ class ExecutionStream:
|
||||
execution_id, ctx, error="Execution cancelled"
|
||||
)
|
||||
|
||||
# Emit SSE event so the frontend knows the execution stopped.
|
||||
# The executor does NOT emit on CancelledError, so there is no
|
||||
# risk of double-emitting.
|
||||
if self._scoped_event_bus:
|
||||
if has_result and result.paused_at:
|
||||
await self._scoped_event_bus.emit_execution_paused(
|
||||
stream_id=self.stream_id,
|
||||
node_id=result.paused_at,
|
||||
reason="Execution cancelled",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
else:
|
||||
await self._scoped_event_bus.emit_execution_failed(
|
||||
stream_id=self.stream_id,
|
||||
execution_id=execution_id,
|
||||
error="Execution cancelled",
|
||||
correlation_id=ctx.correlation_id,
|
||||
)
|
||||
|
||||
# Don't re-raise - we've handled it and saved state
|
||||
|
||||
except Exception as e:
|
||||
@@ -878,10 +1060,11 @@ class ExecutionStream:
|
||||
task = self._execution_tasks.get(execution_id)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Wait briefly for the task to finish. Don't block indefinitely —
|
||||
# the task may be stuck in a long LLM API call that doesn't
|
||||
# respond to cancellation quickly. The cancellation is already
|
||||
# requested; the task will clean up in the background.
|
||||
done, _ = await asyncio.wait({task}, timeout=5.0)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
"""HIVE_LLM_DEBUG — write every LLM turn to a JSONL file for replay/debugging.
|
||||
|
||||
Set the env var to enable:
|
||||
HIVE_LLM_DEBUG=1 → writes to ~/.hive/llm_logs/<ts>.jsonl
|
||||
HIVE_LLM_DEBUG=/some/path → writes to that directory
|
||||
|
||||
Each line is a JSON object with the full LLM turn: assistant text, tool calls,
|
||||
tool results, and token counts. The file is opened lazily on first call and
|
||||
flushed after every write. Errors are silently swallowed — this must never
|
||||
break the agent.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import IO, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_LLM_DEBUG_RAW = os.environ.get("HIVE_LLM_DEBUG", "").strip()
|
||||
_LLM_DEBUG_ENABLED = _LLM_DEBUG_RAW.lower() in ("1", "true") or (
|
||||
bool(_LLM_DEBUG_RAW) and _LLM_DEBUG_RAW.lower() not in ("0", "false", "")
|
||||
)
|
||||
|
||||
_log_file: IO[str] | None = None
|
||||
_log_ready = False # lazy init guard
|
||||
|
||||
|
||||
def _open_log() -> IO[str] | None:
|
||||
"""Open a JSONL log file. Returns None if disabled."""
|
||||
if not _LLM_DEBUG_ENABLED:
|
||||
return None
|
||||
raw = _LLM_DEBUG_RAW
|
||||
if raw.lower() in ("1", "true"):
|
||||
log_dir = Path.home() / ".hive" / "llm_logs"
|
||||
else:
|
||||
log_dir = Path(raw)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
path = log_dir / f"{ts}.jsonl"
|
||||
logger.info("LLM debug log → %s", path)
|
||||
return open(path, "a", encoding="utf-8") # noqa: SIM115
|
||||
|
||||
|
||||
def log_llm_turn(
|
||||
*,
|
||||
node_id: str,
|
||||
stream_id: str,
|
||||
execution_id: str,
|
||||
iteration: int,
|
||||
assistant_text: str,
|
||||
tool_calls: list[dict[str, Any]],
|
||||
tool_results: list[dict[str, Any]],
|
||||
token_counts: dict[str, Any],
|
||||
) -> None:
|
||||
"""Write one JSONL line capturing a complete LLM turn.
|
||||
|
||||
No-op when HIVE_LLM_DEBUG is not set. Never raises.
|
||||
"""
|
||||
if not _LLM_DEBUG_ENABLED:
|
||||
return
|
||||
try:
|
||||
global _log_file, _log_ready # noqa: PLW0603
|
||||
if not _log_ready:
|
||||
_log_file = _open_log()
|
||||
_log_ready = True
|
||||
if _log_file is None:
|
||||
return
|
||||
record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"node_id": node_id,
|
||||
"stream_id": stream_id,
|
||||
"execution_id": execution_id,
|
||||
"iteration": iteration,
|
||||
"assistant_text": assistant_text,
|
||||
"tool_calls": tool_calls,
|
||||
"tool_results": tool_results,
|
||||
"token_counts": token_counts,
|
||||
}
|
||||
_log_file.write(json.dumps(record, default=str) + "\n")
|
||||
_log_file.flush()
|
||||
except Exception:
|
||||
pass # never break the agent
|
||||
@@ -24,6 +24,8 @@ class ToolCallLog(BaseModel):
|
||||
tool_input: dict[str, Any] = Field(default_factory=dict)
|
||||
result: str = ""
|
||||
is_error: bool = False
|
||||
start_timestamp: str = "" # ISO 8601 timestamp when tool execution started
|
||||
duration_s: float = 0.0 # Wall-clock execution time in seconds
|
||||
|
||||
|
||||
class NodeStepLog(BaseModel):
|
||||
|
||||
@@ -114,6 +114,8 @@ class RuntimeLogger:
|
||||
tool_input=tc.get("tool_input", {}),
|
||||
result=tc.get("content", ""),
|
||||
is_error=tc.get("is_error", False),
|
||||
start_timestamp=tc.get("start_timestamp", ""),
|
||||
duration_s=tc.get("duration_s", 0.0),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""aiohttp Application factory for the Hive HTTP API server."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from aiohttp import web
|
||||
@@ -10,6 +11,52 @@ from framework.server.session_manager import Session, SessionManager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Anchor to the repository root so allowed roots are independent of CWD.
|
||||
# app.py lives at core/framework/server/app.py, so four .parent calls
|
||||
# reach the repo root where exports/ and examples/ live.
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
|
||||
|
||||
_ALLOWED_AGENT_ROOTS: tuple[Path, ...] | None = None
|
||||
|
||||
|
||||
def _get_allowed_agent_roots() -> tuple[Path, ...]:
|
||||
"""Return resolved allowed root directories for agent loading.
|
||||
|
||||
Roots are anchored to the repository root (derived from ``__file__``)
|
||||
so the allowlist is correct regardless of the process's working
|
||||
directory.
|
||||
"""
|
||||
global _ALLOWED_AGENT_ROOTS
|
||||
if _ALLOWED_AGENT_ROOTS is None:
|
||||
_ALLOWED_AGENT_ROOTS = (
|
||||
(_REPO_ROOT / "exports").resolve(),
|
||||
(_REPO_ROOT / "examples").resolve(),
|
||||
(Path.home() / ".hive" / "agents").resolve(),
|
||||
)
|
||||
return _ALLOWED_AGENT_ROOTS
|
||||
|
||||
|
||||
def validate_agent_path(agent_path: str | Path) -> Path:
|
||||
"""Validate that an agent path resolves inside an allowed directory.
|
||||
|
||||
Prevents arbitrary code execution via ``importlib.import_module`` by
|
||||
restricting agent loading to known safe directories: ``exports/``,
|
||||
``examples/``, and ``~/.hive/agents/``.
|
||||
|
||||
Returns the resolved ``Path`` on success.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is outside all allowed roots.
|
||||
"""
|
||||
resolved = Path(agent_path).expanduser().resolve()
|
||||
for root in _get_allowed_agent_roots():
|
||||
if resolved.is_relative_to(root) and resolved != root:
|
||||
return resolved
|
||||
raise ValueError(
|
||||
"agent_path must be inside an allowed directory (exports/, examples/, or ~/.hive/agents/)"
|
||||
)
|
||||
|
||||
|
||||
def safe_path_segment(value: str) -> str:
|
||||
"""Validate a URL path parameter is a safe filesystem name.
|
||||
|
||||
@@ -17,7 +64,7 @@ def safe_path_segment(value: str) -> str:
|
||||
traversal sequences. aiohttp decodes ``%2F`` inside route params,
|
||||
so a raw ``{session_id}`` can contain ``/`` or ``..`` after decoding.
|
||||
"""
|
||||
if "/" in value or "\\" in value or ".." in value:
|
||||
if not value or value == "." or "/" in value or "\\" in value or ".." in value:
|
||||
raise web.HTTPBadRequest(reason="Invalid path parameter")
|
||||
return value
|
||||
|
||||
@@ -138,7 +185,21 @@ def create_app(model: str | None = None) -> web.Application:
|
||||
try:
|
||||
from framework.credentials.validation import ensure_credential_key_env
|
||||
|
||||
# Load ALL credentials: HIVE_CREDENTIAL_KEY, ADEN_API_KEY, and LLM keys
|
||||
ensure_credential_key_env()
|
||||
|
||||
# Auto-generate credential key for web-only users who never ran the TUI
|
||||
if not os.environ.get("HIVE_CREDENTIAL_KEY"):
|
||||
try:
|
||||
from framework.credentials.key_storage import generate_and_save_credential_key
|
||||
|
||||
generate_and_save_credential_key()
|
||||
logger.info(
|
||||
"Generated and persisted HIVE_CREDENTIAL_KEY to ~/.hive/secrets/credential_key"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not auto-persist HIVE_CREDENTIAL_KEY: %s", exc)
|
||||
|
||||
app["credential_store"] = CredentialStore.with_aden_sync()
|
||||
except Exception:
|
||||
logger.debug("Encrypted credential store unavailable, using in-memory fallback")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Credential CRUD routes."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
@@ -7,6 +8,7 @@ from pydantic import SecretStr
|
||||
|
||||
from framework.credentials.models import CredentialKey, CredentialObject
|
||||
from framework.credentials.store import CredentialStore
|
||||
from framework.server.app import validate_agent_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,7 +55,6 @@ async def handle_save_credential(request: web.Request) -> web.Response:
|
||||
|
||||
Body: {"credential_id": "...", "keys": {"key_name": "value", ...}}
|
||||
"""
|
||||
store = _get_store(request)
|
||||
body = await request.json()
|
||||
|
||||
credential_id = body.get("credential_id")
|
||||
@@ -62,6 +63,31 @@ async def handle_save_credential(request: web.Request) -> web.Response:
|
||||
if not credential_id or not keys or not isinstance(keys, dict):
|
||||
return web.json_response({"error": "credential_id and keys are required"}, status=400)
|
||||
|
||||
# ADEN_API_KEY is stored in the encrypted store via key_storage module
|
||||
if credential_id == "aden_api_key":
|
||||
key = keys.get("api_key", "").strip()
|
||||
if not key:
|
||||
return web.json_response({"error": "api_key is required"}, status=400)
|
||||
|
||||
from framework.credentials.key_storage import save_aden_api_key
|
||||
|
||||
save_aden_api_key(key)
|
||||
|
||||
# Immediately sync OAuth tokens from Aden (runs in executor because
|
||||
# _presync_aden_tokens makes blocking HTTP calls to the Aden server).
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
from framework.credentials.validation import _presync_aden_tokens
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, _presync_aden_tokens, CREDENTIAL_SPECS)
|
||||
except Exception as exc:
|
||||
logger.warning("Aden token sync after key save failed: %s", exc)
|
||||
|
||||
return web.json_response({"saved": "aden_api_key"}, status=201)
|
||||
|
||||
store = _get_store(request)
|
||||
cred = CredentialObject(
|
||||
id=credential_id,
|
||||
keys={k: CredentialKey(name=k, value=SecretStr(v)) for k, v in keys.items()},
|
||||
@@ -73,6 +99,13 @@ async def handle_save_credential(request: web.Request) -> web.Response:
|
||||
async def handle_delete_credential(request: web.Request) -> web.Response:
|
||||
"""DELETE /api/credentials/{credential_id} — delete a credential."""
|
||||
credential_id = request.match_info["credential_id"]
|
||||
|
||||
if credential_id == "aden_api_key":
|
||||
from framework.credentials.key_storage import delete_aden_api_key
|
||||
|
||||
delete_aden_api_key()
|
||||
return web.json_response({"deleted": True})
|
||||
|
||||
store = _get_store(request)
|
||||
deleted = store.delete_credential(credential_id)
|
||||
if not deleted:
|
||||
@@ -96,17 +129,47 @@ async def handle_check_agent(request: web.Request) -> web.Response:
|
||||
if not agent_path:
|
||||
return web.json_response({"error": "agent_path is required"}, status=400)
|
||||
|
||||
try:
|
||||
agent_path = str(validate_agent_path(agent_path))
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
|
||||
try:
|
||||
from framework.credentials.setup import load_agent_nodes
|
||||
from framework.credentials.validation import ensure_credential_key_env, validate_agent_credentials
|
||||
from framework.credentials.validation import (
|
||||
ensure_credential_key_env,
|
||||
validate_agent_credentials,
|
||||
)
|
||||
|
||||
# Load env vars from shell config (same as runtime startup)
|
||||
ensure_credential_key_env()
|
||||
|
||||
nodes = load_agent_nodes(agent_path)
|
||||
result = validate_agent_credentials(nodes, verify=verify, raise_on_error=False)
|
||||
result = validate_agent_credentials(
|
||||
nodes, verify=verify, raise_on_error=False, force_refresh=True
|
||||
)
|
||||
|
||||
# If any credential needs Aden, include ADEN_API_KEY as a first-class row
|
||||
if any(c.aden_supported for c in result.credentials):
|
||||
aden_key_status = {
|
||||
"credential_name": "Aden Platform",
|
||||
"credential_id": "aden_api_key",
|
||||
"env_var": "ADEN_API_KEY",
|
||||
"description": "API key from the Developers tab in Settings",
|
||||
"help_url": "https://hive.adenhq.com/",
|
||||
"tools": [],
|
||||
"node_types": [],
|
||||
"available": result.has_aden_key,
|
||||
"valid": None,
|
||||
"validation_message": None,
|
||||
"direct_api_key_supported": True,
|
||||
"aden_supported": True, # renders with "Authorize" button to open Aden
|
||||
"credential_key": "api_key",
|
||||
}
|
||||
required = [aden_key_status] + [_status_to_dict(c) for c in result.credentials]
|
||||
else:
|
||||
required = [_status_to_dict(c) for c in result.credentials]
|
||||
|
||||
required = [_status_to_dict(c) for c in result.credentials]
|
||||
return web.json_response(
|
||||
{
|
||||
"required": required,
|
||||
@@ -134,57 +197,14 @@ def _status_to_dict(c) -> dict:
|
||||
"credential_key": c.credential_key,
|
||||
"valid": c.valid,
|
||||
"validation_message": c.validation_message,
|
||||
"alternative_group": c.alternative_group,
|
||||
}
|
||||
|
||||
|
||||
async def handle_save_aden_key(request: web.Request) -> web.Response:
|
||||
"""POST /api/credentials/aden-key — save the user's ADEN_API_KEY.
|
||||
|
||||
Sets the key in the current process environment and persists it to shell
|
||||
config so future terminals pick it up. Then triggers an Aden token sync
|
||||
so OAuth credentials resolve immediately.
|
||||
|
||||
Body: {"key": "..."}
|
||||
"""
|
||||
import os
|
||||
|
||||
body = await request.json()
|
||||
key = body.get("key", "").strip()
|
||||
if not key:
|
||||
return web.json_response({"error": "key is required"}, status=400)
|
||||
|
||||
os.environ["ADEN_API_KEY"] = key
|
||||
|
||||
# Persist to shell config (best-effort, same pattern as TUI setup)
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import add_env_var_to_shell_config
|
||||
|
||||
add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
key,
|
||||
comment="Aden Platform API key",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not persist ADEN_API_KEY to shell config: %s", exc)
|
||||
|
||||
# Immediately sync OAuth tokens from Aden
|
||||
try:
|
||||
from aden_tools.credentials import CREDENTIAL_SPECS
|
||||
|
||||
from framework.credentials.validation import _presync_aden_tokens
|
||||
|
||||
_presync_aden_tokens(CREDENTIAL_SPECS)
|
||||
except Exception as exc:
|
||||
logger.warning("Aden token sync after key save failed: %s", exc)
|
||||
|
||||
return web.json_response({"saved": True}, status=201)
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register credential routes on the application."""
|
||||
# check-agent and aden-key must be registered BEFORE the {credential_id} wildcard
|
||||
# check-agent must be registered BEFORE the {credential_id} wildcard
|
||||
app.router.add_post("/api/credentials/check-agent", handle_check_agent)
|
||||
app.router.add_post("/api/credentials/aden-key", handle_save_aden_key)
|
||||
app.router.add_get("/api/credentials", handle_list_credentials)
|
||||
app.router.add_post("/api/credentials", handle_save_credential)
|
||||
app.router.add_get("/api/credentials/{credential_id}", handle_get_credential)
|
||||
|
||||
@@ -24,6 +24,7 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.NODE_LOOP_STARTED,
|
||||
EventType.NODE_LOOP_ITERATION,
|
||||
EventType.NODE_LOOP_COMPLETED,
|
||||
EventType.LLM_TURN_COMPLETE,
|
||||
EventType.NODE_ACTION_PLAN,
|
||||
EventType.EDGE_TRAVERSED,
|
||||
EventType.GOAL_PROGRESS,
|
||||
@@ -35,6 +36,8 @@ DEFAULT_EVENT_TYPES = [
|
||||
EventType.NODE_TOOL_DOOM_LOOP,
|
||||
EventType.CONTEXT_COMPACTED,
|
||||
EventType.WORKER_LOADED,
|
||||
EventType.CREDENTIALS_REQUIRED,
|
||||
EventType.SUBAGENT_REPORT,
|
||||
]
|
||||
|
||||
# Keepalive interval in seconds
|
||||
@@ -77,12 +80,41 @@ async def handle_events(request: web.Request) -> web.StreamResponse:
|
||||
# Per-client buffer queue
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
|
||||
|
||||
# Lifecycle events drive frontend state transitions and must never be lost.
|
||||
_CRITICAL_EVENTS = {
|
||||
"execution_started",
|
||||
"execution_completed",
|
||||
"execution_failed",
|
||||
"execution_paused",
|
||||
"client_input_requested",
|
||||
"node_loop_iteration",
|
||||
"node_loop_started",
|
||||
"credentials_required",
|
||||
"worker_loaded",
|
||||
}
|
||||
|
||||
client_disconnected = asyncio.Event()
|
||||
|
||||
async def on_event(event) -> None:
|
||||
"""Push event dict into queue; drop if full."""
|
||||
try:
|
||||
queue.put_nowait(event.to_dict())
|
||||
except asyncio.QueueFull:
|
||||
pass # Drop oldest-undelivered; client will catch up
|
||||
"""Push event dict into queue; drop non-critical events if full."""
|
||||
if client_disconnected.is_set():
|
||||
return
|
||||
|
||||
evt_dict = event.to_dict()
|
||||
if evt_dict.get("type") in _CRITICAL_EVENTS:
|
||||
try:
|
||||
queue.put_nowait(evt_dict)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"SSE client queue full on critical event; disconnecting session='%s'",
|
||||
session.id,
|
||||
)
|
||||
client_disconnected.set()
|
||||
else:
|
||||
try:
|
||||
queue.put_nowait(evt_dict)
|
||||
except asyncio.QueueFull:
|
||||
pass # high-frequency events can be dropped; client will catch up
|
||||
|
||||
# Subscribe to EventBus
|
||||
from framework.server.sse import SSEResponse
|
||||
@@ -94,25 +126,46 @@ async def handle_events(request: web.Request) -> web.StreamResponse:
|
||||
|
||||
sse = SSEResponse()
|
||||
await sse.prepare(request)
|
||||
logger.info(
|
||||
"SSE connected: session='%s', sub_id='%s', types=%d", session.id, sub_id, len(event_types)
|
||||
)
|
||||
|
||||
event_count = 0
|
||||
close_reason = "unknown"
|
||||
try:
|
||||
while True:
|
||||
while not client_disconnected.is_set():
|
||||
try:
|
||||
data = await asyncio.wait_for(queue.get(), timeout=KEEPALIVE_INTERVAL)
|
||||
await sse.send_event(data)
|
||||
event_count += 1
|
||||
if event_count == 1:
|
||||
logger.info(
|
||||
"SSE first event: session='%s', type='%s'", session.id, data.get("type")
|
||||
)
|
||||
except TimeoutError:
|
||||
await sse.send_keepalive()
|
||||
except (ConnectionResetError, ConnectionError):
|
||||
close_reason = "client_disconnected"
|
||||
break
|
||||
except RuntimeError as exc:
|
||||
if "closing transport" in str(exc).lower():
|
||||
break
|
||||
raise
|
||||
except Exception as exc:
|
||||
close_reason = f"error: {exc}"
|
||||
break
|
||||
|
||||
if client_disconnected.is_set() and close_reason == "unknown":
|
||||
close_reason = "slow_client"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
close_reason = "cancelled"
|
||||
finally:
|
||||
event_bus.unsubscribe(sub_id)
|
||||
logger.debug("SSE client disconnected from session '%s'", session.id)
|
||||
try:
|
||||
event_bus.unsubscribe(sub_id)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(
|
||||
"SSE disconnected: session='%s', events_sent=%d, reason='%s'",
|
||||
session.id,
|
||||
event_count,
|
||||
close_reason,
|
||||
)
|
||||
|
||||
return sse.response
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""Execution control routes — trigger, inject, chat, resume, stop, replay."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.credentials.validation import validate_agent_credentials
|
||||
from framework.server.app import resolve_session, safe_path_segment, sessions_dir
|
||||
from framework.server.routes_sessions import _credential_error_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +25,30 @@ async def handle_trigger(request: web.Request) -> web.Response:
|
||||
if not session.worker_runtime:
|
||||
return web.json_response({"error": "No worker loaded in this session"}, status=503)
|
||||
|
||||
# Validate credentials before running — deferred from load time to avoid
|
||||
# showing the modal before the user clicks Run. Runs in executor because
|
||||
# validate_agent_credentials makes blocking HTTP health-check calls.
|
||||
if session.runner:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: validate_agent_credentials(session.runner.graph.nodes)
|
||||
)
|
||||
except Exception as e:
|
||||
agent_path = str(session.worker_path) if session.worker_path else ""
|
||||
resp = _credential_error_response(e, agent_path)
|
||||
if resp is not None:
|
||||
return resp
|
||||
|
||||
# Resync MCP servers if credentials were added since the worker loaded
|
||||
# (e.g. user connected an OAuth account mid-session via Aden UI).
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: session.runner._tool_registry.resync_mcp_servers_if_needed()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("MCP resync failed: %s", e)
|
||||
|
||||
body = await request.json()
|
||||
entry_point_id = body.get("entry_point_id", "default")
|
||||
input_data = body.get("input_data", {})
|
||||
@@ -65,12 +92,10 @@ async def handle_inject(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
async def handle_chat(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/chat — convenience endpoint.
|
||||
"""POST /api/sessions/{session_id}/chat — send a message to the queen.
|
||||
|
||||
Routing priority:
|
||||
1. Worker awaiting input → inject into worker node
|
||||
2. Queen active → inject into queen conversation
|
||||
3. Error — no handler available
|
||||
The input box is permanently connected to the queen agent.
|
||||
Worker input is handled separately via /worker-input.
|
||||
|
||||
Body: {"message": "hello"}
|
||||
"""
|
||||
@@ -84,26 +109,6 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
if not message:
|
||||
return web.json_response({"error": "message is required"}, status=400)
|
||||
|
||||
# 1. Check if worker is awaiting input → inject to worker
|
||||
if session.worker_runtime:
|
||||
node_id, graph_id = session.worker_runtime.find_awaiting_node()
|
||||
|
||||
if node_id:
|
||||
delivered = await session.worker_runtime.inject_input(
|
||||
node_id,
|
||||
message,
|
||||
graph_id=graph_id,
|
||||
is_client_input=True,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "injected",
|
||||
"node_id": node_id,
|
||||
"delivered": delivered,
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Queen active → inject into queen conversation
|
||||
queen_executor = session.queen_executor
|
||||
if queen_executor is not None:
|
||||
node = queen_executor.node_registry.get("queen")
|
||||
@@ -116,8 +121,47 @@ async def handle_chat(request: web.Request) -> web.Response:
|
||||
}
|
||||
)
|
||||
|
||||
# 3. No queen or worker available
|
||||
return web.json_response({"error": "No worker or queen available"}, status=503)
|
||||
return web.json_response({"error": "Queen not available"}, status=503)
|
||||
|
||||
|
||||
async def handle_worker_input(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/worker-input — send input to waiting worker node.
|
||||
|
||||
Auto-discovers the worker node currently awaiting input and injects the message.
|
||||
Returns 404 if no worker node is awaiting input.
|
||||
|
||||
Body: {"message": "..."}
|
||||
"""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
return err
|
||||
|
||||
body = await request.json()
|
||||
message = body.get("message", "")
|
||||
|
||||
if not message:
|
||||
return web.json_response({"error": "message is required"}, status=400)
|
||||
|
||||
if not session.worker_runtime:
|
||||
return web.json_response({"error": "No worker loaded"}, status=503)
|
||||
|
||||
node_id, graph_id = session.worker_runtime.find_awaiting_node()
|
||||
if not node_id:
|
||||
return web.json_response({"error": "No worker node awaiting input"}, status=404)
|
||||
|
||||
delivered = await session.worker_runtime.inject_input(
|
||||
node_id,
|
||||
message,
|
||||
graph_id=graph_id,
|
||||
is_client_input=True,
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "injected",
|
||||
"node_id": node_id,
|
||||
"delivered": delivered,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_goal_progress(request: web.Request) -> web.Response:
|
||||
@@ -163,7 +207,7 @@ async def handle_resume(request: web.Request) -> web.Response:
|
||||
return web.json_response({"error": "Session not found"}, status=404)
|
||||
|
||||
try:
|
||||
state = json.loads(state_path.read_text())
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
return web.json_response({"error": f"Failed to read session: {e}"}, status=500)
|
||||
|
||||
@@ -228,6 +272,14 @@ async def handle_stop(request: web.Request) -> web.Response:
|
||||
if reg is None:
|
||||
continue
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
# Signal shutdown on active nodes to abort in-flight LLM streams
|
||||
for executor in stream._active_executors.values():
|
||||
for node in executor.node_registry.values():
|
||||
if hasattr(node, "signal_shutdown"):
|
||||
node.signal_shutdown()
|
||||
if hasattr(node, "cancel_current_turn"):
|
||||
node.cancel_current_turn()
|
||||
|
||||
cancelled = await stream.cancel_execution(execution_id)
|
||||
if cancelled:
|
||||
return web.json_response(
|
||||
@@ -292,14 +344,31 @@ async def handle_replay(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
|
||||
async def handle_cancel_queen(request: web.Request) -> web.Response:
|
||||
"""POST /api/sessions/{session_id}/cancel-queen — cancel the queen's current LLM turn."""
|
||||
session, err = resolve_session(request)
|
||||
if err:
|
||||
return err
|
||||
queen_executor = session.queen_executor
|
||||
if queen_executor is None:
|
||||
return web.json_response({"cancelled": False, "error": "Queen not active"}, status=404)
|
||||
node = queen_executor.node_registry.get("queen")
|
||||
if node is None or not hasattr(node, "cancel_current_turn"):
|
||||
return web.json_response({"cancelled": False, "error": "Queen node not found"}, status=404)
|
||||
node.cancel_current_turn()
|
||||
return web.json_response({"cancelled": True})
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Register execution control routes."""
|
||||
# Session-primary routes
|
||||
app.router.add_post("/api/sessions/{session_id}/trigger", handle_trigger)
|
||||
app.router.add_post("/api/sessions/{session_id}/inject", handle_inject)
|
||||
app.router.add_post("/api/sessions/{session_id}/chat", handle_chat)
|
||||
app.router.add_post("/api/sessions/{session_id}/worker-input", handle_worker_input)
|
||||
app.router.add_post("/api/sessions/{session_id}/pause", handle_stop)
|
||||
app.router.add_post("/api/sessions/{session_id}/resume", handle_resume)
|
||||
app.router.add_post("/api/sessions/{session_id}/stop", handle_stop)
|
||||
app.router.add_post("/api/sessions/{session_id}/cancel-queen", handle_cancel_queen)
|
||||
app.router.add_post("/api/sessions/{session_id}/replay", handle_replay)
|
||||
app.router.add_get("/api/sessions/{session_id}/goal-progress", handle_goal_progress)
|
||||
|
||||
@@ -10,13 +10,21 @@ from framework.server.app import resolve_session, safe_path_segment
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_graph_spec(session, graph_id: str):
|
||||
"""Get GraphSpec for a graph_id. Returns (graph_spec, None) or (None, error_response)."""
|
||||
def _get_graph_registration(session, graph_id: str):
|
||||
"""Get _GraphRegistration for a graph_id. Returns (reg, None) or (None, error_response)."""
|
||||
if not session.worker_runtime:
|
||||
return None, web.json_response({"error": "No worker loaded in this session"}, status=503)
|
||||
reg = session.worker_runtime.get_graph_registration(graph_id)
|
||||
if reg is None:
|
||||
return None, web.json_response({"error": f"Graph '{graph_id}' not found"}, status=404)
|
||||
return reg, None
|
||||
|
||||
|
||||
def _get_graph_spec(session, graph_id: str):
|
||||
"""Get GraphSpec for a graph_id. Returns (graph_spec, None) or (None, error_response)."""
|
||||
reg, err = _get_graph_registration(session, graph_id)
|
||||
if err:
|
||||
return None, err
|
||||
return reg.graph, None
|
||||
|
||||
|
||||
@@ -37,6 +45,7 @@ def _node_to_dict(node) -> dict:
|
||||
"client_facing": node.client_facing,
|
||||
"success_criteria": node.success_criteria,
|
||||
"system_prompt": node.system_prompt or "",
|
||||
"sub_agents": node.sub_agents,
|
||||
}
|
||||
|
||||
|
||||
@@ -47,10 +56,11 @@ async def handle_list_nodes(request: web.Request) -> web.Response:
|
||||
return err
|
||||
|
||||
graph_id = request.match_info["graph_id"]
|
||||
graph, err = _get_graph_spec(session, graph_id)
|
||||
reg, err = _get_graph_registration(session, graph_id)
|
||||
if err:
|
||||
return err
|
||||
|
||||
graph = reg.graph
|
||||
nodes = [_node_to_dict(n) for n in graph.nodes]
|
||||
|
||||
# Optionally enrich with session progress
|
||||
@@ -70,7 +80,7 @@ async def handle_list_nodes(request: web.Request) -> web.Response:
|
||||
)
|
||||
if state_path.exists():
|
||||
try:
|
||||
state = json.loads(state_path.read_text())
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
progress = state.get("progress", {})
|
||||
visit_counts = progress.get("node_visit_counts", {})
|
||||
failures = progress.get("nodes_with_failures", [])
|
||||
@@ -90,11 +100,28 @@ async def handle_list_nodes(request: web.Request) -> web.Response:
|
||||
{"source": e.source, "target": e.target, "condition": e.condition, "priority": e.priority}
|
||||
for e in graph.edges
|
||||
]
|
||||
rt = session.worker_runtime
|
||||
entry_points = [
|
||||
{
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"entry_node": ep.entry_node,
|
||||
"trigger_type": ep.trigger_type,
|
||||
"trigger_config": ep.trigger_config,
|
||||
**(
|
||||
{"next_fire_in": nf}
|
||||
if rt and (nf := rt.get_timer_next_fire_in(ep.id)) is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for ep in reg.entry_points.values()
|
||||
]
|
||||
return web.json_response(
|
||||
{
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"entry_node": graph.entry_node,
|
||||
"entry_points": entry_points,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -30,7 +30,12 @@ from pathlib import Path
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from framework.server.app import resolve_session, safe_path_segment, sessions_dir
|
||||
from framework.server.app import (
|
||||
resolve_session,
|
||||
safe_path_segment,
|
||||
sessions_dir,
|
||||
validate_agent_path,
|
||||
)
|
||||
from framework.server.session_manager import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -118,6 +123,12 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
model = body.get("model")
|
||||
initial_prompt = body.get("initial_prompt")
|
||||
|
||||
if agent_path:
|
||||
try:
|
||||
agent_path = str(validate_agent_path(agent_path))
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
|
||||
try:
|
||||
if agent_path:
|
||||
# One-step: create session + load worker
|
||||
@@ -143,14 +154,17 @@ async def handle_create_session(request: web.Request) -> web.Response:
|
||||
status=409,
|
||||
)
|
||||
return web.json_response({"error": msg}, status=409)
|
||||
except FileNotFoundError as e:
|
||||
return web.json_response({"error": str(e)}, status=404)
|
||||
except FileNotFoundError:
|
||||
return web.json_response(
|
||||
{"error": f"Agent not found: {agent_path or 'no path'}"},
|
||||
status=404,
|
||||
)
|
||||
except Exception as e:
|
||||
resp = _credential_error_response(e, agent_path)
|
||||
if resp is not None:
|
||||
return resp
|
||||
logger.exception("Error creating session: %s", e)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
return web.json_response({"error": "Internal server error"}, status=500)
|
||||
|
||||
return web.json_response(_session_to_live_dict(session), status=201)
|
||||
|
||||
@@ -182,14 +196,21 @@ async def handle_get_live_session(request: web.Request) -> web.Response:
|
||||
data = _session_to_live_dict(session)
|
||||
|
||||
if session.worker_runtime:
|
||||
rt = session.worker_runtime
|
||||
data["entry_points"] = [
|
||||
{
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"entry_node": ep.entry_node,
|
||||
"trigger_type": ep.trigger_type,
|
||||
"trigger_config": ep.trigger_config,
|
||||
**(
|
||||
{"next_fire_in": nf}
|
||||
if (nf := rt.get_timer_next_fire_in(ep.id)) is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for ep in session.worker_runtime.get_entry_points()
|
||||
for ep in rt.get_entry_points()
|
||||
]
|
||||
data["graphs"] = session.worker_runtime.list_graphs()
|
||||
|
||||
@@ -229,6 +250,11 @@ async def handle_load_worker(request: web.Request) -> web.Response:
|
||||
if not agent_path:
|
||||
return web.json_response({"error": "agent_path is required"}, status=400)
|
||||
|
||||
try:
|
||||
agent_path = str(validate_agent_path(agent_path))
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
|
||||
worker_id = body.get("worker_id")
|
||||
model = body.get("model")
|
||||
|
||||
@@ -241,14 +267,14 @@ async def handle_load_worker(request: web.Request) -> web.Response:
|
||||
)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=409)
|
||||
except FileNotFoundError as e:
|
||||
return web.json_response({"error": str(e)}, status=404)
|
||||
except FileNotFoundError:
|
||||
return web.json_response({"error": f"Agent not found: {agent_path}"}, status=404)
|
||||
except Exception as e:
|
||||
resp = _credential_error_response(e, agent_path)
|
||||
if resp is not None:
|
||||
return resp
|
||||
logger.exception("Error loading worker: %s", e)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
return web.json_response({"error": "Internal server error"}, status=500)
|
||||
|
||||
return web.json_response(_session_to_live_dict(session))
|
||||
|
||||
@@ -307,7 +333,8 @@ async def handle_session_entry_points(request: web.Request) -> web.Response:
|
||||
status=404,
|
||||
)
|
||||
|
||||
eps = session.worker_runtime.get_entry_points() if session.worker_runtime else []
|
||||
rt = session.worker_runtime
|
||||
eps = rt.get_entry_points() if rt else []
|
||||
return web.json_response(
|
||||
{
|
||||
"entry_points": [
|
||||
@@ -316,6 +343,12 @@ async def handle_session_entry_points(request: web.Request) -> web.Response:
|
||||
"name": ep.name,
|
||||
"entry_node": ep.entry_node,
|
||||
"trigger_type": ep.trigger_type,
|
||||
"trigger_config": ep.trigger_config,
|
||||
**(
|
||||
{"next_fire_in": nf}
|
||||
if rt and (nf := rt.get_timer_next_fire_in(ep.id)) is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for ep in eps
|
||||
]
|
||||
@@ -367,7 +400,7 @@ async def handle_list_worker_sessions(request: web.Request) -> web.Response:
|
||||
state_path = d / "state.json"
|
||||
if state_path.exists():
|
||||
try:
|
||||
state = json.loads(state_path.read_text())
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
entry["status"] = state.get("status", "unknown")
|
||||
entry["started_at"] = state.get("started_at")
|
||||
entry["completed_at"] = state.get("completed_at")
|
||||
@@ -406,7 +439,7 @@ async def handle_get_worker_session(request: web.Request) -> web.Response:
|
||||
return web.json_response({"error": "Session not found"}, status=404)
|
||||
|
||||
try:
|
||||
state = json.loads(state_path.read_text())
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
return web.json_response({"error": f"Failed to read session: {e}"}, status=500)
|
||||
|
||||
@@ -434,7 +467,7 @@ async def handle_list_checkpoints(request: web.Request) -> web.Response:
|
||||
if f.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
data = json.loads(f.read_text())
|
||||
data = json.loads(f.read_text(encoding="utf-8"))
|
||||
checkpoints.append(
|
||||
{
|
||||
"checkpoint_id": f.stem,
|
||||
@@ -544,13 +577,14 @@ async def handle_messages(request: web.Request) -> web.Response:
|
||||
if part_file.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
part = json.loads(part_file.read_text())
|
||||
part = json.loads(part_file.read_text(encoding="utf-8"))
|
||||
part["_node_id"] = node_dir.name
|
||||
part.setdefault("created_at", part_file.stat().st_mtime)
|
||||
all_messages.append(part)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
all_messages.sort(key=lambda m: m.get("seq", 0))
|
||||
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
|
||||
|
||||
client_only = request.query.get("client_only", "").lower() in ("true", "1")
|
||||
if client_only:
|
||||
@@ -598,13 +632,16 @@ async def handle_queen_messages(request: web.Request) -> web.Response:
|
||||
if part_file.suffix != ".json":
|
||||
continue
|
||||
try:
|
||||
part = json.loads(part_file.read_text())
|
||||
part = json.loads(part_file.read_text(encoding="utf-8"))
|
||||
part["_node_id"] = node_dir.name
|
||||
# Use file mtime as created_at so frontend can order
|
||||
# queen and worker messages chronologically.
|
||||
part.setdefault("created_at", part_file.stat().st_mtime)
|
||||
all_messages.append(part)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
all_messages.sort(key=lambda m: m.get("seq", 0))
|
||||
all_messages.sort(key=lambda m: m.get("created_at", m.get("seq", 0)))
|
||||
|
||||
# Filter to client-facing messages only
|
||||
all_messages = [
|
||||
|
||||
@@ -159,13 +159,17 @@ class SessionManager:
|
||||
|
||||
# Start queen with worker profile + lifecycle + monitoring tools
|
||||
worker_identity = (
|
||||
build_worker_profile(session.worker_runtime) if session.worker_runtime else None
|
||||
build_worker_profile(session.worker_runtime, agent_path=agent_path)
|
||||
if session.worker_runtime
|
||||
else None
|
||||
)
|
||||
await self._start_queen(
|
||||
session, worker_identity=worker_identity, initial_prompt=initial_prompt
|
||||
)
|
||||
await self._start_queen(session, worker_identity=worker_identity, initial_prompt=initial_prompt)
|
||||
|
||||
# Start health judge
|
||||
if agent_path.name != "hive_coder" and session.worker_runtime:
|
||||
await self._start_judge(session, session.runner._storage_path)
|
||||
# Health judge disabled for simplicity.
|
||||
# if agent_path.name != "hive_coder" and session.worker_runtime:
|
||||
# await self._start_judge(session, session.runner._storage_path)
|
||||
|
||||
except Exception:
|
||||
# If anything fails, tear down the session
|
||||
@@ -212,6 +216,7 @@ class SessionManager:
|
||||
agent_path,
|
||||
model=resolved_model,
|
||||
interactive=False,
|
||||
skip_credential_validation=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -228,6 +233,9 @@ class SessionManager:
|
||||
if runtime and not runtime.is_running:
|
||||
await runtime.start()
|
||||
|
||||
# Clean up stale "active" sessions from previous (dead) processes
|
||||
self._cleanup_stale_active_sessions(agent_path)
|
||||
|
||||
info = runner.info()
|
||||
|
||||
# Update session
|
||||
@@ -251,6 +259,37 @@ class SessionManager:
|
||||
self._loading.discard(session.id)
|
||||
raise
|
||||
|
||||
def _cleanup_stale_active_sessions(self, agent_path: Path) -> None:
|
||||
"""Mark stale 'active' sessions on disk as 'cancelled'.
|
||||
|
||||
When a new runtime starts, any on-disk session still marked 'active'
|
||||
is from a process that no longer exists. 'Paused' sessions are left
|
||||
intact so they remain resumable.
|
||||
"""
|
||||
sessions_path = Path.home() / ".hive" / "agents" / agent_path.name / "sessions"
|
||||
if not sessions_path.exists():
|
||||
return
|
||||
|
||||
for d in sessions_path.iterdir():
|
||||
if not d.is_dir() or not d.name.startswith("session_"):
|
||||
continue
|
||||
state_path = d / "state.json"
|
||||
if not state_path.exists():
|
||||
continue
|
||||
try:
|
||||
state = json.loads(state_path.read_text(encoding="utf-8"))
|
||||
if state.get("status") != "active":
|
||||
continue
|
||||
state["status"] = "cancelled"
|
||||
state.setdefault("result", {})["error"] = "Stale session: runtime restarted"
|
||||
state.setdefault("timestamps", {})["updated_at"] = datetime.now().isoformat()
|
||||
state_path.write_text(json.dumps(state, indent=2), encoding="utf-8")
|
||||
logger.info(
|
||||
"Marked stale session '%s' as cancelled for agent '%s'", d.name, agent_path.name
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to clean up stale session %s: %s", d.name, e)
|
||||
|
||||
async def load_worker(
|
||||
self,
|
||||
session_id: str,
|
||||
@@ -275,9 +314,10 @@ class SessionManager:
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Start judge + notify queen (skip for hive_coder itself)
|
||||
# Notify queen about the loaded worker (skip for hive_coder itself).
|
||||
# Health judge disabled for simplicity.
|
||||
if agent_path.name != "hive_coder" and session.worker_runtime:
|
||||
await self._start_judge(session, session.runner._storage_path)
|
||||
# await self._start_judge(session, session.runner._storage_path)
|
||||
await self._notify_queen_worker_loaded(session)
|
||||
|
||||
# Emit SSE event so the frontend can update UI
|
||||
@@ -450,6 +490,7 @@ class SessionManager:
|
||||
stream_id="queen",
|
||||
storage_path=queen_dir,
|
||||
loop_config=queen_graph.loop_config,
|
||||
execution_id=session.id,
|
||||
)
|
||||
session.queen_executor = executor
|
||||
logger.info(
|
||||
@@ -457,13 +498,19 @@ class SessionManager:
|
||||
len(queen_tools),
|
||||
[t.name for t in queen_tools],
|
||||
)
|
||||
await executor.execute(
|
||||
result = await executor.execute(
|
||||
graph=queen_graph,
|
||||
goal=queen_goal,
|
||||
input_data={"greeting": initial_prompt or "Session started."},
|
||||
session_state={"resume_session_id": session.id},
|
||||
)
|
||||
logger.warning("Queen executor returned (should be forever-alive)")
|
||||
if result.success:
|
||||
logger.warning("Queen executor returned (should be forever-alive)")
|
||||
else:
|
||||
logger.error(
|
||||
"Queen executor failed: %s",
|
||||
result.error or "(no error message)",
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Queen conversation crashed", exc_info=True)
|
||||
finally:
|
||||
@@ -596,7 +643,7 @@ class SessionManager:
|
||||
if node is None or not hasattr(node, "inject_event"):
|
||||
return
|
||||
|
||||
profile = build_worker_profile(session.worker_runtime)
|
||||
profile = build_worker_profile(session.worker_runtime, agent_path=session.worker_path)
|
||||
await node.inject_event(f"[SYSTEM] Worker loaded.{profile}")
|
||||
|
||||
async def _emit_worker_loaded(self, session: Session) -> None:
|
||||
|
||||
@@ -95,7 +95,7 @@ class CheckpointStore:
|
||||
return None
|
||||
|
||||
try:
|
||||
return Checkpoint.model_validate_json(checkpoint_path.read_text())
|
||||
return Checkpoint.model_validate_json(checkpoint_path.read_text(encoding="utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint {checkpoint_id}: {e}")
|
||||
return None
|
||||
@@ -123,7 +123,9 @@ class CheckpointStore:
|
||||
return None
|
||||
|
||||
try:
|
||||
return CheckpointIndex.model_validate_json(self.index_path.read_text())
|
||||
return CheckpointIndex.model_validate_json(
|
||||
self.index_path.read_text(encoding="utf-8")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint index: {e}")
|
||||
return None
|
||||
|
||||
@@ -114,7 +114,7 @@ class SessionStore:
|
||||
if not state_path.exists():
|
||||
return None
|
||||
|
||||
return SessionState.model_validate_json(state_path.read_text())
|
||||
return SessionState.model_validate_json(state_path.read_text(encoding="utf-8"))
|
||||
|
||||
return await asyncio.to_thread(_read)
|
||||
|
||||
@@ -151,7 +151,7 @@ class SessionStore:
|
||||
continue
|
||||
|
||||
try:
|
||||
state = SessionState.model_validate_json(state_path.read_text())
|
||||
state = SessionState.model_validate_json(state_path.read_text(encoding="utf-8"))
|
||||
|
||||
# Apply filters
|
||||
if status and state.status != status:
|
||||
|
||||
@@ -11,10 +11,35 @@ Provides commands:
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _check_pytest_available() -> bool:
|
||||
"""Check if pytest is available as a runnable command.
|
||||
|
||||
Returns True if pytest is found, otherwise prints an error message
|
||||
with install instructions and returns False.
|
||||
"""
|
||||
if shutil.which("pytest") is None:
|
||||
print(
|
||||
"Error: pytest is not installed or not on PATH.\n"
|
||||
"Hive's testing commands require pytest at runtime.\n"
|
||||
"Install it with:\n"
|
||||
"\n"
|
||||
" pip install 'framework[testing]'\n"
|
||||
"\n"
|
||||
"or if using uv:\n"
|
||||
"\n"
|
||||
" uv pip install 'framework[testing]'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def register_testing_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
"""Register testing CLI commands."""
|
||||
|
||||
@@ -105,6 +130,9 @@ def register_testing_commands(subparsers: argparse._SubParsersAction) -> None:
|
||||
|
||||
def cmd_test_run(args: argparse.Namespace) -> int:
|
||||
"""Run tests for an agent using pytest subprocess."""
|
||||
if not _check_pytest_available():
|
||||
return 1
|
||||
|
||||
agent_path = Path(args.agent_path)
|
||||
tests_dir = agent_path / "tests"
|
||||
|
||||
@@ -177,7 +205,8 @@ def cmd_test_run(args: argparse.Namespace) -> int:
|
||||
|
||||
def cmd_test_debug(args: argparse.Namespace) -> int:
|
||||
"""Debug a failed test by re-running with verbose output."""
|
||||
import subprocess
|
||||
if not _check_pytest_available():
|
||||
return 1
|
||||
|
||||
agent_path = Path(args.agent_path)
|
||||
test_name = args.test_name
|
||||
@@ -190,7 +219,7 @@ def cmd_test_debug(args: argparse.Namespace) -> int:
|
||||
# Find which file contains the test
|
||||
test_file = None
|
||||
for py_file in tests_dir.glob("test_*.py"):
|
||||
content = py_file.read_text()
|
||||
content = py_file.read_text(encoding="utf-8")
|
||||
if f"def {test_name}" in content or f"async def {test_name}" in content:
|
||||
test_file = py_file
|
||||
break
|
||||
@@ -238,7 +267,7 @@ def _scan_test_files(tests_dir: Path) -> list[dict]:
|
||||
|
||||
for test_file in sorted(tests_dir.glob("test_*.py")):
|
||||
try:
|
||||
content = test_file.read_text()
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
tree = ast.parse(content)
|
||||
|
||||
for node in ast.walk(tree):
|
||||
|
||||
@@ -33,12 +33,18 @@ Usage::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from framework.credentials.models import CredentialError
|
||||
from framework.runner.preload_validation import credential_errors_to_json, validate_credentials
|
||||
from framework.runtime.event_bus import AgentEvent, EventType
|
||||
from framework.server.app import validate_agent_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.agent_runtime import AgentRuntime
|
||||
@@ -60,7 +66,7 @@ class WorkerSessionAdapter:
|
||||
worker_path: Path | None = None
|
||||
|
||||
|
||||
def build_worker_profile(runtime: AgentRuntime) -> str:
|
||||
def build_worker_profile(runtime: AgentRuntime, agent_path: Path | str | None = None) -> str:
|
||||
"""Build a worker capability profile from its graph/goal definition.
|
||||
|
||||
Injected into the queen's system prompt so it knows what the worker
|
||||
@@ -71,6 +77,8 @@ def build_worker_profile(runtime: AgentRuntime) -> str:
|
||||
|
||||
lines = ["\n\n# Worker Profile"]
|
||||
lines.append(f"Agent: {runtime.graph_id}")
|
||||
if agent_path:
|
||||
lines.append(f"Path: {agent_path}")
|
||||
lines.append(f"Goal: {goal.name}")
|
||||
if goal.description:
|
||||
lines.append(f"Description: {goal.description}")
|
||||
@@ -151,6 +159,11 @@ def register_queen_lifecycle_tools(
|
||||
|
||||
# --- start_worker ---------------------------------------------------------
|
||||
|
||||
# How long to wait for credential validation + MCP resync before
|
||||
# proceeding with trigger anyway. These are pre-flight checks that
|
||||
# should not block the queen indefinitely.
|
||||
_START_PREFLIGHT_TIMEOUT = 15 # seconds
|
||||
|
||||
async def start_worker(task: str) -> str:
|
||||
"""Start the worker agent with a task description.
|
||||
|
||||
@@ -162,6 +175,51 @@ def register_queen_lifecycle_tools(
|
||||
return json.dumps({"error": "No worker loaded in this session."})
|
||||
|
||||
try:
|
||||
# Pre-flight: validate credentials and resync MCP servers.
|
||||
# Both are blocking I/O (HTTP health-checks, subprocess spawns)
|
||||
# so they run in a thread-pool executor. We cap the total
|
||||
# preflight time so the queen never hangs waiting.
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def _preflight():
|
||||
cred_error: CredentialError | None = None
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: validate_credentials(
|
||||
runtime.graph.nodes,
|
||||
interactive=False,
|
||||
skip=False,
|
||||
),
|
||||
)
|
||||
except CredentialError as e:
|
||||
cred_error = e
|
||||
|
||||
runner = getattr(session, "runner", None)
|
||||
if runner:
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: runner._tool_registry.resync_mcp_servers_if_needed(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("MCP resync failed: %s", e)
|
||||
|
||||
# Re-raise CredentialError after MCP resync so both steps
|
||||
# get a chance to run before we bail.
|
||||
if cred_error is not None:
|
||||
raise cred_error
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(_preflight(), timeout=_START_PREFLIGHT_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
"start_worker preflight timed out after %ds — proceeding with trigger",
|
||||
_START_PREFLIGHT_TIMEOUT,
|
||||
)
|
||||
except CredentialError:
|
||||
raise # handled below
|
||||
|
||||
# Resume timers in case they were paused by a previous stop_worker
|
||||
runtime.resume_timers()
|
||||
|
||||
@@ -185,6 +243,23 @@ def register_queen_lifecycle_tools(
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
except CredentialError as e:
|
||||
# Build structured error with per-credential details so the
|
||||
# queen can report exactly what's missing and how to fix it.
|
||||
error_payload = credential_errors_to_json(e)
|
||||
error_payload["agent_path"] = str(getattr(session, "worker_path", "") or "")
|
||||
|
||||
# Emit SSE event so the frontend opens the credentials modal
|
||||
bus = getattr(session, "event_bus", None)
|
||||
if bus is not None:
|
||||
await bus.publish(
|
||||
AgentEvent(
|
||||
type=EventType.CREDENTIALS_REQUIRED,
|
||||
stream_id="queen",
|
||||
data=error_payload,
|
||||
)
|
||||
)
|
||||
return json.dumps(error_payload)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to start worker: {e}"})
|
||||
|
||||
@@ -211,30 +286,40 @@ def register_queen_lifecycle_tools(
|
||||
# --- stop_worker ----------------------------------------------------------
|
||||
|
||||
async def stop_worker() -> str:
|
||||
"""Cancel all active worker executions.
|
||||
"""Cancel all active worker executions across all graphs.
|
||||
|
||||
Stops the worker gracefully. Returns the IDs of cancelled executions.
|
||||
Stops the worker immediately. Returns the IDs of cancelled executions.
|
||||
"""
|
||||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return json.dumps({"error": "No worker loaded in this session."})
|
||||
|
||||
cancelled = []
|
||||
graph_id = runtime.graph_id
|
||||
|
||||
# Get the primary graph's streams
|
||||
reg = runtime.get_graph_registration(graph_id)
|
||||
if reg is None:
|
||||
return json.dumps({"error": "Worker graph not found"})
|
||||
# Iterate ALL registered graphs — multiple entrypoint requests
|
||||
# can spawn executions in different graphs within the same session.
|
||||
for graph_id in runtime.list_graphs():
|
||||
reg = runtime.get_graph_registration(graph_id)
|
||||
if reg is None:
|
||||
continue
|
||||
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
for exec_id in list(stream.active_execution_ids):
|
||||
try:
|
||||
ok = await stream.cancel_execution(exec_id)
|
||||
if ok:
|
||||
cancelled.append(exec_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cancel %s: %s", exec_id, e)
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
# Signal shutdown on all active EventLoopNodes first so they
|
||||
# exit cleanly and cancel their in-flight LLM streams.
|
||||
for executor in stream._active_executors.values():
|
||||
for node in executor.node_registry.values():
|
||||
if hasattr(node, "signal_shutdown"):
|
||||
node.signal_shutdown()
|
||||
if hasattr(node, "cancel_current_turn"):
|
||||
node.cancel_current_turn()
|
||||
|
||||
for exec_id in list(stream.active_execution_ids):
|
||||
try:
|
||||
ok = await stream.cancel_execution(exec_id)
|
||||
if ok:
|
||||
cancelled.append(exec_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cancel %s: %s", exec_id, e)
|
||||
|
||||
# Pause timers so the next tick doesn't restart execution
|
||||
runtime.pause_timers()
|
||||
@@ -260,11 +345,46 @@ def register_queen_lifecycle_tools(
|
||||
|
||||
# --- get_worker_status ----------------------------------------------------
|
||||
|
||||
async def get_worker_status() -> str:
|
||||
"""Check if the worker is idle, running, or waiting for user input.
|
||||
def _get_event_bus():
|
||||
"""Get the session's event bus for querying history."""
|
||||
return getattr(session, "event_bus", None)
|
||||
|
||||
Returns worker identity, execution state, active node, and iteration count.
|
||||
_status_last_called: dict[str, float] = {} # {"ts": monotonic time}
|
||||
_STATUS_COOLDOWN = 30.0 # seconds between full status checks
|
||||
|
||||
async def get_worker_status(last_n: int = 20) -> str:
|
||||
"""Comprehensive worker status: state, execution details, and recent activity.
|
||||
|
||||
Returns everything the queen needs in a single call:
|
||||
- Identity and high-level state (idle / running / waiting_for_input)
|
||||
- Active execution details (elapsed time, current node, iteration)
|
||||
- Running tool calls (started but not yet completed)
|
||||
- Recent completed tool calls (name, success/error)
|
||||
- Node transitions (execution path)
|
||||
- Retries, stalls, and constraint violations
|
||||
- Goal progress and token consumption
|
||||
|
||||
Args:
|
||||
last_n: Number of recent events to include per category (default 20).
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
now = _time.monotonic()
|
||||
last = _status_last_called.get("ts", 0.0)
|
||||
if now - last < _STATUS_COOLDOWN:
|
||||
remaining = int(_STATUS_COOLDOWN - (now - last))
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "cooldown",
|
||||
"message": (
|
||||
f"Status was checked {int(now - last)}s ago. "
|
||||
f"Wait {remaining}s before checking again. "
|
||||
"Do NOT call this tool in a loop — wait for user input instead."
|
||||
),
|
||||
}
|
||||
)
|
||||
_status_last_called["ts"] = now
|
||||
|
||||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return json.dumps({"status": "not_loaded", "message": "No worker loaded."})
|
||||
@@ -275,55 +395,235 @@ def register_queen_lifecycle_tools(
|
||||
if reg is None:
|
||||
return json.dumps({"status": "not_loaded"})
|
||||
|
||||
base = {
|
||||
result: dict[str, Any] = {
|
||||
"worker_graph_id": graph_id,
|
||||
"worker_goal": getattr(goal, "name", graph_id),
|
||||
}
|
||||
|
||||
# --- Execution state ---
|
||||
active_execs = []
|
||||
for ep_id, stream in reg.streams.items():
|
||||
for exec_id in stream.active_execution_ids:
|
||||
active_execs.append(
|
||||
{
|
||||
"execution_id": exec_id,
|
||||
"entry_point": ep_id,
|
||||
}
|
||||
)
|
||||
exec_info: dict[str, Any] = {
|
||||
"execution_id": exec_id,
|
||||
"entry_point": ep_id,
|
||||
}
|
||||
ctx = stream.get_context(exec_id)
|
||||
if ctx:
|
||||
from datetime import datetime
|
||||
|
||||
elapsed = (datetime.now() - ctx.started_at).total_seconds()
|
||||
exec_info["elapsed_seconds"] = round(elapsed, 1)
|
||||
exec_info["exec_status"] = ctx.status
|
||||
active_execs.append(exec_info)
|
||||
|
||||
if not active_execs:
|
||||
return json.dumps(
|
||||
{
|
||||
**base,
|
||||
"status": "idle",
|
||||
"message": "Worker has no active executions.",
|
||||
}
|
||||
result["status"] = "idle"
|
||||
result["message"] = "Worker has no active executions."
|
||||
else:
|
||||
waiting_nodes = []
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
waiting_nodes.extend(stream.get_waiting_nodes())
|
||||
|
||||
result["status"] = "waiting_for_input" if waiting_nodes else "running"
|
||||
result["active_executions"] = active_execs
|
||||
if waiting_nodes:
|
||||
result["waiting_node_id"] = waiting_nodes[0]["node_id"]
|
||||
|
||||
result["agent_idle_seconds"] = round(runtime.agent_idle_seconds, 1)
|
||||
|
||||
# --- EventBus enrichment ---
|
||||
bus = _get_event_bus()
|
||||
if not bus:
|
||||
return json.dumps(result)
|
||||
|
||||
try:
|
||||
# Pending user question (from ask_user tool)
|
||||
if result.get("status") == "waiting_for_input":
|
||||
input_events = bus.get_history(event_type=EventType.CLIENT_INPUT_REQUESTED, limit=1)
|
||||
if input_events:
|
||||
prompt = input_events[0].data.get("prompt", "")
|
||||
if prompt:
|
||||
result["pending_question"] = prompt
|
||||
# Current node
|
||||
edge_events = bus.get_history(event_type=EventType.EDGE_TRAVERSED, limit=1)
|
||||
if edge_events:
|
||||
target = edge_events[0].data.get("target_node")
|
||||
if target:
|
||||
result["current_node"] = target
|
||||
|
||||
# Current iteration
|
||||
iter_events = bus.get_history(event_type=EventType.NODE_LOOP_ITERATION, limit=1)
|
||||
if iter_events:
|
||||
result["current_iteration"] = iter_events[0].data.get("iteration")
|
||||
|
||||
# Running tool calls (started but not yet completed)
|
||||
tool_started = bus.get_history(event_type=EventType.TOOL_CALL_STARTED, limit=last_n * 2)
|
||||
tool_completed = bus.get_history(
|
||||
event_type=EventType.TOOL_CALL_COMPLETED, limit=last_n * 2
|
||||
)
|
||||
completed_ids = {
|
||||
evt.data.get("tool_use_id") for evt in tool_completed if evt.data.get("tool_use_id")
|
||||
}
|
||||
running = [
|
||||
evt
|
||||
for evt in tool_started
|
||||
if evt.data.get("tool_use_id") and evt.data.get("tool_use_id") not in completed_ids
|
||||
]
|
||||
if running:
|
||||
result["running_tools"] = [
|
||||
{
|
||||
"tool": evt.data.get("tool_name"),
|
||||
"node": evt.node_id,
|
||||
"started_at": evt.timestamp.isoformat(),
|
||||
"input_preview": str(evt.data.get("tool_input", ""))[:200],
|
||||
}
|
||||
for evt in running
|
||||
]
|
||||
|
||||
# Check if the worker is waiting for user input
|
||||
waiting_nodes = []
|
||||
for _ep_id, stream in reg.streams.items():
|
||||
waiting_nodes.extend(stream.get_waiting_nodes())
|
||||
# Recent completed tool calls
|
||||
if tool_completed:
|
||||
result["recent_tool_calls"] = [
|
||||
{
|
||||
"tool": evt.data.get("tool_name"),
|
||||
"error": bool(evt.data.get("is_error")),
|
||||
"node": evt.node_id,
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
for evt in tool_completed[:last_n]
|
||||
]
|
||||
|
||||
status = "waiting_for_input" if waiting_nodes else "running"
|
||||
result = {
|
||||
**base,
|
||||
"status": status,
|
||||
"active_executions": active_execs,
|
||||
}
|
||||
if waiting_nodes:
|
||||
result["waiting_node_id"] = waiting_nodes[0]["node_id"]
|
||||
return json.dumps(result)
|
||||
# Node transitions
|
||||
edges = bus.get_history(event_type=EventType.EDGE_TRAVERSED, limit=last_n)
|
||||
if edges:
|
||||
result["node_transitions"] = [
|
||||
{
|
||||
"from": evt.data.get("source_node"),
|
||||
"to": evt.data.get("target_node"),
|
||||
"condition": evt.data.get("edge_condition"),
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
for evt in edges
|
||||
]
|
||||
|
||||
# Retries
|
||||
retries = bus.get_history(event_type=EventType.NODE_RETRY, limit=last_n)
|
||||
if retries:
|
||||
result["retries"] = [
|
||||
{
|
||||
"node": evt.node_id,
|
||||
"retry_count": evt.data.get("retry_count"),
|
||||
"error": evt.data.get("error", "")[:200],
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
for evt in retries
|
||||
]
|
||||
|
||||
# Stalls and doom loops
|
||||
stalls = bus.get_history(event_type=EventType.NODE_STALLED, limit=5)
|
||||
doom_loops = bus.get_history(event_type=EventType.NODE_TOOL_DOOM_LOOP, limit=5)
|
||||
issues = []
|
||||
for evt in stalls:
|
||||
issues.append(
|
||||
{
|
||||
"type": "stall",
|
||||
"node": evt.node_id,
|
||||
"reason": evt.data.get("reason", "")[:200],
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
)
|
||||
for evt in doom_loops:
|
||||
issues.append(
|
||||
{
|
||||
"type": "tool_doom_loop",
|
||||
"node": evt.node_id,
|
||||
"description": evt.data.get("description", "")[:200],
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
)
|
||||
if issues:
|
||||
result["issues"] = issues
|
||||
|
||||
# Constraint violations
|
||||
violations = bus.get_history(event_type=EventType.CONSTRAINT_VIOLATION, limit=5)
|
||||
if violations:
|
||||
result["constraint_violations"] = [
|
||||
{
|
||||
"constraint": evt.data.get("constraint_id"),
|
||||
"description": evt.data.get("description", "")[:200],
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
for evt in violations
|
||||
]
|
||||
|
||||
# Goal progress
|
||||
try:
|
||||
progress = await runtime.get_goal_progress()
|
||||
if progress:
|
||||
result["goal_progress"] = progress
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Token summary
|
||||
llm_events = bus.get_history(event_type=EventType.LLM_TURN_COMPLETE, limit=200)
|
||||
if llm_events:
|
||||
total_in = sum(evt.data.get("input_tokens", 0) or 0 for evt in llm_events)
|
||||
total_out = sum(evt.data.get("output_tokens", 0) or 0 for evt in llm_events)
|
||||
result["token_summary"] = {
|
||||
"llm_turns": len(llm_events),
|
||||
"input_tokens": total_in,
|
||||
"output_tokens": total_out,
|
||||
"total_tokens": total_in + total_out,
|
||||
}
|
||||
|
||||
# Execution completions/failures
|
||||
exec_completed = bus.get_history(event_type=EventType.EXECUTION_COMPLETED, limit=5)
|
||||
exec_failed = bus.get_history(event_type=EventType.EXECUTION_FAILED, limit=5)
|
||||
if exec_completed or exec_failed:
|
||||
result["execution_outcomes"] = []
|
||||
for evt in exec_completed:
|
||||
result["execution_outcomes"].append(
|
||||
{
|
||||
"outcome": "completed",
|
||||
"execution_id": evt.execution_id,
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
)
|
||||
for evt in exec_failed:
|
||||
result["execution_outcomes"].append(
|
||||
{
|
||||
"outcome": "failed",
|
||||
"execution_id": evt.execution_id,
|
||||
"error": evt.data.get("error", "")[:200],
|
||||
"time": evt.timestamp.isoformat(),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass # Non-critical enrichment
|
||||
|
||||
return json.dumps(result, default=str, ensure_ascii=False)
|
||||
|
||||
_status_tool = Tool(
|
||||
name="get_worker_status",
|
||||
description=(
|
||||
"Check the worker agent's current state: idle (no execution), "
|
||||
"running (actively processing), or waiting_for_input (blocked on "
|
||||
"user response). Returns execution details."
|
||||
"Get comprehensive worker status: state (idle/running/waiting_for_input), "
|
||||
"execution details (elapsed time, current node, iteration), "
|
||||
"recent tool calls, running tools, node transitions, retries, "
|
||||
"stalls, constraint violations, goal progress, and token consumption. "
|
||||
"One call gives the queen a complete picture."
|
||||
),
|
||||
parameters={"type": "object", "properties": {}},
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"last_n": {
|
||||
"type": "integer",
|
||||
"description": "Number of recent events per category (default 20)",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
registry.register("get_worker_status", _status_tool, lambda inputs: get_worker_status())
|
||||
registry.register("get_worker_status", _status_tool, lambda inputs: get_worker_status(**inputs))
|
||||
tools_registered += 1
|
||||
|
||||
# --- inject_worker_message ------------------------------------------------
|
||||
@@ -387,6 +687,105 @@ def register_queen_lifecycle_tools(
|
||||
)
|
||||
tools_registered += 1
|
||||
|
||||
# --- list_credentials -----------------------------------------------------
|
||||
|
||||
async def list_credentials(credential_id: str = "") -> str:
|
||||
"""List all authorized credentials (Aden OAuth + local encrypted store).
|
||||
|
||||
Returns credential IDs, aliases, status, and identity metadata.
|
||||
Never returns secret values. Optionally filter by credential_id.
|
||||
"""
|
||||
try:
|
||||
# Primary: CredentialStoreAdapter sees both Aden OAuth and local accounts
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
store = CredentialStoreAdapter.default()
|
||||
all_accounts = store.get_all_account_info()
|
||||
|
||||
# Filter by credential_id / provider if requested
|
||||
if credential_id:
|
||||
all_accounts = [
|
||||
a
|
||||
for a in all_accounts
|
||||
if a.get("credential_id", "").startswith(credential_id)
|
||||
or a.get("provider", "") == credential_id
|
||||
]
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"count": len(all_accounts),
|
||||
"credentials": all_accounts,
|
||||
},
|
||||
default=str,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to list credentials: {e}"})
|
||||
|
||||
# Fallback: local encrypted store only
|
||||
try:
|
||||
from framework.credentials.local.registry import LocalCredentialRegistry
|
||||
|
||||
registry = LocalCredentialRegistry.default()
|
||||
accounts = registry.list_accounts(
|
||||
credential_id=credential_id or None,
|
||||
)
|
||||
|
||||
credentials = []
|
||||
for info in accounts:
|
||||
entry: dict[str, Any] = {
|
||||
"credential_id": info.credential_id,
|
||||
"alias": info.alias,
|
||||
"storage_id": info.storage_id,
|
||||
"status": info.status,
|
||||
"created_at": info.created_at.isoformat() if info.created_at else None,
|
||||
"last_validated": (
|
||||
info.last_validated.isoformat() if info.last_validated else None
|
||||
),
|
||||
}
|
||||
identity = info.identity.to_dict()
|
||||
if identity:
|
||||
entry["identity"] = identity
|
||||
credentials.append(entry)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"count": len(credentials),
|
||||
"credentials": credentials,
|
||||
"location": "~/.hive/credentials",
|
||||
},
|
||||
default=str,
|
||||
)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": f"Failed to list credentials: {e}"})
|
||||
|
||||
_list_creds_tool = Tool(
|
||||
name="list_credentials",
|
||||
description=(
|
||||
"List all authorized credentials in the local store. Returns credential IDs, "
|
||||
"aliases, status (active/failed/unknown), and identity metadata — never secret "
|
||||
"values. Optionally filter by credential_id (e.g. 'brave_search')."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"credential_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Filter to a specific credential type (e.g. 'brave_search'). "
|
||||
"Omit to list all credentials."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
registry.register(
|
||||
"list_credentials", _list_creds_tool, lambda inputs: list_credentials(**inputs)
|
||||
)
|
||||
tools_registered += 1
|
||||
|
||||
# --- load_built_agent (server context only) --------------------------------
|
||||
|
||||
if session_manager is not None and manager_session_id is not None:
|
||||
@@ -400,16 +799,18 @@ def register_queen_lifecycle_tools(
|
||||
"""
|
||||
runtime = _get_runtime()
|
||||
if runtime is not None:
|
||||
return json.dumps(
|
||||
{
|
||||
"error": "A worker is already loaded in this session. "
|
||||
"Unload it first or open a new tab."
|
||||
}
|
||||
)
|
||||
try:
|
||||
await session_manager.unload_worker(manager_session_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to unload existing worker: %s", e, exc_info=True)
|
||||
return json.dumps({"error": f"Failed to unload existing worker: {e}"})
|
||||
|
||||
resolved_path = Path(agent_path).resolve()
|
||||
try:
|
||||
resolved_path = validate_agent_path(agent_path)
|
||||
except ValueError as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
if not resolved_path.exists():
|
||||
return json.dumps({"error": f"Agent path does not exist: {resolved_path}"})
|
||||
return json.dumps({"error": f"Agent path does not exist: {agent_path}"})
|
||||
|
||||
try:
|
||||
updated_session = await session_manager.load_worker(
|
||||
|
||||
@@ -18,7 +18,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -48,10 +47,14 @@ def register_graph_tools(registry: ToolRegistry, runtime: AgentRuntime) -> int:
|
||||
"""
|
||||
from framework.runner.runner import AgentRunner
|
||||
from framework.runtime.execution_stream import EntryPointSpec
|
||||
from framework.server.app import validate_agent_path
|
||||
|
||||
path = Path(agent_path).resolve()
|
||||
try:
|
||||
path = validate_agent_path(agent_path)
|
||||
except ValueError as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
if not path.exists():
|
||||
return json.dumps({"error": f"Agent path does not exist: {path}"})
|
||||
return json.dumps({"error": f"Agent path does not exist: {agent_path}"})
|
||||
|
||||
try:
|
||||
runner = AgentRunner.load(path)
|
||||
|
||||
@@ -473,14 +473,12 @@ class AdenTUI(App):
|
||||
from pathlib import Path
|
||||
|
||||
from framework.graph.executor import GraphExecutor
|
||||
from framework.monitoring import judge_goal, judge_graph
|
||||
from framework.runner.tool_registry import ToolRegistry
|
||||
from framework.runtime.core import Runtime
|
||||
from framework.runtime.event_bus import EventType as _ET
|
||||
from framework.tools.queen_lifecycle_tools import register_queen_lifecycle_tools
|
||||
from framework.tools.worker_monitoring_tools import register_worker_monitoring_tools
|
||||
|
||||
log = logging.getLogger("tui.judge")
|
||||
log = logging.getLogger("tui.queen")
|
||||
|
||||
try:
|
||||
storage_path = Path(storage_path)
|
||||
@@ -502,64 +500,16 @@ class AdenTUI(App):
|
||||
worker_graph_id=self.runtime._graph_id,
|
||||
)
|
||||
|
||||
# 2. Storage dirs — global, not per-agent. Queen and judge are
|
||||
# supervisory components that outlive any single worker.
|
||||
# 2. Storage dirs — global, not per-agent.
|
||||
hive_home = Path.home() / ".hive"
|
||||
judge_dir = hive_home / "judge" / "session" / session_id
|
||||
judge_dir.mkdir(parents=True, exist_ok=True)
|
||||
queen_dir = hive_home / "queen" / "session" / session_id
|
||||
queen_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 3. Health judge — background task, fires every 2 minutes.
|
||||
# ---------------------------------------------------------------
|
||||
judge_runtime = Runtime(hive_home / "judge")
|
||||
monitoring_tools = list(monitoring_registry.get_tools().values())
|
||||
monitoring_executor = monitoring_registry.get_executor()
|
||||
|
||||
# Scoped event buses — stamp graph_id on every event so
|
||||
# downstream routing (queen-primary mode) can distinguish
|
||||
# queen/judge/worker events.
|
||||
# Health judge disabled for simplicity.
|
||||
from framework.runtime.execution_stream import GraphScopedEventBus
|
||||
|
||||
judge_event_bus = GraphScopedEventBus(event_bus, "judge")
|
||||
queen_event_bus = GraphScopedEventBus(event_bus, "queen")
|
||||
|
||||
async def _judge_loop():
|
||||
interval = 120 # seconds
|
||||
first = True
|
||||
while True:
|
||||
if not first:
|
||||
await asyncio.sleep(interval)
|
||||
first = False
|
||||
try:
|
||||
executor = GraphExecutor(
|
||||
runtime=judge_runtime,
|
||||
llm=llm,
|
||||
tools=monitoring_tools,
|
||||
tool_executor=monitoring_executor,
|
||||
event_bus=judge_event_bus,
|
||||
stream_id="judge",
|
||||
storage_path=judge_dir,
|
||||
loop_config=judge_graph.loop_config,
|
||||
)
|
||||
await executor.execute(
|
||||
graph=judge_graph,
|
||||
goal=judge_goal,
|
||||
input_data={
|
||||
"event": {"source": "timer", "reason": "scheduled"},
|
||||
},
|
||||
session_state={"resume_session_id": session_id},
|
||||
)
|
||||
except Exception:
|
||||
log.error("Health judge tick failed", exc_info=True)
|
||||
|
||||
self._judge_task = asyncio.run_coroutine_threadsafe(
|
||||
_judge_loop(),
|
||||
agent_loop,
|
||||
)
|
||||
self._judge_graph_id = "judge"
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 4. Queen — persistent interactive conversation.
|
||||
# Runs a continuous event_loop node that is the user's
|
||||
@@ -606,7 +556,10 @@ class AdenTUI(App):
|
||||
# Build worker profile for queen's system prompt.
|
||||
from framework.tools.queen_lifecycle_tools import build_worker_profile
|
||||
|
||||
worker_identity = build_worker_profile(self.runtime)
|
||||
worker_identity = build_worker_profile(
|
||||
self.runtime,
|
||||
agent_path=self._runner.agent_path if self._runner else None,
|
||||
)
|
||||
|
||||
# Adjust queen graph: filter tools to what's registered and
|
||||
# append worker identity to the system prompt.
|
||||
@@ -687,31 +640,8 @@ class AdenTUI(App):
|
||||
|
||||
self.chat_repl._queen_inject_callback = _inject_queen
|
||||
|
||||
# Judge escalation → inject into queen conversation as a message.
|
||||
async def _on_escalation(event):
|
||||
ticket = event.data.get("ticket", {})
|
||||
executor = self._queen_executor
|
||||
if executor is None:
|
||||
log.warning("Escalation received but queen executor is None")
|
||||
return
|
||||
node = executor.node_registry.get("queen")
|
||||
if node is not None and hasattr(node, "inject_event"):
|
||||
import json as _json
|
||||
|
||||
msg = "[ESCALATION TICKET from Health Judge]\n" + _json.dumps(
|
||||
ticket, indent=2, ensure_ascii=False
|
||||
)
|
||||
await node.inject_event(msg)
|
||||
else:
|
||||
log.warning("Escalation received but queen node not ready for injection")
|
||||
|
||||
self._queen_escalation_sub = event_bus.subscribe(
|
||||
event_types=[_ET.WORKER_ESCALATION_TICKET],
|
||||
handler=_on_escalation,
|
||||
)
|
||||
|
||||
self.notify(
|
||||
"Queen + health judge active",
|
||||
"Queen active",
|
||||
severity="information",
|
||||
timeout=3,
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ def _get_last_active(agent_name: str) -> str | None:
|
||||
if not state_file.exists():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(state_file.read_text())
|
||||
data = json.loads(state_file.read_text(encoding="utf-8"))
|
||||
ts = data.get("timestamps", {}).get("updated_at")
|
||||
if ts and (latest is None or ts > latest):
|
||||
latest = ts
|
||||
@@ -84,7 +84,7 @@ def _extract_agent_stats(agent_path: Path) -> tuple[int, int, list[str]]:
|
||||
agent_py = agent_path / "agent.py"
|
||||
if agent_py.exists():
|
||||
try:
|
||||
tree = ast.parse(agent_py.read_text())
|
||||
tree = ast.parse(agent_py.read_text(encoding="utf-8"))
|
||||
for node in ast.walk(tree):
|
||||
# Find `nodes = [...]` assignment
|
||||
if isinstance(node, ast.Assign):
|
||||
@@ -99,7 +99,7 @@ def _extract_agent_stats(agent_path: Path) -> tuple[int, int, list[str]]:
|
||||
agent_json = agent_path / "agent.json"
|
||||
if agent_json.exists():
|
||||
try:
|
||||
data = json.loads(agent_json.read_text())
|
||||
data = json.loads(agent_json.read_text(encoding="utf-8"))
|
||||
json_nodes = data.get("nodes", [])
|
||||
if node_count == 0:
|
||||
node_count = len(json_nodes)
|
||||
@@ -150,7 +150,7 @@ def discover_agents() -> dict[str, list[AgentEntry]]:
|
||||
agent_json = path / "agent.json"
|
||||
if agent_json.exists():
|
||||
try:
|
||||
data = json.loads(agent_json.read_text())
|
||||
data = json.loads(agent_json.read_text(encoding="utf-8"))
|
||||
meta = data.get("agent", {})
|
||||
name = meta.get("name", name)
|
||||
desc = meta.get("description", desc)
|
||||
|
||||
@@ -160,20 +160,9 @@ class CredentialSetupScreen(ModalScreen[bool | None]):
|
||||
aden_input = self.query_one("#key-aden", Input)
|
||||
aden_key = aden_input.value.strip()
|
||||
if aden_key:
|
||||
os.environ["ADEN_API_KEY"] = aden_key
|
||||
# Persist to shell config
|
||||
try:
|
||||
from aden_tools.credentials.shell_config import (
|
||||
add_env_var_to_shell_config,
|
||||
)
|
||||
from framework.credentials.key_storage import save_aden_api_key
|
||||
|
||||
add_env_var_to_shell_config(
|
||||
"ADEN_API_KEY",
|
||||
aden_key,
|
||||
comment="Aden Platform API key",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
save_aden_api_key(aden_key)
|
||||
configured += 1 # ADEN_API_KEY itself counts as configured
|
||||
|
||||
# Run Aden sync for all Aden-backed creds (best-effort)
|
||||
|
||||
@@ -1460,10 +1460,6 @@ class ChatRepl(Vertical):
|
||||
indicator.update("Preparing question...")
|
||||
return
|
||||
|
||||
if tool_name == "escalate_to_coder":
|
||||
indicator.update("Escalating to coder...")
|
||||
return
|
||||
|
||||
# Update indicator to show tool activity
|
||||
indicator.update(f"Using tool: {tool_name}...")
|
||||
|
||||
@@ -1475,7 +1471,7 @@ class ChatRepl(Vertical):
|
||||
|
||||
def handle_tool_completed(self, tool_name: str, result: str, is_error: bool) -> None:
|
||||
"""Handle a tool call completing."""
|
||||
if tool_name in ("ask_user", "escalate_to_coder"):
|
||||
if tool_name == "ask_user":
|
||||
return
|
||||
|
||||
result_str = str(result)
|
||||
|
||||
@@ -25,6 +25,7 @@ EVENT_FORMAT: dict[EventType, tuple[str, str]] = {
|
||||
EventType.NODE_LOOP_STARTED: ("@@", "cyan"),
|
||||
EventType.NODE_LOOP_ITERATION: ("..", "dim"),
|
||||
EventType.NODE_LOOP_COMPLETED: ("@@", "dim"),
|
||||
EventType.LLM_TURN_COMPLETE: ("◆", "green"),
|
||||
EventType.NODE_STALLED: ("!!", "bold yellow"),
|
||||
EventType.NODE_INPUT_BLOCKED: ("!!", "yellow"),
|
||||
EventType.GOAL_PROGRESS: ("%%", "blue"),
|
||||
@@ -87,6 +88,12 @@ def extract_event_text(event: AgentEvent) -> str:
|
||||
return f"State changed: {data.get('key', 'unknown')}"
|
||||
elif et == EventType.CLIENT_INPUT_REQUESTED:
|
||||
return "Waiting for user input"
|
||||
elif et == EventType.LLM_TURN_COMPLETE:
|
||||
stop = data.get("stop_reason", "?")
|
||||
model = data.get("model", "?")
|
||||
inp = data.get("input_tokens", 0)
|
||||
out = data.get("output_tokens", 0)
|
||||
return f"{model} → {stop} ({inp}+{out} tokens)"
|
||||
else:
|
||||
return f"{et.value}: {data}"
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ export interface AgentCredentialRequirement {
|
||||
direct_api_key_supported: boolean;
|
||||
aden_supported: boolean;
|
||||
credential_key: string;
|
||||
alternative_group: string | null;
|
||||
}
|
||||
|
||||
export const credentialsApi = {
|
||||
@@ -45,7 +46,4 @@ export const credentialsApi = {
|
||||
"/credentials/check-agent",
|
||||
{ agent_path: agentPath },
|
||||
),
|
||||
|
||||
saveAdenKey: (key: string) =>
|
||||
api.post<{ saved: boolean }>("/credentials/aden-key", { key }),
|
||||
};
|
||||
|
||||
@@ -37,6 +37,9 @@ export const executionApi = {
|
||||
chat: (sessionId: string, message: string) =>
|
||||
api.post<ChatResult>(`/sessions/${sessionId}/chat`, { message }),
|
||||
|
||||
workerInput: (sessionId: string, message: string) =>
|
||||
api.post<ChatResult>(`/sessions/${sessionId}/worker-input`, { message }),
|
||||
|
||||
stop: (sessionId: string, executionId: string) =>
|
||||
api.post<StopResult>(`/sessions/${sessionId}/stop`, {
|
||||
execution_id: executionId,
|
||||
@@ -47,6 +50,9 @@ export const executionApi = {
|
||||
execution_id: executionId,
|
||||
}),
|
||||
|
||||
cancelQueen: (sessionId: string) =>
|
||||
api.post<{ cancelled: boolean }>(`/sessions/${sessionId}/cancel-queen`),
|
||||
|
||||
resume: (sessionId: string, workerSessionId: string, checkpointId?: string) =>
|
||||
api.post<ResumeResult>(`/sessions/${sessionId}/resume`, {
|
||||
session_id: workerSessionId,
|
||||
|
||||
@@ -26,6 +26,9 @@ export interface EntryPoint {
|
||||
name: string;
|
||||
entry_node: string;
|
||||
trigger_type: string;
|
||||
trigger_config?: Record<string, unknown>;
|
||||
/** Seconds until the next timer fire (only present for timer entry points). */
|
||||
next_fire_in?: number;
|
||||
}
|
||||
|
||||
export interface DiscoverEntry {
|
||||
@@ -130,6 +133,8 @@ export interface Message {
|
||||
is_transition_marker?: boolean;
|
||||
is_client_input?: boolean;
|
||||
tool_calls?: unknown[];
|
||||
/** Epoch seconds from file mtime — used for cross-conversation ordering */
|
||||
created_at?: number;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
@@ -150,6 +155,7 @@ export interface NodeSpec {
|
||||
client_facing: boolean;
|
||||
success_criteria: string | null;
|
||||
system_prompt: string;
|
||||
sub_agents?: string[];
|
||||
// Runtime enrichment (when session_id provided)
|
||||
visit_count?: number;
|
||||
has_failures?: boolean;
|
||||
@@ -178,6 +184,7 @@ export interface GraphTopology {
|
||||
nodes: NodeSpec[];
|
||||
edges: GraphEdge[];
|
||||
entry_node: string;
|
||||
entry_points?: EntryPoint[];
|
||||
}
|
||||
|
||||
export interface NodeCriteria {
|
||||
@@ -262,7 +269,9 @@ export type EventTypeName =
|
||||
| "webhook_received"
|
||||
| "custom"
|
||||
| "escalation_requested"
|
||||
| "worker_loaded";
|
||||
| "worker_loaded"
|
||||
| "credentials_required"
|
||||
| "subagent_report";
|
||||
|
||||
export interface AgentEvent {
|
||||
type: EventTypeName;
|
||||
|
||||
@@ -3,10 +3,15 @@ import { Play, Pause, Loader2, CheckCircle2 } from "lucide-react";
|
||||
|
||||
export type NodeStatus = "running" | "complete" | "pending" | "error" | "looping";
|
||||
|
||||
export type NodeType = "execution" | "trigger";
|
||||
|
||||
export interface GraphNode {
|
||||
id: string;
|
||||
label: string;
|
||||
status: NodeStatus;
|
||||
nodeType?: NodeType;
|
||||
triggerType?: string;
|
||||
triggerConfig?: Record<string, unknown>;
|
||||
next?: string[];
|
||||
backEdges?: string[];
|
||||
iterations?: number;
|
||||
@@ -25,6 +30,7 @@ interface AgentGraphProps {
|
||||
onPause?: () => void;
|
||||
version?: string;
|
||||
runState?: RunState;
|
||||
building?: boolean;
|
||||
}
|
||||
|
||||
// --- Extracted RunButton so hover state survives parent re-renders ---
|
||||
@@ -116,14 +122,30 @@ const statusColors: Record<NodeStatus, { dot: string; bg: string; border: string
|
||||
},
|
||||
};
|
||||
|
||||
function formatLabel(id: string): string {
|
||||
return id
|
||||
.split("-")
|
||||
.map((w) => w.charAt(0).toUpperCase() + w.slice(1))
|
||||
.join(" ");
|
||||
// Trigger node palette — cool blue-gray, visually distinct from amber execution nodes
|
||||
const triggerColors = {
|
||||
bg: "hsl(210,25%,14%)",
|
||||
border: "hsl(210,30%,30%)",
|
||||
text: "hsl(210,30%,65%)",
|
||||
icon: "hsl(210,40%,55%)",
|
||||
};
|
||||
|
||||
const triggerIcons: Record<string, string> = {
|
||||
webhook: "\u26A1", // lightning bolt
|
||||
timer: "\u23F1", // stopwatch
|
||||
api: "\u2192", // right arrow
|
||||
event: "\u223F", // sine wave
|
||||
};
|
||||
|
||||
/** Truncate label to fit within `availablePx` at the given fontSize. */
|
||||
function truncateLabel(label: string, availablePx: number, fontSize: number): string {
|
||||
const avgCharW = fontSize * 0.58;
|
||||
const maxChars = Math.floor(availablePx / avgCharW);
|
||||
if (label.length <= maxChars) return label;
|
||||
return label.slice(0, Math.max(maxChars - 1, 1)) + "\u2026";
|
||||
}
|
||||
|
||||
export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, onPause, version, runState: externalRunState }: AgentGraphProps) {
|
||||
export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, onPause, version, runState: externalRunState, building }: AgentGraphProps) {
|
||||
const [localRunState, setLocalRunState] = useState<RunState>("idle");
|
||||
const runState = externalRunState ?? localRunState;
|
||||
const runBtnRef = useRef<HTMLButtonElement>(null);
|
||||
@@ -258,7 +280,14 @@ export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, o
|
||||
<RunButton runState={runState} disabled={nodes.length === 0} onRun={handleRun} onPause={onPause ?? (() => {})} btnRef={runBtnRef} />
|
||||
</div>
|
||||
<div className="flex-1 flex items-center justify-center px-5">
|
||||
<p className="text-xs text-muted-foreground/60 text-center italic">No pipeline configured yet.<br/>Chat with the Queen to get started.</p>
|
||||
{building ? (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<Loader2 className="w-6 h-6 animate-spin text-primary/60" />
|
||||
<p className="text-xs text-muted-foreground/80 text-center">Building agent...</p>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-xs text-muted-foreground/60 text-center italic">No pipeline configured yet.<br/>Chat with the Queen to get started.</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -380,15 +409,89 @@ export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, o
|
||||
);
|
||||
};
|
||||
|
||||
const renderTriggerNode = (node: GraphNode, i: number) => {
|
||||
const pos = nodePos(i);
|
||||
const icon = triggerIcons[node.triggerType || ""] || "\u26A1";
|
||||
const triggerFontSize = nodeW < 140 ? 10.5 : 11.5;
|
||||
const triggerAvailW = nodeW - 38;
|
||||
const triggerDisplayLabel = truncateLabel(node.label, triggerAvailW, triggerFontSize);
|
||||
const nextFireIn = node.triggerConfig?.next_fire_in as number | undefined;
|
||||
|
||||
// Format countdown for display below node
|
||||
let countdownLabel: string | null = null;
|
||||
if (nextFireIn != null && nextFireIn > 0) {
|
||||
const h = Math.floor(nextFireIn / 3600);
|
||||
const m = Math.floor((nextFireIn % 3600) / 60);
|
||||
const s = Math.floor(nextFireIn % 60);
|
||||
countdownLabel = h > 0
|
||||
? `next in ${h}h ${String(m).padStart(2, "0")}m`
|
||||
: `next in ${m}m ${String(s).padStart(2, "0")}s`;
|
||||
}
|
||||
|
||||
return (
|
||||
<g key={node.id} onClick={() => onNodeClick?.(node)} style={{ cursor: onNodeClick ? "pointer" : "default" }}>
|
||||
<title>{node.label}</title>
|
||||
{/* Pill-shaped background with dashed border */}
|
||||
<rect
|
||||
x={pos.x} y={pos.y}
|
||||
width={nodeW} height={NODE_H}
|
||||
rx={NODE_H / 2}
|
||||
fill={triggerColors.bg}
|
||||
stroke={triggerColors.border}
|
||||
strokeWidth={1}
|
||||
strokeDasharray="4 2"
|
||||
/>
|
||||
|
||||
{/* Trigger type icon */}
|
||||
<text
|
||||
x={pos.x + 18} y={pos.y + NODE_H / 2}
|
||||
fill={triggerColors.icon} fontSize={13}
|
||||
textAnchor="middle" dominantBaseline="middle"
|
||||
>
|
||||
{icon}
|
||||
</text>
|
||||
|
||||
{/* Label */}
|
||||
<text
|
||||
x={pos.x + 32} y={pos.y + NODE_H / 2}
|
||||
fill={triggerColors.text}
|
||||
fontSize={triggerFontSize}
|
||||
fontWeight={500}
|
||||
dominantBaseline="middle"
|
||||
letterSpacing="0.01em"
|
||||
>
|
||||
{triggerDisplayLabel}
|
||||
</text>
|
||||
|
||||
{/* Countdown label below node */}
|
||||
{countdownLabel && (
|
||||
<text
|
||||
x={pos.x + nodeW / 2} y={pos.y + NODE_H + 13}
|
||||
fill="hsl(210,30%,50%)" fontSize={9.5}
|
||||
textAnchor="middle" fontStyle="italic" opacity={0.7}
|
||||
>
|
||||
{countdownLabel}
|
||||
</text>
|
||||
)}
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const renderNode = (node: GraphNode, i: number) => {
|
||||
if (node.nodeType === "trigger") return renderTriggerNode(node, i);
|
||||
|
||||
const pos = nodePos(i);
|
||||
const isActive = node.status === "running" || node.status === "looping";
|
||||
const isDone = node.status === "complete";
|
||||
const colors = statusColors[node.status];
|
||||
const clipId = `clip-label-${node.id}`;
|
||||
|
||||
const fontSize = nodeW < 140 ? 10.5 : 12.5;
|
||||
const labelAvailW = nodeW - 38;
|
||||
const displayLabel = truncateLabel(node.label, labelAvailW, fontSize);
|
||||
|
||||
return (
|
||||
<g key={node.id} onClick={() => onNodeClick?.(node)} style={{ cursor: onNodeClick ? "pointer" : "default" }}>
|
||||
<title>{node.label}</title>
|
||||
{/* Ambient glow for active nodes */}
|
||||
{isActive && (
|
||||
<>
|
||||
@@ -436,20 +539,16 @@ export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, o
|
||||
</text>
|
||||
)}
|
||||
|
||||
{/* Label -- properly capitalized, clipped for narrow nodes */}
|
||||
<clipPath id={clipId}>
|
||||
<rect x={pos.x + 30} y={pos.y} width={nodeW - 38} height={NODE_H} />
|
||||
</clipPath>
|
||||
{/* Label -- truncated with ellipsis for narrow nodes */}
|
||||
<text
|
||||
x={pos.x + 32} y={pos.y + NODE_H / 2}
|
||||
fill={isActive ? "hsl(45,90%,85%)" : isDone ? "hsl(40,20%,75%)" : "hsl(35,10%,45%)"}
|
||||
fontSize={nodeW < 140 ? 10.5 : 12.5}
|
||||
fontSize={fontSize}
|
||||
fontWeight={isActive ? 600 : isDone ? 500 : 400}
|
||||
dominantBaseline="middle"
|
||||
letterSpacing="0.01em"
|
||||
clipPath={`url(#${clipId})`}
|
||||
>
|
||||
{formatLabel(node.id)}
|
||||
{displayLabel}
|
||||
</text>
|
||||
|
||||
{/* Status label for active nodes */}
|
||||
@@ -500,18 +599,26 @@ export default function AgentGraph({ nodes, title: _title, onNodeClick, onRun, o
|
||||
</div>
|
||||
|
||||
{/* Graph */}
|
||||
<div className="flex-1 overflow-y-auto overflow-x-hidden px-3 pb-5">
|
||||
<div className="flex-1 overflow-y-auto overflow-x-hidden px-3 pb-5 relative">
|
||||
<svg
|
||||
width={svgWidth}
|
||||
height={svgHeight}
|
||||
viewBox={`0 0 ${svgWidth} ${svgHeight}`}
|
||||
className="select-none"
|
||||
className={`select-none${building ? " opacity-30" : ""}`}
|
||||
style={{ fontFamily: "'Inter', system-ui, sans-serif" }}
|
||||
>
|
||||
{forwardEdges.map((e, i) => renderForwardEdge(e, i))}
|
||||
{backEdges.map((e, i) => renderBackEdge(e, i))}
|
||||
{nodes.map((n, i) => renderNode(n, i))}
|
||||
</svg>
|
||||
{building && (
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<Loader2 className="w-6 h-6 animate-spin text-primary/60" />
|
||||
<p className="text-xs text-muted-foreground/80">Rebuilding agent...</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { memo, useState, useRef, useEffect } from "react";
|
||||
import { Send, Crown, Cpu } from "lucide-react";
|
||||
import { formatAgentDisplayName } from "@/lib/chat-helpers";
|
||||
import { Send, Square, Crown, Cpu, Check, Loader2, Reply } from "lucide-react";
|
||||
import MarkdownContent from "@/components/MarkdownContent";
|
||||
|
||||
export interface ChatMessage {
|
||||
@@ -9,10 +8,12 @@ export interface ChatMessage {
|
||||
agentColor: string;
|
||||
content: string;
|
||||
timestamp: string;
|
||||
type?: "system" | "agent" | "user";
|
||||
type?: "system" | "agent" | "user" | "tool_status" | "worker_input_request";
|
||||
role?: "queen" | "worker";
|
||||
/** Which worker thread this message belongs to (worker agent name) */
|
||||
thread?: string;
|
||||
/** Epoch ms when this message was first created — used for ordering queen/worker interleaving */
|
||||
createdAt?: number;
|
||||
}
|
||||
|
||||
interface ChatPanelProps {
|
||||
@@ -20,17 +21,185 @@ interface ChatPanelProps {
|
||||
onSend: (message: string, thread: string) => void;
|
||||
isWaiting?: boolean;
|
||||
activeThread: string;
|
||||
/** When true, the agent is waiting for user input — changes placeholder text */
|
||||
awaitingInput?: boolean;
|
||||
/** When true, the worker is waiting for user input — shows inline reply box */
|
||||
workerAwaitingInput?: boolean;
|
||||
/** When true, the input is disabled (e.g. during loading) */
|
||||
disabled?: boolean;
|
||||
/** Called when user clicks the stop button to cancel the queen's current turn */
|
||||
onCancel?: () => void;
|
||||
/** Called when user submits a reply to the worker's input request */
|
||||
onWorkerReply?: (message: string) => void;
|
||||
}
|
||||
|
||||
const queenColor = "hsl(45,95%,58%)";
|
||||
const workerColor = "hsl(220,60%,55%)";
|
||||
|
||||
function getColor(_agent: string, role?: "queen" | "worker"): string {
|
||||
if (role === "queen") return queenColor;
|
||||
return "hsl(220,60%,55%)";
|
||||
return workerColor;
|
||||
}
|
||||
|
||||
// Honey-drizzle palette — based on color-hex.com/color-palette/80116
|
||||
// #8e4200 · #db6f02 · #ff9624 · #ffb825 · #ffd69c + adjacent warm tones
|
||||
const TOOL_HEX = [
|
||||
"#db6f02", // rich orange
|
||||
"#ffb825", // golden yellow
|
||||
"#ff9624", // bright orange
|
||||
"#c48820", // warm bronze
|
||||
"#e89530", // honey
|
||||
"#d4a040", // goldenrod
|
||||
"#cc7a10", // caramel
|
||||
"#e5a820", // sunflower
|
||||
];
|
||||
|
||||
function toolHex(name: string): string {
|
||||
let hash = 0;
|
||||
for (let i = 0; i < name.length; i++) hash = (hash * 31 + name.charCodeAt(i)) | 0;
|
||||
return TOOL_HEX[Math.abs(hash) % TOOL_HEX.length];
|
||||
}
|
||||
|
||||
function ToolActivityRow({ content }: { content: string }) {
|
||||
let tools: { name: string; done: boolean }[] = [];
|
||||
try {
|
||||
const parsed = JSON.parse(content);
|
||||
tools = parsed.tools || [];
|
||||
} catch {
|
||||
// Legacy plain-text fallback
|
||||
return (
|
||||
<div className="flex gap-3 pl-10">
|
||||
<span className="text-[11px] text-muted-foreground bg-muted/40 px-3 py-1 rounded-full border border-border/40">
|
||||
{content}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (tools.length === 0) return null;
|
||||
|
||||
// Group by tool name → count done vs running
|
||||
const grouped = new Map<string, { done: number; running: number }>();
|
||||
for (const t of tools) {
|
||||
const entry = grouped.get(t.name) || { done: 0, running: 0 };
|
||||
if (t.done) entry.done++;
|
||||
else entry.running++;
|
||||
grouped.set(t.name, entry);
|
||||
}
|
||||
|
||||
// Build pill list: running first, then done
|
||||
const runningPills: { name: string; count: number }[] = [];
|
||||
const donePills: { name: string; count: number }[] = [];
|
||||
for (const [name, counts] of grouped) {
|
||||
if (counts.running > 0) runningPills.push({ name, count: counts.running });
|
||||
if (counts.done > 0) donePills.push({ name, count: counts.done });
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex gap-3 pl-10">
|
||||
<div className="flex flex-wrap items-center gap-1.5">
|
||||
{runningPills.map((p) => {
|
||||
const hex = toolHex(p.name);
|
||||
return (
|
||||
<span
|
||||
key={`run-${p.name}`}
|
||||
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
|
||||
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
|
||||
>
|
||||
<Loader2 className="w-2.5 h-2.5 animate-spin" />
|
||||
{p.name}
|
||||
{p.count > 1 && (
|
||||
<span className="text-[10px] font-medium opacity-70">×{p.count}</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
})}
|
||||
{donePills.map((p) => {
|
||||
const hex = toolHex(p.name);
|
||||
return (
|
||||
<span
|
||||
key={`done-${p.name}`}
|
||||
className="inline-flex items-center gap-1 text-[11px] px-2.5 py-0.5 rounded-full"
|
||||
style={{ color: hex, backgroundColor: `${hex}18`, border: `1px solid ${hex}35` }}
|
||||
>
|
||||
<Check className="w-2.5 h-2.5" />
|
||||
{p.name}
|
||||
{p.count > 1 && (
|
||||
<span className="text-[10px] opacity-80">×{p.count}</span>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/** Inline reply box that appears below a worker's input request in the chat thread. */
|
||||
function WorkerInputReply({ onSubmit, disabled }: { onSubmit: (text: string) => void; disabled?: boolean }) {
|
||||
const [value, setValue] = useState("");
|
||||
const [sent, setSent] = useState(false);
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!disabled && !sent) inputRef.current?.focus();
|
||||
}, [disabled, sent]);
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!value.trim() || sent) return;
|
||||
onSubmit(value.trim());
|
||||
setSent(true);
|
||||
};
|
||||
|
||||
if (sent) {
|
||||
return (
|
||||
<div className="ml-10 flex items-center gap-1.5 text-[11px] text-muted-foreground py-1">
|
||||
<Check className="w-3 h-3 text-emerald-500" />
|
||||
<span>Response sent</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="ml-10 mt-1">
|
||||
<div
|
||||
className="flex items-center gap-2 rounded-xl px-3 py-2 border transition-colors"
|
||||
style={{
|
||||
backgroundColor: `${workerColor}08`,
|
||||
borderColor: `${workerColor}30`,
|
||||
}}
|
||||
>
|
||||
<Reply className="w-3.5 h-3.5 flex-shrink-0" style={{ color: workerColor }} />
|
||||
<textarea
|
||||
ref={inputRef}
|
||||
rows={1}
|
||||
value={value}
|
||||
onChange={(e) => {
|
||||
setValue(e.target.value);
|
||||
const ta = e.target;
|
||||
ta.style.height = "auto";
|
||||
ta.style.height = `${Math.min(ta.scrollHeight, 120)}px`;
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSubmit(e);
|
||||
}
|
||||
}}
|
||||
placeholder="Reply to worker..."
|
||||
disabled={disabled}
|
||||
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 resize-none overflow-y-auto"
|
||||
/>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!value.trim() || disabled}
|
||||
className="p-1.5 rounded-lg transition-opacity disabled:opacity-30 hover:opacity-90"
|
||||
style={{ backgroundColor: workerColor, color: "white" }}
|
||||
>
|
||||
<Send className="w-3.5 h-3.5" />
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
const MessageBubble = memo(function MessageBubble({ msg }: { msg: ChatMessage }) {
|
||||
@@ -48,6 +217,10 @@ const MessageBubble = memo(function MessageBubble({ msg }: { msg: ChatMessage })
|
||||
);
|
||||
}
|
||||
|
||||
if (msg.type === "tool_status") {
|
||||
return <ToolActivityRow content={msg.content} />;
|
||||
}
|
||||
|
||||
if (isUser) {
|
||||
return (
|
||||
<div className="flex justify-end">
|
||||
@@ -99,16 +272,16 @@ const MessageBubble = memo(function MessageBubble({ msg }: { msg: ChatMessage })
|
||||
);
|
||||
}, (prev, next) => prev.msg.id === next.msg.id && prev.msg.content === next.msg.content);
|
||||
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, activeThread, awaitingInput, disabled }: ChatPanelProps) {
|
||||
export default function ChatPanel({ messages, onSend, isWaiting, activeThread, workerAwaitingInput, disabled, onCancel, onWorkerReply }: ChatPanelProps) {
|
||||
const [input, setInput] = useState("");
|
||||
const [readMap, setReadMap] = useState<Record<string, number>>({});
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const threadMessages = messages.filter((m) => {
|
||||
if (m.type === "system" && !m.thread) return false;
|
||||
return m.thread === activeThread;
|
||||
});
|
||||
console.log('[ChatPanel] render: messages:', messages.length, 'threadMessages:', threadMessages.length, 'activeThread:', activeThread, 'threads:', [...new Set(messages.map(m => m.thread))]);
|
||||
|
||||
// Mark current thread as read
|
||||
useEffect(() => {
|
||||
@@ -122,16 +295,26 @@ export default function ChatPanel({ messages, onSend, isWaiting, activeThread, a
|
||||
const lastMsg = threadMessages[threadMessages.length - 1];
|
||||
useEffect(() => {
|
||||
bottomRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [threadMessages.length, lastMsg?.content]);
|
||||
}, [threadMessages.length, lastMsg?.content, workerAwaitingInput]);
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!input.trim()) return;
|
||||
onSend(input.trim(), activeThread);
|
||||
setInput("");
|
||||
if (textareaRef.current) textareaRef.current.style.height = "auto";
|
||||
};
|
||||
|
||||
const activeWorkerLabel = formatAgentDisplayName(activeThread);
|
||||
// Find the last worker message to attach the inline reply box below.
|
||||
// For explicit ask_user, this will be the worker_input_request message.
|
||||
// For auto-block, this will be the last client_output_delta streamed message.
|
||||
const lastWorkerMsgIdx = workerAwaitingInput
|
||||
? threadMessages.reduce(
|
||||
(last, m, i) =>
|
||||
m.role === "worker" && m.type !== "tool_status" && m.type !== "system" ? i : last,
|
||||
-1,
|
||||
)
|
||||
: -1;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full min-w-0">
|
||||
@@ -142,8 +325,13 @@ export default function ChatPanel({ messages, onSend, isWaiting, activeThread, a
|
||||
|
||||
{/* Messages */}
|
||||
<div className="flex-1 overflow-auto px-5 py-4 space-y-3">
|
||||
{threadMessages.map((msg) => (
|
||||
<MessageBubble key={msg.id} msg={msg} />
|
||||
{threadMessages.map((msg, idx) => (
|
||||
<div key={msg.id}>
|
||||
<MessageBubble msg={msg} />
|
||||
{idx === lastWorkerMsgIdx && onWorkerReply && (
|
||||
<WorkerInputReply onSubmit={onWorkerReply} />
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
|
||||
{isWaiting && (
|
||||
@@ -163,29 +351,46 @@ export default function ChatPanel({ messages, onSend, isWaiting, activeThread, a
|
||||
<div ref={bottomRef} />
|
||||
</div>
|
||||
|
||||
{/* Input */}
|
||||
{/* Input — always connected to Queen */}
|
||||
<form onSubmit={handleSubmit} className="p-4 border-t border-border">
|
||||
<div className="flex items-center gap-3 bg-muted/40 rounded-xl px-4 py-2.5 border border-border focus-within:border-primary/40 transition-colors">
|
||||
<input
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
rows={1}
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
placeholder={
|
||||
disabled
|
||||
? "Connecting to agent..."
|
||||
: awaitingInput
|
||||
? "Agent is waiting for your response..."
|
||||
: `Message ${activeWorkerLabel}...`
|
||||
}
|
||||
onChange={(e) => {
|
||||
setInput(e.target.value);
|
||||
const ta = e.target;
|
||||
ta.style.height = "auto";
|
||||
ta.style.height = `${Math.min(ta.scrollHeight, 160)}px`;
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSubmit(e);
|
||||
}
|
||||
}}
|
||||
placeholder={disabled ? "Connecting to agent..." : "Message Queen Bee..."}
|
||||
disabled={disabled}
|
||||
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
className="flex-1 bg-transparent text-sm text-foreground outline-none placeholder:text-muted-foreground disabled:opacity-50 disabled:cursor-not-allowed resize-none overflow-y-auto"
|
||||
/>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!input.trim() || disabled}
|
||||
className="p-2 rounded-lg bg-primary text-primary-foreground disabled:opacity-30 hover:opacity-90 transition-opacity"
|
||||
>
|
||||
<Send className="w-4 h-4" />
|
||||
</button>
|
||||
{isWaiting && onCancel ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onCancel}
|
||||
className="p-2 rounded-lg bg-destructive text-destructive-foreground hover:opacity-90 transition-opacity"
|
||||
>
|
||||
<Square className="w-4 h-4" />
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!input.trim() || disabled}
|
||||
className="p-2 rounded-lg bg-primary text-primary-foreground disabled:opacity-30 hover:opacity-90 transition-opacity"
|
||||
>
|
||||
<Send className="w-4 h-4" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { KeyRound, Check, AlertCircle, X, Shield, Loader2, Trash2, ExternalLink } from "lucide-react";
|
||||
import { useState, useEffect, useCallback, useRef } from "react";
|
||||
import { KeyRound, Check, AlertCircle, X, Shield, Loader2, Trash2, ExternalLink, Pencil } from "lucide-react";
|
||||
import { credentialsApi, type AgentCredentialRequirement } from "@/api/credentials";
|
||||
|
||||
export interface Credential {
|
||||
@@ -40,6 +40,7 @@ interface CredentialRow {
|
||||
adenSupported: boolean; // whether this credential uses OAuth via Aden
|
||||
valid: boolean | null; // true = health check passed, false = failed, null = not checked
|
||||
validationMessage: string | null;
|
||||
alternativeGroup: string | null; // non-null when multiple providers can satisfy a tool
|
||||
}
|
||||
|
||||
function requirementToRow(r: AgentCredentialRequirement): CredentialRow {
|
||||
@@ -54,6 +55,7 @@ function requirementToRow(r: AgentCredentialRequirement): CredentialRow {
|
||||
adenSupported: r.aden_supported,
|
||||
valid: r.valid,
|
||||
validationMessage: r.validation_message,
|
||||
alternativeGroup: r.alternative_group ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -61,6 +63,16 @@ function requirementToRow(r: AgentCredentialRequirement): CredentialRow {
|
||||
// Cleared on save/delete so the next fetch picks up updated availability.
|
||||
const credentialCache = new Map<string, AgentCredentialRequirement[]>();
|
||||
|
||||
/** Clear cached credential requirements so the next modal open fetches fresh data.
|
||||
* Call with a specific path to clear one entry, or no args to clear all. */
|
||||
export function clearCredentialCache(agentPath?: string) {
|
||||
if (agentPath) {
|
||||
credentialCache.delete(agentPath);
|
||||
} else {
|
||||
credentialCache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
interface CredentialsModalProps {
|
||||
agentType: string;
|
||||
agentLabel: string;
|
||||
@@ -90,9 +102,8 @@ export default function CredentialsModal({
|
||||
const [inputValue, setInputValue] = useState("");
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [deletingId, setDeletingId] = useState<string | null>(null);
|
||||
const [hasAdenKey, setHasAdenKey] = useState(true); // assume true until backend says otherwise
|
||||
const [adenKeyInput, setAdenKeyInput] = useState("");
|
||||
const [savingAdenKey, setSavingAdenKey] = useState(false);
|
||||
const pendingAdenAuth = useRef(false);
|
||||
const lastFocusFetch = useRef(0);
|
||||
|
||||
const fetchStatus = useCallback(async () => {
|
||||
setError(null);
|
||||
@@ -108,8 +119,7 @@ export default function CredentialsModal({
|
||||
|
||||
// Real agent — ask backend what credentials it actually needs
|
||||
setLoading(true);
|
||||
const { required, has_aden_key } = await credentialsApi.checkAgent(agentPath);
|
||||
setHasAdenKey(has_aden_key);
|
||||
const { required } = await credentialsApi.checkAgent(agentPath);
|
||||
credentialCache.set(agentPath, required);
|
||||
setRows(required.map(requirementToRow));
|
||||
} else {
|
||||
@@ -125,6 +135,7 @@ export default function CredentialsModal({
|
||||
adenSupported: false,
|
||||
valid: null,
|
||||
validationMessage: null,
|
||||
alternativeGroup: null,
|
||||
})));
|
||||
} else {
|
||||
setRows([]);
|
||||
@@ -140,56 +151,73 @@ export default function CredentialsModal({
|
||||
fetchStatus();
|
||||
setEditingId(null);
|
||||
setInputValue("");
|
||||
setAdenKeyInput("");
|
||||
setDeletingId(null);
|
||||
}
|
||||
}, [open, fetchStatus]);
|
||||
|
||||
const handleSaveAdenKey = async () => {
|
||||
if (!adenKeyInput.trim()) return;
|
||||
setSavingAdenKey(true);
|
||||
try {
|
||||
await credentialsApi.saveAdenKey(adenKeyInput.trim());
|
||||
setAdenKeyInput("");
|
||||
// Re-fetch when user returns to window (e.g. after completing OAuth on Aden).
|
||||
// Uses "focus" instead of "visibilitychange" because window.open("_blank")
|
||||
// doesn't reliably trigger visibilitychange — the original tab may never
|
||||
// lose visibility. "focus" fires reliably when the user clicks back.
|
||||
useEffect(() => {
|
||||
if (!open) return;
|
||||
const handleFocus = () => {
|
||||
// Debounce: skip if we fetched within the last 3 seconds
|
||||
const now = Date.now();
|
||||
if (now - lastFocusFetch.current < 3000) return;
|
||||
lastFocusFetch.current = now;
|
||||
if (agentPath) credentialCache.delete(agentPath);
|
||||
onCredentialChange?.();
|
||||
await fetchStatus();
|
||||
} catch {
|
||||
setError("Failed to save Aden API Key");
|
||||
} finally {
|
||||
setSavingAdenKey(false);
|
||||
}
|
||||
};
|
||||
fetchStatus();
|
||||
if (pendingAdenAuth.current) {
|
||||
pendingAdenAuth.current = false;
|
||||
setEditingId("aden_api_key");
|
||||
setInputValue("");
|
||||
}
|
||||
};
|
||||
window.addEventListener("focus", handleFocus);
|
||||
return () => window.removeEventListener("focus", handleFocus);
|
||||
}, [open, agentPath, fetchStatus]);
|
||||
|
||||
const handleConnect = async (row: CredentialRow) => {
|
||||
if (editingId === row.id) {
|
||||
if (inputValue.trim()) {
|
||||
// Has input — save the key
|
||||
setSaving(true);
|
||||
try {
|
||||
await credentialsApi.save(row.id, { [row.credentialKey]: inputValue.trim() });
|
||||
setEditingId(null);
|
||||
setInputValue("");
|
||||
if (agentPath) credentialCache.delete(agentPath);
|
||||
onCredentialChange?.();
|
||||
await fetchStatus();
|
||||
} catch {
|
||||
setError(`Failed to save ${row.name}`);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Empty input on aden_api_key — fall through to re-open Aden
|
||||
if (row.id !== "aden_api_key") return;
|
||||
}
|
||||
|
||||
if (row.id === "aden_api_key" && row.adenSupported) {
|
||||
// Aden Platform key — open Aden so user can grab key from Developers tab
|
||||
window.open("https://hive.adenhq.com/", "_blank", "noopener");
|
||||
pendingAdenAuth.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (row.adenSupported) {
|
||||
// OAuth credential — redirect to Aden platform
|
||||
window.open("https://hive.adenhq.com/", "_blank", "noopener");
|
||||
return;
|
||||
}
|
||||
|
||||
if (editingId === row.id) {
|
||||
// Already editing — save
|
||||
if (!inputValue.trim()) return;
|
||||
setSaving(true);
|
||||
try {
|
||||
await credentialsApi.save(row.id, { [row.credentialKey]: inputValue.trim() });
|
||||
setEditingId(null);
|
||||
setInputValue("");
|
||||
if (agentPath) credentialCache.delete(agentPath);
|
||||
onCredentialChange?.();
|
||||
await fetchStatus();
|
||||
} catch {
|
||||
setError(`Failed to save ${row.name}`);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
} else {
|
||||
// Start editing — show inline API key input
|
||||
setEditingId(row.id);
|
||||
setInputValue("");
|
||||
setDeletingId(null);
|
||||
}
|
||||
// Start editing — show inline API key input
|
||||
setEditingId(row.id);
|
||||
setInputValue("");
|
||||
setDeletingId(null);
|
||||
};
|
||||
|
||||
const handleDisconnect = async (row: CredentialRow) => {
|
||||
@@ -210,12 +238,29 @@ export default function CredentialsModal({
|
||||
if (!open) return null;
|
||||
|
||||
const connectedCount = rows.filter(c => c.connected).length;
|
||||
const requiredCount = rows.filter(c => c.required).length;
|
||||
const requiredConnected = rows.filter(c => c.required && c.connected).length;
|
||||
const invalidCount = rows.filter(c => c.valid === false).length;
|
||||
const missingCount = requiredCount - requiredConnected;
|
||||
const allRequiredMet = requiredConnected === requiredCount && invalidCount === 0;
|
||||
const needsAdenKeyInput = !hasAdenKey && rows.some(r => r.adenSupported);
|
||||
|
||||
// Alternative groups (e.g. send_email → resend OR google): satisfied if ANY is connected & valid
|
||||
const altGroups = new Map<string, boolean>();
|
||||
for (const c of rows) {
|
||||
if (!c.alternativeGroup) continue;
|
||||
if (!altGroups.has(c.alternativeGroup)) altGroups.set(c.alternativeGroup, false);
|
||||
if (c.connected && c.valid !== false) altGroups.set(c.alternativeGroup, true);
|
||||
}
|
||||
const altGroupsSatisfied = altGroups.size === 0 || [...altGroups.values()].every(Boolean);
|
||||
|
||||
// Non-alternative required credentials
|
||||
const nonAltRequired = rows.filter(c => c.required && !c.alternativeGroup);
|
||||
const nonAltMet = nonAltRequired.every(c => c.connected && c.valid !== false);
|
||||
|
||||
const allRequiredMet = nonAltMet && altGroupsSatisfied;
|
||||
|
||||
// For status banner counts
|
||||
const nonAltMissing = nonAltRequired.filter(c => !c.connected).length;
|
||||
const altGroupsMissing = [...altGroups.values()].filter(v => !v).length;
|
||||
const missingCount = nonAltMissing + altGroupsMissing;
|
||||
|
||||
const adenPlatformConnected = rows.find(r => r.id === "aden_api_key")?.connected ?? false;
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -280,50 +325,6 @@ export default function CredentialsModal({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Aden API Key section */}
|
||||
{!loading && needsAdenKeyInput && (
|
||||
<div className="mx-5 mt-4 px-3 py-3 rounded-lg border border-amber-500/30 bg-amber-500/5">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<KeyRound className="w-3.5 h-3.5 text-amber-600" />
|
||||
<span className="text-sm font-medium text-foreground">Aden API Key</span>
|
||||
<span className="text-[9px] font-semibold uppercase tracking-wider px-1.5 py-0.5 rounded text-destructive/70 bg-destructive/10">
|
||||
Required
|
||||
</span>
|
||||
</div>
|
||||
<p className="text-[11px] text-muted-foreground mb-2">
|
||||
Required to connect OAuth integrations below.{" "}
|
||||
<a
|
||||
href="https://hive.adenhq.com/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-primary hover:underline inline-flex items-center gap-0.5"
|
||||
>
|
||||
Get your key at hive.adenhq.com
|
||||
<ExternalLink className="w-2.5 h-2.5" />
|
||||
</a>
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="password"
|
||||
value={adenKeyInput}
|
||||
onChange={(e) => setAdenKeyInput(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") handleSaveAdenKey();
|
||||
}}
|
||||
placeholder="Paste your ADEN_API_KEY..."
|
||||
className="flex-1 px-3 py-1.5 rounded-md border border-border bg-background text-xs text-foreground placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-primary/40"
|
||||
/>
|
||||
<button
|
||||
onClick={handleSaveAdenKey}
|
||||
disabled={savingAdenKey || !adenKeyInput.trim()}
|
||||
className="px-3 py-1.5 rounded-md text-xs font-medium bg-primary text-primary-foreground hover:bg-primary/90 disabled:opacity-50 disabled:cursor-not-allowed transition-colors"
|
||||
>
|
||||
{savingAdenKey ? <Loader2 className="w-3 h-3 animate-spin" /> : "Save"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Credential list */}
|
||||
{!loading && (
|
||||
<div className="p-5 space-y-2">
|
||||
@@ -343,13 +344,23 @@ export default function CredentialsModal({
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm font-medium text-foreground">{row.name}</span>
|
||||
{row.required && (
|
||||
<span className={`text-[9px] font-semibold uppercase tracking-wider px-1.5 py-0.5 rounded ${
|
||||
row.connected
|
||||
? "text-emerald-600/70 bg-emerald-500/10"
|
||||
: "text-destructive/70 bg-destructive/10"
|
||||
}`}>
|
||||
Required
|
||||
</span>
|
||||
row.alternativeGroup ? (
|
||||
<span className={`text-[9px] font-semibold uppercase tracking-wider px-1.5 py-0.5 rounded ${
|
||||
row.connected
|
||||
? "text-emerald-600/70 bg-emerald-500/10"
|
||||
: "text-amber-600/70 bg-amber-500/10"
|
||||
}`}>
|
||||
Either
|
||||
</span>
|
||||
) : (
|
||||
<span className={`text-[9px] font-semibold uppercase tracking-wider px-1.5 py-0.5 rounded ${
|
||||
row.connected
|
||||
? "text-emerald-600/70 bg-emerald-500/10"
|
||||
: "text-destructive/70 bg-destructive/10"
|
||||
}`}>
|
||||
Required
|
||||
</span>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
<p className="text-[11px] text-muted-foreground mt-0.5">{row.description}</p>
|
||||
@@ -375,6 +386,20 @@ export default function CredentialsModal({
|
||||
Connected
|
||||
</span>
|
||||
)}
|
||||
{(row.id === "aden_api_key" || !row.adenSupported) && (
|
||||
<button
|
||||
onClick={() => {
|
||||
setEditingId(editingId === row.id ? null : row.id);
|
||||
setInputValue("");
|
||||
setDeletingId(null);
|
||||
}}
|
||||
disabled={saving}
|
||||
className="p-1.5 rounded-md text-muted-foreground hover:text-foreground hover:bg-muted/60 transition-colors"
|
||||
title="Update key"
|
||||
>
|
||||
<Pencil className="w-3 h-3" />
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={() => {
|
||||
setDeletingId(deletingId === row.id ? null : row.id);
|
||||
@@ -387,6 +412,10 @@ export default function CredentialsModal({
|
||||
<Trash2 className="w-3 h-3" />
|
||||
</button>
|
||||
</div>
|
||||
) : row.adenSupported && !adenPlatformConnected && row.id !== "aden_api_key" ? (
|
||||
<span className="text-[11px] text-muted-foreground italic flex-shrink-0">
|
||||
Connect Aden Platform key first
|
||||
</span>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleConnect(row)}
|
||||
@@ -435,7 +464,7 @@ export default function CredentialsModal({
|
||||
)}
|
||||
|
||||
{/* Inline API key input */}
|
||||
{editingId === row.id && (!row.connected || row.valid === false) && (
|
||||
{editingId === row.id && (
|
||||
<div className="mt-1.5 flex gap-2 px-3">
|
||||
<input
|
||||
type="password"
|
||||
|
||||
@@ -20,9 +20,19 @@ interface ToolCredential {
|
||||
value?: string;
|
||||
}
|
||||
|
||||
export interface SubagentReport {
|
||||
subagent_id: string;
|
||||
message: string;
|
||||
data?: Record<string, unknown>;
|
||||
timestamp: string;
|
||||
status?: "running" | "complete" | "error";
|
||||
}
|
||||
|
||||
interface NodeDetailPanelProps {
|
||||
node: GraphNode | null;
|
||||
nodeSpec?: NodeSpec | null;
|
||||
allNodeSpecs?: NodeSpec[];
|
||||
subagentReports?: SubagentReport[];
|
||||
sessionId?: string;
|
||||
graphId?: string;
|
||||
workerSessionId?: string | null;
|
||||
@@ -195,10 +205,96 @@ function SystemPromptTab({ systemPrompt }: { systemPrompt?: string }) {
|
||||
);
|
||||
}
|
||||
|
||||
function SubagentsTab() {
|
||||
function SubagentStatusBadge({ status }: { status?: "running" | "complete" | "error" }) {
|
||||
if (!status) return null;
|
||||
if (status === "running") {
|
||||
return (
|
||||
<span className="ml-auto flex items-center gap-1 text-[10px] font-medium flex-shrink-0" style={{ color: "hsl(45,95%,58%)" }}>
|
||||
<span className="relative flex h-1.5 w-1.5">
|
||||
<span className="animate-ping absolute inline-flex h-full w-full rounded-full opacity-75" style={{ backgroundColor: "hsl(45,95%,58%)" }} />
|
||||
<span className="relative inline-flex rounded-full h-1.5 w-1.5" style={{ backgroundColor: "hsl(45,95%,58%)" }} />
|
||||
</span>
|
||||
Running
|
||||
</span>
|
||||
);
|
||||
}
|
||||
if (status === "complete") {
|
||||
return (
|
||||
<span className="ml-auto flex items-center gap-1 text-[10px] font-medium flex-shrink-0" style={{ color: "hsl(43,70%,45%)" }}>
|
||||
<CheckCircle2 className="w-3 h-3" />
|
||||
Complete
|
||||
</span>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="flex-1 flex items-center justify-center">
|
||||
<p className="text-xs text-muted-foreground/60 italic text-center">No subagents assigned to this node.</p>
|
||||
<span className="ml-auto flex items-center gap-1 text-[10px] font-medium flex-shrink-0" style={{ color: "hsl(0,65%,55%)" }}>
|
||||
<AlertCircle className="w-3 h-3" />
|
||||
Failed
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function SubagentsTab({ subAgentIds, allNodeSpecs, subagentReports }: { subAgentIds: string[]; allNodeSpecs: NodeSpec[]; subagentReports: SubagentReport[] }) {
|
||||
if (subAgentIds.length === 0) {
|
||||
return (
|
||||
<div className="flex-1 flex items-center justify-center">
|
||||
<p className="text-xs text-muted-foreground/60 italic text-center">No subagents assigned to this node.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<p className="text-[10px] font-medium text-muted-foreground uppercase tracking-wider mb-1">Sub-agents ({subAgentIds.length})</p>
|
||||
{subAgentIds.map(saId => {
|
||||
const spec = allNodeSpecs.find(n => n.id === saId);
|
||||
const reports = subagentReports.filter(r => r.subagent_id === saId);
|
||||
// Derive status from latest report that has a status field
|
||||
const latestStatus = [...reports].reverse().find(r => r.status)?.status;
|
||||
// Progress messages are reports without a status field (from report_to_parent)
|
||||
const progressReports = reports.filter(r => !r.status);
|
||||
|
||||
return (
|
||||
<div key={saId} className="rounded-xl border border-border/20 overflow-hidden">
|
||||
<div className="p-3 bg-muted/30">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<Bot className="w-3.5 h-3.5 text-primary/70 flex-shrink-0" />
|
||||
<span className="text-xs font-medium text-foreground truncate">{spec?.name || saId}</span>
|
||||
<SubagentStatusBadge status={latestStatus} />
|
||||
</div>
|
||||
{spec?.description && (
|
||||
<p className="text-[11px] text-muted-foreground leading-relaxed mt-1">{spec.description}</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Static info: tools + output keys */}
|
||||
<div className="px-3 py-2 border-t border-border/15 bg-muted/15">
|
||||
{spec?.tools && spec.tools.length > 0 && (
|
||||
<div className="mb-1.5">
|
||||
<span className="text-[10px] text-muted-foreground font-medium">Tools: </span>
|
||||
<span className="text-[10px] text-foreground/70">{spec.tools.join(", ")}</span>
|
||||
</div>
|
||||
)}
|
||||
{spec?.output_keys && spec.output_keys.length > 0 && (
|
||||
<div>
|
||||
<span className="text-[10px] text-muted-foreground font-medium">Outputs: </span>
|
||||
<span className="text-[10px] text-foreground/70 font-mono">{spec.output_keys.join(", ")}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Live progress reports (from report_to_parent) */}
|
||||
{progressReports.length > 0 && (
|
||||
<div className="px-3 py-2 border-t border-border/15 bg-background/60">
|
||||
<p className="text-[10px] text-muted-foreground font-medium mb-1">Reports ({progressReports.length})</p>
|
||||
{progressReports.map((r, i) => (
|
||||
<div key={i} className="text-[10.5px] text-foreground/70 leading-relaxed py-0.5">{r.message}</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -213,7 +309,7 @@ const tabs: { id: Tab; label: string; Icon: React.FC<{ className?: string }> }[]
|
||||
{ id: "subagents", label: "Subagents", Icon: ({ className }) => <Bot className={className} /> },
|
||||
];
|
||||
|
||||
export default function NodeDetailPanel({ node, nodeSpec, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
|
||||
export default function NodeDetailPanel({ node, nodeSpec, allNodeSpecs, subagentReports, sessionId, graphId, workerSessionId, nodeLogs, actionPlan, onClose }: NodeDetailPanelProps) {
|
||||
const [activeTab, setActiveTab] = useState<Tab>("overview");
|
||||
const [realTools, setRealTools] = useState<ToolInfo[] | null>(null);
|
||||
const [realCriteria, setRealCriteria] = useState<NodeCriteria | null>(null);
|
||||
@@ -295,7 +391,7 @@ export default function NodeDetailPanel({ node, nodeSpec, sessionId, graphId, wo
|
||||
|
||||
{/* Tab bar */}
|
||||
<div className="flex border-b border-border/30 flex-shrink-0 px-2 pt-1 overflow-x-auto scrollbar-hide">
|
||||
{tabs.map(tab => (
|
||||
{tabs.filter(t => t.id !== "subagents" || (nodeSpec?.sub_agents && nodeSpec.sub_agents.length > 0)).map(tab => (
|
||||
<button
|
||||
key={tab.id}
|
||||
onClick={() => setActiveTab(tab.id)}
|
||||
@@ -397,8 +493,12 @@ export default function NodeDetailPanel({ node, nodeSpec, sessionId, graphId, wo
|
||||
<SystemPromptTab systemPrompt={nodeSpec?.system_prompt} />
|
||||
)}
|
||||
|
||||
{activeTab === "subagents" && (
|
||||
<SubagentsTab />
|
||||
{activeTab === "subagents" && nodeSpec?.sub_agents && (
|
||||
<SubagentsTab
|
||||
subAgentIds={nodeSpec.sub_agents}
|
||||
allNodeSpecs={allNodeSpecs || []}
|
||||
subagentReports={subagentReports || []}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -91,7 +91,7 @@ export default function TopBar({ tabs: tabsProp, onTabClick, onCloseTab, canClos
|
||||
<div className="flex items-center gap-3 min-w-0">
|
||||
<button onClick={() => navigate("/")} className="flex items-center gap-2 hover:opacity-80 transition-opacity flex-shrink-0">
|
||||
<Crown className="w-4 h-4 text-primary" />
|
||||
<span className="text-sm font-semibold text-primary">Hive</span>
|
||||
<span className="text-sm font-semibold text-primary">Open Hive</span>
|
||||
</button>
|
||||
|
||||
{tabs.length > 0 && (
|
||||
|
||||
@@ -80,28 +80,39 @@ export function useMultiSSE({ sessions, onEvent }: UseMultiSSEOptions) {
|
||||
const onEventRef = useRef(onEvent);
|
||||
onEventRef.current = onEvent;
|
||||
|
||||
const sourcesRef = useRef(new Map<string, EventSource>());
|
||||
// Track both the EventSource and its session ID so we can detect session changes
|
||||
const sourcesRef = useRef(new Map<string, { es: EventSource; sessionId: string }>());
|
||||
|
||||
// Diff-based open/close — runs on every `sessions` change
|
||||
useEffect(() => {
|
||||
const current = sourcesRef.current;
|
||||
const desired = new Set(Object.keys(sessions));
|
||||
|
||||
// Close connections for sessions no longer in the map
|
||||
for (const [agentType, es] of current) {
|
||||
if (!desired.has(agentType)) {
|
||||
es.close();
|
||||
// Close connections for removed agents OR changed session IDs
|
||||
for (const [agentType, entry] of current) {
|
||||
if (!desired.has(agentType) || sessions[agentType] !== entry.sessionId) {
|
||||
console.log('[SSE] closing:', agentType, entry.sessionId, desired.has(agentType) ? '(session changed)' : '(removed)');
|
||||
entry.es.close();
|
||||
current.delete(agentType);
|
||||
}
|
||||
}
|
||||
|
||||
// Open connections for newly added sessions
|
||||
// Open connections for new/changed sessions
|
||||
for (const [agentType, sessionId] of Object.entries(sessions)) {
|
||||
if (!sessionId || current.has(agentType)) continue;
|
||||
|
||||
const url = `/api/sessions/${sessionId}/events`;
|
||||
console.log('[SSE] opening:', agentType, sessionId);
|
||||
const es = new EventSource(url);
|
||||
|
||||
es.onopen = () => {
|
||||
console.log('[SSE] connected:', agentType, sessionId);
|
||||
};
|
||||
|
||||
es.onerror = () => {
|
||||
console.error('[SSE] error:', agentType, sessionId, 'readyState:', es.readyState);
|
||||
};
|
||||
|
||||
es.onmessage = (e: MessageEvent) => {
|
||||
try {
|
||||
const event: AgentEvent = JSON.parse(e.data);
|
||||
@@ -112,14 +123,14 @@ export function useMultiSSE({ sessions, onEvent }: UseMultiSSEOptions) {
|
||||
}
|
||||
};
|
||||
|
||||
current.set(agentType, es);
|
||||
current.set(agentType, { es, sessionId });
|
||||
}
|
||||
}, [sessions]);
|
||||
|
||||
// Close all on unmount only
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
for (const es of sourcesRef.current.values()) es.close();
|
||||
for (const entry of sourcesRef.current.values()) entry.es.close();
|
||||
sourcesRef.current.clear();
|
||||
};
|
||||
}, []);
|
||||
|
||||
@@ -87,6 +87,11 @@
|
||||
button {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
textarea {
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
}
|
||||
|
||||
* {
|
||||
|
||||
@@ -160,7 +160,97 @@ describe("sseEventToChatMessage", () => {
|
||||
expect(result!.id).toBe("stream-exec-123-chat");
|
||||
});
|
||||
|
||||
it("falls back to '0' when both turnId and execution_id are null", () => {
|
||||
it("combines execution_id and turnId to differentiate loop iterations", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "chat",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "hello" },
|
||||
});
|
||||
const r1 = sseEventToChatMessage(event, "t", undefined, 1);
|
||||
const r2 = sseEventToChatMessage(event, "t", undefined, 2);
|
||||
expect(r1!.id).toBe("stream-exec-1-1-chat");
|
||||
expect(r2!.id).toBe("stream-exec-1-2-chat");
|
||||
expect(r1!.id).not.toBe(r2!.id);
|
||||
});
|
||||
|
||||
it("same execution_id + same turnId produces same ID (streaming upsert within iteration)", () => {
|
||||
const e1 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "chat",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "partial" },
|
||||
});
|
||||
const e2 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "chat",
|
||||
execution_id: "exec-1",
|
||||
data: { snapshot: "partial response" },
|
||||
});
|
||||
expect(sseEventToChatMessage(e1, "t", undefined, 3)!.id).toBe(
|
||||
sseEventToChatMessage(e2, "t", undefined, 3)!.id,
|
||||
);
|
||||
});
|
||||
|
||||
it("uses data.iteration over turnId when present", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: null,
|
||||
data: { snapshot: "hello", iteration: 5 },
|
||||
});
|
||||
const result = sseEventToChatMessage(event, "t", undefined, 2);
|
||||
expect(result!.id).toBe("stream-5-queen");
|
||||
});
|
||||
|
||||
it("falls back to turnId when data.iteration is absent", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: null,
|
||||
data: { snapshot: "hello" },
|
||||
});
|
||||
const result = sseEventToChatMessage(event, "t", undefined, 2);
|
||||
expect(result!.id).toBe("stream-2-queen");
|
||||
});
|
||||
|
||||
it("different iterations from same node produce different message IDs", () => {
|
||||
const e1 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "",
|
||||
data: { snapshot: "first response", iteration: 0 },
|
||||
});
|
||||
const e2 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "",
|
||||
data: { snapshot: "second response", iteration: 3 },
|
||||
});
|
||||
const r1 = sseEventToChatMessage(e1, "t");
|
||||
const r2 = sseEventToChatMessage(e2, "t");
|
||||
expect(r1!.id).not.toBe(r2!.id);
|
||||
});
|
||||
|
||||
it("same iteration produces same ID for streaming upsert", () => {
|
||||
const e1 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "",
|
||||
data: { snapshot: "partial", iteration: 2 },
|
||||
});
|
||||
const e2 = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "queen",
|
||||
execution_id: "",
|
||||
data: { snapshot: "partial response", iteration: 2 },
|
||||
});
|
||||
expect(sseEventToChatMessage(e1, "t")!.id).toBe(
|
||||
sseEventToChatMessage(e2, "t")!.id,
|
||||
);
|
||||
});
|
||||
|
||||
it("uses timestamp fallback when both turnId and execution_id are null", () => {
|
||||
const event = makeEvent({
|
||||
type: "client_output_delta",
|
||||
node_id: "chat",
|
||||
@@ -168,7 +258,7 @@ describe("sseEventToChatMessage", () => {
|
||||
data: { snapshot: "hello" },
|
||||
});
|
||||
const result = sseEventToChatMessage(event, "t");
|
||||
expect(result!.id).toBe("stream-0-chat");
|
||||
expect(result!.id).toMatch(/^stream-t-\d+-chat$/);
|
||||
});
|
||||
|
||||
it("converts client_input_requested with prompt to message", () => {
|
||||
|
||||
@@ -37,8 +37,11 @@ export function backendMessageToChatMessage(
|
||||
thread: string,
|
||||
agentDisplayName?: string,
|
||||
): ChatMessage {
|
||||
// Use file-mtime created_at (epoch seconds → ms) for cross-conversation
|
||||
// ordering; fall back to seq for backwards compatibility.
|
||||
const createdAt = msg.created_at ? msg.created_at * 1000 : msg.seq;
|
||||
return {
|
||||
id: `backend-${msg.seq}`,
|
||||
id: `backend-${msg._node_id}-${msg.seq}`,
|
||||
agent: msg.role === "user" ? "You" : agentDisplayName || msg._node_id || "Agent",
|
||||
agentColor: "",
|
||||
content: msg.content,
|
||||
@@ -46,6 +49,7 @@ export function backendMessageToChatMessage(
|
||||
type: msg.role === "user" ? "user" : undefined,
|
||||
role: msg.role === "user" ? undefined : "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -61,40 +65,41 @@ export function sseEventToChatMessage(
|
||||
agentDisplayName?: string,
|
||||
turnId?: number,
|
||||
): ChatMessage | null {
|
||||
// turnId disambiguates messages across response turns. Within a single
|
||||
// turn the ID stays stable so the upsert logic can replace the previous
|
||||
// snapshot (streaming). Across turns, different turnIds produce different
|
||||
// IDs so each response gets its own bubble.
|
||||
const idKey = turnId != null ? String(turnId) : (event.execution_id ?? "0");
|
||||
// Combine execution_id (unique per execution) with turnId (increments per
|
||||
// loop iteration) so each iteration gets its own bubble while streaming
|
||||
// deltas within one iteration still share the same ID for upsert.
|
||||
const eid = event.execution_id ?? "";
|
||||
const tid = turnId != null ? String(turnId) : "";
|
||||
const idKey = eid && tid ? `${eid}-${tid}` : eid || tid || `t-${Date.now()}`;
|
||||
// Use the backend event timestamp for message ordering
|
||||
const createdAt = event.timestamp ? new Date(event.timestamp).getTime() : Date.now();
|
||||
|
||||
switch (event.type) {
|
||||
case "client_output_delta": {
|
||||
// Prefer backend-provided iteration (reliable, embedded in event data)
|
||||
// over frontend turnCounter (can desync when SSE queue drops events).
|
||||
const iter = event.data?.iteration;
|
||||
const iterTid = iter != null ? String(iter) : tid;
|
||||
const iterIdKey = eid && iterTid ? `${eid}-${iterTid}` : eid || iterTid || `t-${Date.now()}`;
|
||||
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
if (!snapshot) return null;
|
||||
return {
|
||||
id: `stream-${idKey}-${event.node_id}`,
|
||||
id: `stream-${iterIdKey}-${event.node_id}`,
|
||||
agent: agentDisplayName || event.node_id || "Agent",
|
||||
agentColor: "",
|
||||
content: snapshot,
|
||||
timestamp: "",
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
};
|
||||
}
|
||||
|
||||
case "client_input_requested": {
|
||||
const prompt = (event.data?.prompt as string) || "";
|
||||
if (!prompt) return null;
|
||||
return {
|
||||
id: `input-req-${idKey}-${event.node_id}`,
|
||||
agent: agentDisplayName || event.node_id || "Agent",
|
||||
agentColor: "",
|
||||
content: prompt,
|
||||
timestamp: "",
|
||||
role: "worker",
|
||||
thread,
|
||||
};
|
||||
}
|
||||
case "client_input_requested":
|
||||
// Handled explicitly in handleSSEEvent (workspace.tsx) so it can
|
||||
// create a worker_input_request message and set awaitingInput state.
|
||||
return null;
|
||||
|
||||
case "llm_text_delta": {
|
||||
const snapshot = (event.data?.snapshot as string) || (event.data?.content as string) || "";
|
||||
@@ -107,6 +112,7 @@ export function sseEventToChatMessage(
|
||||
timestamp: "",
|
||||
role: "worker",
|
||||
thread,
|
||||
createdAt,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -119,6 +125,7 @@ export function sseEventToChatMessage(
|
||||
timestamp: "",
|
||||
type: "system",
|
||||
thread,
|
||||
createdAt,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -132,6 +139,7 @@ export function sseEventToChatMessage(
|
||||
timestamp: "",
|
||||
type: "system",
|
||||
thread,
|
||||
createdAt,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -258,3 +258,118 @@ describe("node ordering", () => {
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Trigger node synthesis from entry_points
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("trigger node synthesis", () => {
|
||||
it("single non-manual entry point: trigger node prepended before entry_node", () => {
|
||||
const topology: GraphTopology = {
|
||||
nodes: [makeNode("A"), makeNode("B")],
|
||||
edges: [
|
||||
{ source: "A", target: "B", condition: "on_success", priority: 0 },
|
||||
],
|
||||
entry_node: "A",
|
||||
entry_points: [
|
||||
{ id: "webhook", name: "Webhook Handler", entry_node: "A", trigger_type: "webhook", trigger_config: { url: "/hook" } },
|
||||
],
|
||||
};
|
||||
|
||||
const result = topologyToGraphNodes(topology);
|
||||
expect(result).toHaveLength(3);
|
||||
|
||||
const trigger = result[0];
|
||||
expect(trigger.id).toBe("__trigger_webhook");
|
||||
expect(trigger.nodeType).toBe("trigger");
|
||||
expect(trigger.triggerType).toBe("webhook");
|
||||
expect(trigger.triggerConfig).toEqual({ url: "/hook" });
|
||||
expect(trigger.label).toBe("Webhook Handler");
|
||||
expect(trigger.status).toBe("pending");
|
||||
expect(trigger.next).toEqual(["A"]);
|
||||
});
|
||||
|
||||
it("trigger_config is threaded through for timer triggers", () => {
|
||||
const topology: GraphTopology = {
|
||||
nodes: [makeNode("A")],
|
||||
edges: [],
|
||||
entry_node: "A",
|
||||
entry_points: [
|
||||
{ id: "timer", name: "Daily Check", entry_node: "A", trigger_type: "timer", trigger_config: { cron: "0 9 * * *" } },
|
||||
],
|
||||
};
|
||||
|
||||
const result = topologyToGraphNodes(topology);
|
||||
const trigger = result[0];
|
||||
expect(trigger.triggerConfig).toEqual({ cron: "0 9 * * *" });
|
||||
});
|
||||
|
||||
it("no entry_points: no trigger nodes added", () => {
|
||||
const topology: GraphTopology = {
|
||||
nodes: [makeNode("A")],
|
||||
edges: [],
|
||||
entry_node: "A",
|
||||
};
|
||||
|
||||
const result = topologyToGraphNodes(topology);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].nodeType).toBeUndefined();
|
||||
});
|
||||
|
||||
it("only manual entry points: no trigger nodes added", () => {
|
||||
const topology: GraphTopology = {
|
||||
nodes: [makeNode("A")],
|
||||
edges: [],
|
||||
entry_node: "A",
|
||||
entry_points: [
|
||||
{ id: "main", name: "Main", entry_node: "A", trigger_type: "manual" },
|
||||
],
|
||||
};
|
||||
|
||||
const result = topologyToGraphNodes(topology);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe("A");
|
||||
});
|
||||
|
||||
it("multiple non-manual entry points: multiple trigger nodes", () => {
|
||||
const topology: GraphTopology = {
|
||||
nodes: [makeNode("A"), makeNode("B"), makeNode("C")],
|
||||
edges: [
|
||||
{ source: "A", target: "C", condition: "on_success", priority: 0 },
|
||||
{ source: "B", target: "C", condition: "on_success", priority: 0 },
|
||||
],
|
||||
entry_node: "A",
|
||||
entry_points: [
|
||||
{ id: "webhook", name: "Webhook", entry_node: "A", trigger_type: "webhook" },
|
||||
{ id: "timer", name: "Daily Timer", entry_node: "B", trigger_type: "timer" },
|
||||
],
|
||||
};
|
||||
|
||||
const result = topologyToGraphNodes(topology);
|
||||
expect(result).toHaveLength(5); // 2 triggers + 3 nodes
|
||||
const triggers = result.filter((n) => n.nodeType === "trigger");
|
||||
expect(triggers).toHaveLength(2);
|
||||
expect(triggers[0].next).toEqual(["A"]);
|
||||
expect(triggers[1].next).toEqual(["B"]);
|
||||
});
|
||||
|
||||
it("mix of manual and non-manual: only non-manual become trigger nodes", () => {
|
||||
const topology: GraphTopology = {
|
||||
nodes: [makeNode("A"), makeNode("B")],
|
||||
edges: [
|
||||
{ source: "A", target: "B", condition: "on_success", priority: 0 },
|
||||
],
|
||||
entry_node: "A",
|
||||
entry_points: [
|
||||
{ id: "main", name: "Main", entry_node: "A", trigger_type: "manual" },
|
||||
{ id: "webhook", name: "Webhook", entry_node: "A", trigger_type: "webhook" },
|
||||
],
|
||||
};
|
||||
|
||||
const result = topologyToGraphNodes(topology);
|
||||
expect(result).toHaveLength(3); // 1 trigger + 2 nodes
|
||||
const triggers = result.filter((n) => n.nodeType === "trigger");
|
||||
expect(triggers).toHaveLength(1);
|
||||
expect(triggers[0].triggerType).toBe("webhook");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,31 +5,82 @@ import type { GraphNode, NodeStatus } from "@/components/AgentGraph";
|
||||
* Convert a backend GraphTopology (nodes + edges + entry_node) into
|
||||
* the GraphNode[] shape that AgentGraph renders.
|
||||
*
|
||||
* Three jobs:
|
||||
* 1. Order nodes via BFS from entry_node
|
||||
* 2. Classify edges as forward (next) or backward (backEdges)
|
||||
* 3. Map session enrichment fields to NodeStatus
|
||||
* Four jobs:
|
||||
* 1. Synthesize trigger nodes from non-manual entry_points
|
||||
* 2. Order nodes via BFS from trigger/entry_node
|
||||
* 3. Classify edges as forward (next) or backward (backEdges)
|
||||
* 4. Map session enrichment fields to NodeStatus
|
||||
*/
|
||||
export function topologyToGraphNodes(topology: GraphTopology): GraphNode[] {
|
||||
const { nodes, edges, entry_node } = topology;
|
||||
if (nodes.length === 0) return [];
|
||||
const { nodes: allNodes, edges, entry_node, entry_points } = topology;
|
||||
if (allNodes.length === 0) return [];
|
||||
|
||||
// Build adjacency list: source → [target, ...]
|
||||
// Filter out subagent-only nodes (referenced in sub_agents but not in any edge)
|
||||
const subagentIds = new Set<string>();
|
||||
for (const n of allNodes) {
|
||||
for (const sa of n.sub_agents ?? []) {
|
||||
subagentIds.add(sa);
|
||||
}
|
||||
}
|
||||
const edgeParticipants = new Set<string>();
|
||||
for (const e of edges) {
|
||||
edgeParticipants.add(e.source);
|
||||
edgeParticipants.add(e.target);
|
||||
}
|
||||
const nodes = allNodes.filter(
|
||||
(n) =>
|
||||
!subagentIds.has(n.id) ||
|
||||
edgeParticipants.has(n.id) ||
|
||||
n.id === entry_node,
|
||||
);
|
||||
|
||||
// --- Synthesize trigger nodes for non-manual entry points ---
|
||||
const schedulerEntryPoints = (entry_points || []).filter(
|
||||
(ep) => ep.trigger_type !== "manual",
|
||||
);
|
||||
const triggerMap = new Map<string, GraphNode>();
|
||||
|
||||
for (const ep of schedulerEntryPoints) {
|
||||
const triggerId = `__trigger_${ep.id}`;
|
||||
triggerMap.set(triggerId, {
|
||||
id: triggerId,
|
||||
label: ep.name,
|
||||
status: "pending",
|
||||
nodeType: "trigger",
|
||||
triggerType: ep.trigger_type,
|
||||
triggerConfig: {
|
||||
...ep.trigger_config,
|
||||
...(ep.next_fire_in != null ? { next_fire_in: ep.next_fire_in } : {}),
|
||||
},
|
||||
next: [ep.entry_node],
|
||||
});
|
||||
}
|
||||
|
||||
// Build adjacency list: source → [target, ...] (includes trigger edges)
|
||||
const adj = new Map<string, string[]>();
|
||||
for (const e of edges) {
|
||||
const list = adj.get(e.source) || [];
|
||||
list.push(e.target);
|
||||
adj.set(e.source, list);
|
||||
}
|
||||
for (const [triggerId, triggerNode] of triggerMap) {
|
||||
adj.set(triggerId, triggerNode.next!);
|
||||
}
|
||||
|
||||
// BFS from entry_node to determine walk order + position map
|
||||
// BFS — start from trigger nodes (if any), then entry_node.
|
||||
// Always include entry_node so the DAG ordering stays correct
|
||||
// even when triggers target a node other than entry.
|
||||
const order: string[] = [];
|
||||
const position = new Map<string, number>();
|
||||
const visited = new Set<string>();
|
||||
|
||||
const start = entry_node || nodes[0].id;
|
||||
const queue = [start];
|
||||
visited.add(start);
|
||||
const entryStart = entry_node || nodes[0].id;
|
||||
const starts =
|
||||
triggerMap.size > 0
|
||||
? [...triggerMap.keys(), entryStart]
|
||||
: [entryStart];
|
||||
const queue = [...starts];
|
||||
for (const s of starts) visited.add(s);
|
||||
|
||||
while (queue.length > 0) {
|
||||
const id = queue.shift()!;
|
||||
@@ -91,6 +142,10 @@ export function topologyToGraphNodes(topology: GraphTopology): GraphNode[] {
|
||||
|
||||
// Build GraphNode[] in BFS order
|
||||
return order.map((id) => {
|
||||
// Synthetic trigger nodes are returned directly
|
||||
const trigger = triggerMap.get(id);
|
||||
if (trigger) return trigger;
|
||||
|
||||
const spec = nodeMap.get(id);
|
||||
const next = nextMap.get(id);
|
||||
const back = backMap.get(id);
|
||||
|
||||
@@ -9,7 +9,7 @@ import type { GraphNode } from "@/components/AgentGraph";
|
||||
export const TAB_STORAGE_KEY = "hive:workspace-tabs";
|
||||
|
||||
export interface PersistedTabState {
|
||||
tabs: Array<{ id: string; agentType: string; label: string }>;
|
||||
tabs: Array<{ id: string; agentType: string; label: string; backendSessionId?: string }>;
|
||||
activeSessionByAgent: Record<string, string>;
|
||||
activeWorker: string;
|
||||
sessions?: Record<string, { messages: ChatMessage[]; graphNodes: GraphNode[] }>;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import { Crown, Mail, Briefcase, Shield, Search, Newspaper, ArrowRight, Hexagon, Send, Bot } from "lucide-react";
|
||||
import TopBar from "@/components/TopBar";
|
||||
@@ -40,6 +40,7 @@ const promptHints = [
|
||||
export default function Home() {
|
||||
const navigate = useNavigate();
|
||||
const [inputValue, setInputValue] = useState("");
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const [showAgents, setShowAgents] = useState(false);
|
||||
const [agents, setAgents] = useState<DiscoverEntry[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
@@ -105,19 +106,26 @@ export default function Home() {
|
||||
{/* Chat input */}
|
||||
<form onSubmit={handleSubmit} className="mb-6">
|
||||
<div className="relative border border-border/60 rounded-xl bg-card/50 hover:border-primary/30 focus-within:border-primary/40 transition-colors shadow-sm">
|
||||
<input
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
rows={1}
|
||||
value={inputValue}
|
||||
onChange={(e) => setInputValue(e.target.value)}
|
||||
onChange={(e) => {
|
||||
setInputValue(e.target.value);
|
||||
const ta = e.target;
|
||||
ta.style.height = "auto";
|
||||
ta.style.height = `${Math.min(ta.scrollHeight, 160)}px`;
|
||||
}}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSubmit(e);
|
||||
}
|
||||
}}
|
||||
placeholder="Describe a task for the hive..."
|
||||
className="w-full bg-transparent px-5 py-3 pr-12 text-sm text-foreground placeholder:text-muted-foreground/60 focus:outline-none rounded-xl"
|
||||
className="w-full bg-transparent px-5 py-4 pr-12 text-sm text-foreground placeholder:text-muted-foreground/60 focus:outline-none rounded-xl resize-none overflow-y-auto"
|
||||
/>
|
||||
<div className="absolute right-3 top-1/2 -translate-y-1/2">
|
||||
<div className="absolute right-3 bottom-2.5">
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!inputValue.trim()}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+12
-4
@@ -12,9 +12,6 @@ dependencies = [
|
||||
"mcp>=1.0.0",
|
||||
"fastmcp>=2.0.0",
|
||||
"textual>=1.0.0",
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-xdist>=3.0",
|
||||
"tools",
|
||||
]
|
||||
|
||||
@@ -22,6 +19,11 @@ dependencies = [
|
||||
tui = ["textual>=0.75.0"]
|
||||
webhook = ["aiohttp>=3.9.0"]
|
||||
server = ["aiohttp>=3.9.0"]
|
||||
testing = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-xdist>=3.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
hive = "framework.cli:main"
|
||||
@@ -63,4 +65,10 @@ lint.isort.section-order = [
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["ty>=0.0.13", "ruff>=0.14.14"]
|
||||
dev = [
|
||||
"ty>=0.0.13",
|
||||
"ruff>=0.14.14",
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-xdist>=3.0",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
"""Diagnostic script to reproduce and trace Codex streaming errors.
|
||||
|
||||
Run: .venv/bin/python core/tests/debug_codex_stream.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
# Enable litellm debug logging to see the raw HTTP exchange
|
||||
litellm._turn_on_debug()
|
||||
|
||||
|
||||
async def test_codex_stream():
|
||||
"""Minimal Codex streaming call via LiteLLMProvider (Responses API path)."""
|
||||
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs
|
||||
from framework.llm.litellm import LiteLLMProvider
|
||||
|
||||
api_key = get_api_key()
|
||||
api_base = get_api_base()
|
||||
extra_kwargs = get_llm_extra_kwargs()
|
||||
|
||||
if not api_key or not api_base:
|
||||
print("ERROR: No Codex subscription configured in ~/.hive/configuration.json")
|
||||
return
|
||||
|
||||
print(f"api_base: {api_base}")
|
||||
print(f"extra_kwargs keys: {list(extra_kwargs.keys())}")
|
||||
print(f"extra_headers: {list(extra_kwargs.get('extra_headers', {}).keys())}")
|
||||
|
||||
model = "openai/gpt-5.3-codex"
|
||||
|
||||
# Create the provider
|
||||
provider = LiteLLMProvider(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
)
|
||||
print(f"_codex_backend: {provider._codex_backend}")
|
||||
|
||||
# Verify mode is "responses" (the correct routing for Codex backend)
|
||||
_strip = model.removeprefix("openai/")
|
||||
mode = litellm.model_cost.get(_strip, {}).get("mode", "NOT SET")
|
||||
print(f"litellm.model_cost['{_strip}']['mode']: {mode}")
|
||||
if mode != "responses":
|
||||
print(" WARNING: Expected mode='responses' for Codex backend!")
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 1: Stream via LiteLLMProvider.stream() (the real code path)
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 1: LiteLLMProvider.stream() — basic text")
|
||||
print("=" * 60)
|
||||
try:
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Say hello in exactly 3 words."}]
|
||||
chunk_count = 0
|
||||
text = ""
|
||||
async for event in provider.stream(messages=messages):
|
||||
chunk_count += 1
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, TextEndEvent):
|
||||
print(f" TextEnd: {event.full_text!r}")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f" ToolCall: {event.tool_name}({event.tool_input})")
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason} "
|
||||
f"in={event.input_tokens} out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
print(f" Text: {text!r}")
|
||||
print(f" Total events: {chunk_count}")
|
||||
print(" RESULT: OK" if text else " RESULT: EMPTY")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 2: Stream via LiteLLMProvider.stream() with tools
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 2: LiteLLMProvider.stream() — with tools")
|
||||
print("=" * 60)
|
||||
try:
|
||||
from framework.llm.provider import Tool
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_weather",
|
||||
description="Get weather for a city",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
)
|
||||
]
|
||||
messages = [{"role": "user", "content": "What is the weather in SF?"}]
|
||||
chunk_count = 0
|
||||
text = ""
|
||||
tool_calls = []
|
||||
async for event in provider.stream(messages=messages, tools=tools):
|
||||
chunk_count += 1
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
tool_calls.append({"name": event.tool_name, "input": event.tool_input})
|
||||
print(f" ToolCall: {event.tool_name}({json.dumps(event.tool_input)})")
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason} "
|
||||
f"in={event.input_tokens} out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
print(f" Text: {text!r}")
|
||||
print(f" Tool calls: {json.dumps(tool_calls, indent=2)}")
|
||||
print(f" Total events: {chunk_count}")
|
||||
status = "OK" if (text or tool_calls) else "EMPTY"
|
||||
print(f" RESULT: {status}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 3: acomplete() via provider (uses stream + collect)
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 3: LiteLLMProvider.acomplete() — round-trip")
|
||||
print("=" * 60)
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is 2+2? Reply with just the number."}]
|
||||
response = await provider.acomplete(messages=messages)
|
||||
print(f" Content: {response.content!r}")
|
||||
print(f" Model: {response.model}")
|
||||
print(f" Tokens: in={response.input_tokens} out={response.output_tokens}")
|
||||
print(f" Stop: {response.stop_reason}")
|
||||
print(" RESULT: OK" if response.content else " RESULT: EMPTY")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 4: Direct litellm.acompletion with metadata fix
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 4: Direct litellm.acompletion (with metadata={})")
|
||||
print("=" * 60)
|
||||
try:
|
||||
direct_kwargs = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Say hello in exactly 3 words."}],
|
||||
"stream": True,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"metadata": {}, # Prevent NoneType masking in error handler
|
||||
**extra_kwargs,
|
||||
}
|
||||
response = await litellm.acompletion(**direct_kwargs)
|
||||
chunk_count = 0
|
||||
text = ""
|
||||
async for chunk in response:
|
||||
chunk_count += 1
|
||||
choices = chunk.choices if chunk.choices else []
|
||||
delta = choices[0].delta if choices else None
|
||||
content = delta.content if delta and delta.content else ""
|
||||
if content:
|
||||
text += content
|
||||
finish = choices[0].finish_reason if choices else None
|
||||
if finish:
|
||||
print(f" finish_reason: {finish}")
|
||||
print(f" Text: {text!r}")
|
||||
print(f" Total chunks: {chunk_count}")
|
||||
print(" RESULT: OK" if text else " RESULT: EMPTY")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Test 5: Rapid-fire 3 calls via provider.stream()
|
||||
# -----------------------------------------------------------
|
||||
print("=" * 60)
|
||||
print("TEST 5: Rapid-fire 3 calls via provider.stream()")
|
||||
print("=" * 60)
|
||||
for i in range(3):
|
||||
try:
|
||||
messages = [{"role": "user", "content": f"Say the number {i + 1}."}]
|
||||
text = ""
|
||||
async for event in provider.stream(messages=messages):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" Call {i + 1}: StreamError: {event.error}")
|
||||
break
|
||||
status = f"OK ({len(text)} chars: {text!r})" if text else "EMPTY"
|
||||
print(f" Call {i + 1}: {status}")
|
||||
except Exception as e:
|
||||
print(f" Call {i + 1}: ERROR {type(e).__name__}: {e}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_codex_stream())
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Run Codex stream with litellm debug logging enabled.
|
||||
|
||||
Run: .venv/bin/python core/tests/debug_codex_verbose.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
from framework.config import get_api_base, get_api_key, get_llm_extra_kwargs # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
from framework.llm.stream_events import ( # noqa: E402
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
TextEndEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = get_api_key()
|
||||
api_base = get_api_base()
|
||||
extra_kwargs = get_llm_extra_kwargs()
|
||||
|
||||
if not api_key or not api_base:
|
||||
print("ERROR: No Codex config in ~/.hive/configuration.json")
|
||||
return
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
model="openai/gpt-5.3-codex",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
print(f"_codex_backend={provider._codex_backend}")
|
||||
print()
|
||||
|
||||
text = ""
|
||||
async for event in provider.stream(
|
||||
messages=[{"role": "user", "content": "What is 2+2? Reply with just the number."}],
|
||||
system="You are a helpful assistant.",
|
||||
):
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, TextEndEvent):
|
||||
print(f"TextEnd: {event.full_text!r}")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f"ToolCall: {event.tool_name}({event.tool_input})")
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f"Finish: stop={event.stop_reason} "
|
||||
f"in={event.input_tokens} out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f"StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
|
||||
print(f"Text: {text!r}")
|
||||
print("OK" if text else "EMPTY")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Integration test: Run a real EventLoopNode against the Codex backend.
|
||||
|
||||
Run: .venv/bin/python core/tests/test_codex_eventloop.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(0, "core")
|
||||
|
||||
logging.basicConfig(level=logging.WARNING, format="%(levelname)s %(name)s: %(message)s")
|
||||
# Show our provider's retry/stream logs
|
||||
logging.getLogger("framework.llm.litellm").setLevel(logging.DEBUG)
|
||||
|
||||
from framework.config import RuntimeConfig # noqa: E402
|
||||
from framework.graph.event_loop_node import EventLoopNode, LoopConfig # noqa: E402
|
||||
from framework.graph.node import NodeContext, NodeResult, NodeSpec, SharedMemory # noqa: E402
|
||||
from framework.llm.litellm import LiteLLMProvider # noqa: E402
|
||||
|
||||
|
||||
def make_provider() -> LiteLLMProvider:
|
||||
cfg = RuntimeConfig()
|
||||
if not cfg.api_key:
|
||||
print("ERROR: No API key configured in ~/.hive/configuration.json")
|
||||
sys.exit(1)
|
||||
print(f"Model : {cfg.model}")
|
||||
print(f"Base : {cfg.api_base}")
|
||||
print(f"Codex : {'chatgpt.com/backend-api/codex' in (cfg.api_base or '')}")
|
||||
return LiteLLMProvider(
|
||||
model=cfg.model,
|
||||
api_key=cfg.api_key,
|
||||
api_base=cfg.api_base,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_context(
|
||||
llm: LiteLLMProvider,
|
||||
*,
|
||||
node_id: str = "test",
|
||||
system_prompt: str = "You are a helpful assistant.",
|
||||
output_keys: list[str] | None = None,
|
||||
) -> NodeContext:
|
||||
if output_keys is None:
|
||||
output_keys = ["answer"]
|
||||
|
||||
spec = NodeSpec(
|
||||
id=node_id,
|
||||
name="Test Node",
|
||||
description="Integration test node",
|
||||
node_type="event_loop",
|
||||
output_keys=output_keys,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
runtime = MagicMock()
|
||||
runtime.start_run = MagicMock(return_value="run-1")
|
||||
runtime.decide = MagicMock(return_value="dec-1")
|
||||
runtime.record_outcome = MagicMock()
|
||||
runtime.end_run = MagicMock()
|
||||
|
||||
memory = SharedMemory()
|
||||
|
||||
return NodeContext(
|
||||
runtime=runtime,
|
||||
node_id=node_id,
|
||||
node_spec=spec,
|
||||
memory=memory,
|
||||
input_data={},
|
||||
llm=llm,
|
||||
available_tools=[],
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
async def run_test(
|
||||
name: str, llm: LiteLLMProvider, system: str, output_keys: list[str]
|
||||
) -> NodeResult:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"TEST: {name}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
ctx = make_context(llm, system_prompt=system, output_keys=output_keys)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=3))
|
||||
|
||||
try:
|
||||
result = await node.execute(ctx)
|
||||
print(f" Success : {result.success}")
|
||||
print(f" Output : {result.output}")
|
||||
if result.error:
|
||||
print(f" Error : {result.error}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return NodeResult(success=False, error=str(e))
|
||||
|
||||
|
||||
async def main():
|
||||
llm = make_provider()
|
||||
print()
|
||||
|
||||
# Test 1: Simple text output — the node should call set_output to fill "answer"
|
||||
r1 = await run_test(
|
||||
name="Simple text generation",
|
||||
llm=llm,
|
||||
system=(
|
||||
"You are a helpful assistant. When asked a question, use the "
|
||||
"set_output tool to store your answer in the 'answer' key. "
|
||||
"Keep answers short (1-2 sentences)."
|
||||
),
|
||||
output_keys=["answer"],
|
||||
)
|
||||
|
||||
# Test 2: If test 1 failed, try bare stream() to isolate the issue
|
||||
if not r1.success:
|
||||
print(f"\n{'=' * 60}")
|
||||
print("FALLBACK: Testing bare provider.stream() directly")
|
||||
print(f"{'=' * 60}")
|
||||
try:
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamErrorEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
|
||||
text = ""
|
||||
events = []
|
||||
async for event in llm.stream(
|
||||
messages=[{"role": "user", "content": "Say hello in 3 words."}],
|
||||
):
|
||||
events.append(type(event).__name__)
|
||||
if isinstance(event, TextDeltaEvent):
|
||||
text = event.snapshot
|
||||
elif isinstance(event, FinishEvent):
|
||||
print(
|
||||
f" Finish: stop={event.stop_reason}"
|
||||
f" in={event.input_tokens}"
|
||||
f" out={event.output_tokens}"
|
||||
)
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
print(f" StreamError: {event.error} (recoverable={event.recoverable})")
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
print(f" ToolCall: {event.tool_name}")
|
||||
print(f" Text : {text!r}")
|
||||
print(f" Events : {events}")
|
||||
print(f" RESULT : {'OK' if text else 'EMPTY'}")
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("DONE")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -36,16 +36,6 @@ class FailingLLMProvider(LLMProvider):
|
||||
def complete(self, messages: list[dict[str, Any]], **kwargs: Any) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str,
|
||||
tools: list,
|
||||
tool_executor: Any,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
|
||||
async def _build_conversation(*pairs: tuple[str, str]) -> NodeConversation:
|
||||
"""Build a NodeConversation from (user, assistant) message pairs."""
|
||||
|
||||
@@ -62,9 +62,6 @@ class MockStreamingLLM(LLMProvider):
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
@@ -64,9 +64,6 @@ class MockStreamingLLM(LLMProvider):
|
||||
self.complete_calls.append({"messages": messages, "system": system})
|
||||
return LLMResponse(content=self.complete_response, model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
@@ -95,17 +95,6 @@ class ScriptableMockLLMProvider(LLMProvider):
|
||||
output_tokens=10,
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
tool_executor: Callable[[ToolUse], ToolResult] | None = None,
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 1024,
|
||||
) -> LLMResponse:
|
||||
return self.complete(messages, system, tools, max_tokens)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@@ -68,9 +68,6 @@ class MockStreamingLLM(LLMProvider):
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary of conversation.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build a simple text-only scenario
|
||||
@@ -114,7 +111,7 @@ def tool_call_scenario(
|
||||
@pytest.fixture
|
||||
def runtime():
|
||||
rt = MagicMock(spec=Runtime)
|
||||
rt.start_run = MagicMock(return_value="run_1")
|
||||
rt.start_run = MagicMock(return_value="session_20250101_000000_eventlp01")
|
||||
rt.decide = MagicMock(return_value="dec_1")
|
||||
rt.record_outcome = MagicMock()
|
||||
rt.end_run = MagicMock()
|
||||
@@ -704,6 +701,259 @@ class TestClientFacingBlocking:
|
||||
assert "ask_user" not in tool_names
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Client-facing: _cf_expecting_work state machine
|
||||
#
|
||||
# After user responds, text-only turns with missing required outputs should
|
||||
# go through judge (RETRY) instead of auto-blocking. This prevents weak
|
||||
# models from stalling when they output "Understood" without calling tools.
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestClientFacingExpectingWork:
|
||||
"""Tests for _cf_expecting_work state machine in client-facing nodes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_after_user_input_goes_to_judge(self, runtime, memory):
|
||||
"""After user responds, text-only with missing outputs gets judged (not auto-blocked).
|
||||
|
||||
Simulates: findings-review asks user, user says "generate report",
|
||||
Codex replies "Understood" without tools -> judge should RETRY.
|
||||
"""
|
||||
spec = NodeSpec(
|
||||
id="findings",
|
||||
name="Findings Review",
|
||||
description="review findings",
|
||||
node_type="event_loop",
|
||||
output_keys=["decision"],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 0: ask user what to do
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{"question": "Continue or generate report?"},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
# Turn 1: after user responds, LLM outputs text-only (lazy)
|
||||
text_scenario("Understood, generating the report."),
|
||||
# Turn 2: after judge RETRY, LLM sets output
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{"key": "decision", "value": "generate"},
|
||||
),
|
||||
# Turn 3: accept
|
||||
text_scenario("Done."),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
async def user_responds():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Generate the report")
|
||||
|
||||
task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["decision"] == "generate"
|
||||
# LLM should have been called at least 3 times (ask_user, text-only retried, set_output)
|
||||
assert llm._call_index >= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_block_without_missing_outputs(self, runtime, memory):
|
||||
"""Text-only with no missing outputs should still auto-block (queen monitoring).
|
||||
|
||||
Simulates: queen node with no required outputs outputs "monitoring..."
|
||||
-> should auto-block and wait for event, not spin in judge loop.
|
||||
"""
|
||||
spec = NodeSpec(
|
||||
id="queen",
|
||||
name="Queen",
|
||||
description="orchestrator",
|
||||
node_type="event_loop",
|
||||
output_keys=[],
|
||||
client_facing=True,
|
||||
)
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 0: ask user for domain
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{"question": "What domain?"},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
# Turn 1: after user input, outputs monitoring text
|
||||
# No missing required outputs -> should auto-block
|
||||
text_scenario("Monitoring workers..."),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(config=LoopConfig(max_iterations=10))
|
||||
ctx = build_ctx(runtime, spec, memory, llm)
|
||||
|
||||
async def user_then_shutdown():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("furwise.app")
|
||||
# Node should auto-block on "Monitoring..." text.
|
||||
# Give it time to reach the block, then shutdown.
|
||||
await asyncio.sleep(0.1)
|
||||
node.signal_shutdown()
|
||||
|
||||
task = asyncio.create_task(user_then_shutdown())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
# LLM called exactly 2 times: ask_user + monitoring text.
|
||||
# If auto-block was skipped, judge would loop and call LLM more times.
|
||||
assert llm._call_index == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calls_reset_expecting_work(self, runtime, memory):
|
||||
"""After LLM calls tools, next text-only turn should auto-block again.
|
||||
|
||||
Simulates: user gives input -> LLM calls tools (work) -> LLM presents
|
||||
results as text -> should auto-block (presenting, not lazy).
|
||||
"""
|
||||
spec = NodeSpec(
|
||||
id="report",
|
||||
name="Report",
|
||||
description="generate report",
|
||||
node_type="event_loop",
|
||||
output_keys=["status"],
|
||||
client_facing=True,
|
||||
)
|
||||
|
||||
def my_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(tool_use_id=tool_use.id, content="saved", is_error=False)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 0: ask user
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{"question": "Ready?"},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
# Turn 1: after user responds, LLM does work (tool call)
|
||||
tool_call_scenario(
|
||||
"save_data",
|
||||
{"content": "report.html"},
|
||||
tool_use_id="tool_1",
|
||||
),
|
||||
# Turn 2: LLM presents results as text (no tools)
|
||||
# Tool calls reset _cf_expecting_work -> should auto-block
|
||||
text_scenario("Here is your report. Need changes?"),
|
||||
# Turn 3: after user responds, set output
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{"key": "status", "value": "complete"},
|
||||
),
|
||||
# Turn 4: done
|
||||
text_scenario("All done."),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(
|
||||
tool_executor=my_executor,
|
||||
config=LoopConfig(max_iterations=10),
|
||||
)
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="save_data", description="save", parameters={})],
|
||||
)
|
||||
|
||||
async def interactions():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Yes, go ahead")
|
||||
# After tool calls + text presentation, node should auto-block again.
|
||||
# Inject second user response.
|
||||
await asyncio.sleep(0.2)
|
||||
await node.inject_event("Looks good")
|
||||
|
||||
task = asyncio.create_task(interactions())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["status"] == "complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_judge_retry_enables_expecting_work(self, runtime, memory):
|
||||
"""After judge RETRY, text-only with missing outputs goes to judge again.
|
||||
|
||||
Simulates: LLM calls save_data but forgets set_output -> judge RETRY ->
|
||||
LLM outputs text -> should go to judge (not auto-block).
|
||||
"""
|
||||
spec = NodeSpec(
|
||||
id="report",
|
||||
name="Report",
|
||||
description="generate report",
|
||||
node_type="event_loop",
|
||||
output_keys=["status"],
|
||||
client_facing=True,
|
||||
)
|
||||
|
||||
def my_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(tool_use_id=tool_use.id, content="saved", is_error=False)
|
||||
|
||||
llm = MockStreamingLLM(
|
||||
scenarios=[
|
||||
# Turn 0: ask user
|
||||
tool_call_scenario(
|
||||
"ask_user",
|
||||
{"question": "Generate?"},
|
||||
tool_use_id="ask_1",
|
||||
),
|
||||
# Turn 1: LLM calls tool but doesn't set output
|
||||
tool_call_scenario(
|
||||
"save_data",
|
||||
{"content": "report"},
|
||||
tool_use_id="tool_1",
|
||||
),
|
||||
# Turn 2: judge RETRY (missing "status"). LLM outputs text.
|
||||
# _cf_expecting_work should be True from RETRY -> goes to judge
|
||||
text_scenario("Report generated successfully."),
|
||||
# Turn 3: after second RETRY, LLM finally sets output
|
||||
tool_call_scenario(
|
||||
"set_output",
|
||||
{"key": "status", "value": "done"},
|
||||
),
|
||||
# Turn 4: accept
|
||||
text_scenario("Complete."),
|
||||
]
|
||||
)
|
||||
node = EventLoopNode(
|
||||
tool_executor=my_executor,
|
||||
config=LoopConfig(max_iterations=10),
|
||||
)
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="save_data", description="save", parameters={})],
|
||||
)
|
||||
|
||||
async def user_responds():
|
||||
await asyncio.sleep(0.05)
|
||||
await node.inject_event("Yes")
|
||||
|
||||
task = asyncio.create_task(user_responds())
|
||||
result = await node.execute(ctx)
|
||||
await task
|
||||
|
||||
assert result.success is True
|
||||
assert result.output["status"] == "done"
|
||||
# LLM called at least 4 times: ask_user, save_data, text(retried), set_output
|
||||
assert llm._call_index >= 4
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tool execution
|
||||
# ===========================================================================
|
||||
@@ -1026,9 +1276,6 @@ class ErrorThenSuccessLLM(LLMProvider):
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
class TestTransientErrorRetry:
|
||||
"""Test retry-with-backoff for transient LLM errors in EventLoopNode."""
|
||||
@@ -1131,20 +1378,6 @@ class TestTransientErrorRetry:
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = StreamErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
@@ -1227,9 +1460,6 @@ class TestTransientErrorRetry:
|
||||
def complete(self, messages, system="", **kwargs):
|
||||
return LLMResponse(content="ok", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs):
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
llm = RecoverableErrorThenSuccessLLM()
|
||||
ctx = build_ctx(runtime, node_spec, memory, llm)
|
||||
node = EventLoopNode(
|
||||
@@ -1412,20 +1642,6 @@ class ToolRepeatLLM(LLMProvider):
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
|
||||
class TestToolDoomLoopIntegration:
|
||||
"""Integration tests for doom loop detection in execute().
|
||||
@@ -1650,20 +1866,6 @@ class TestToolDoomLoopIntegration:
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
tool_executor,
|
||||
**kw,
|
||||
):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
model="mock",
|
||||
stop_reason="stop",
|
||||
)
|
||||
|
||||
llm = DiffArgsLLM()
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
@@ -1691,6 +1893,71 @@ class TestToolDoomLoopIntegration:
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doom_loop_detects_repeated_failing_tool(
|
||||
self,
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
):
|
||||
"""A tool that keeps failing with is_error=True should trigger doom loop.
|
||||
|
||||
Regression test: previously, errored tool calls were excluded from
|
||||
doom loop fingerprinting (``not tc.get("is_error")``), so a tool like
|
||||
a tool failing with the same error every turn
|
||||
would never be detected.
|
||||
"""
|
||||
node_spec.output_keys = []
|
||||
judge = AsyncMock(spec=JudgeProtocol)
|
||||
eval_count = 0
|
||||
|
||||
async def judge_eval(*args, **kwargs):
|
||||
nonlocal eval_count
|
||||
eval_count += 1
|
||||
if eval_count >= 5:
|
||||
return JudgeVerdict(action="ACCEPT")
|
||||
return JudgeVerdict(action="RETRY")
|
||||
|
||||
judge.evaluate = judge_eval
|
||||
|
||||
# 4 turns of the same failing tool call, then text
|
||||
llm = ToolRepeatLLM("failing_tool", {}, tool_turns=4)
|
||||
bus = EventBus()
|
||||
doom_events: list = []
|
||||
bus.subscribe(
|
||||
event_types=[EventType.NODE_TOOL_DOOM_LOOP],
|
||||
handler=lambda e: doom_events.append(e),
|
||||
)
|
||||
|
||||
def tool_exec(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id,
|
||||
content="Error: accessibility tree unavailable",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
ctx = build_ctx(
|
||||
runtime,
|
||||
node_spec,
|
||||
memory,
|
||||
llm,
|
||||
tools=[Tool(name="failing_tool", description="s", parameters={})],
|
||||
)
|
||||
node = EventLoopNode(
|
||||
judge=judge,
|
||||
tool_executor=tool_exec,
|
||||
event_bus=bus,
|
||||
config=LoopConfig(
|
||||
max_iterations=10,
|
||||
tool_doom_loop_threshold=3,
|
||||
),
|
||||
)
|
||||
result = await node.execute(ctx)
|
||||
assert result.success is True
|
||||
# Doom loop MUST fire for repeatedly-failing tool calls
|
||||
assert len(doom_events) >= 1
|
||||
assert "failing_tool" in doom_events[0].data["description"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# execution_id plumbing
|
||||
|
||||
@@ -248,22 +248,3 @@ async def test_event_loop_max_retries_positive_logs_warning(runtime, caplog):
|
||||
|
||||
# Custom nodes (not EventLoopNode instances) don't get override warning
|
||||
assert "Overriding to 0" not in caplog.text
|
||||
|
||||
|
||||
# --- Existing node types unaffected ---
|
||||
|
||||
|
||||
def test_existing_node_types_unchanged():
|
||||
"""Only event_loop is a valid node type."""
|
||||
expected = {"event_loop"}
|
||||
assert expected == GraphExecutor.VALID_NODE_TYPES
|
||||
|
||||
# Default node_type is event_loop
|
||||
spec = NodeSpec(id="x", name="X", description="x")
|
||||
assert spec.node_type == "event_loop"
|
||||
|
||||
# Default max_retries is still 3
|
||||
assert spec.max_retries == 3
|
||||
|
||||
# Default client_facing is False
|
||||
assert spec.client_facing is False
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Tests for ExecutionStream retention behavior."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -38,16 +38,6 @@ class DummyLLMProvider(LLMProvider):
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content="Summary for compaction.", model="dummy")
|
||||
|
||||
def complete_with_tools(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
system: str,
|
||||
tools: list[Tool],
|
||||
tool_executor: Callable,
|
||||
max_iterations: int = 10,
|
||||
) -> LLMResponse:
|
||||
return LLMResponse(content="Summary for compaction.", model="dummy")
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -57,8 +47,11 @@ class DummyLLMProvider(LLMProvider):
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
self._call_count += 1
|
||||
|
||||
if self._call_count == 1:
|
||||
# First call: set the output via tool call
|
||||
# Each execution takes 2 LLM calls:
|
||||
# - Odd calls (1, 3, 5, ...): set output via tool call
|
||||
# - Even calls (2, 4, 6, ...): finish with text
|
||||
if self._call_count % 2 == 1:
|
||||
# First call of each execution: set the output via tool call
|
||||
yield ToolCallEvent(
|
||||
tool_use_id=f"tc_{self._call_count}",
|
||||
tool_name="set_output",
|
||||
@@ -66,7 +59,7 @@ class DummyLLMProvider(LLMProvider):
|
||||
)
|
||||
yield FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=10)
|
||||
else:
|
||||
# Subsequent calls: just finish with text
|
||||
# Second call of each execution: finish with text
|
||||
yield TextDeltaEvent(content="Done.", snapshot="Done.")
|
||||
yield FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import pytest
|
||||
|
||||
from framework.llm.anthropic import AnthropicProvider
|
||||
from framework.llm.litellm import LiteLLMProvider, _compute_retry_delay
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool, ToolResult, ToolUse
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
|
||||
|
||||
class TestLiteLLMProviderInit:
|
||||
@@ -154,124 +154,6 @@ class TestLiteLLMProviderComplete:
|
||||
assert call_kwargs["tools"][0]["function"]["name"] == "get_weather"
|
||||
|
||||
|
||||
class TestLiteLLMProviderToolUse:
|
||||
"""Test LiteLLMProvider.complete_with_tools() method."""
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_single_iteration(self, mock_completion):
|
||||
"""Test tool use with single iteration."""
|
||||
# First response: tool call
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "get_weather"
|
||||
tool_call_response.choices[0].message.tool_calls[
|
||||
0
|
||||
].function.arguments = '{"location": "London"}'
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 20
|
||||
tool_call_response.usage.completion_tokens = 15
|
||||
|
||||
# Second response: final answer
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "The weather in London is sunny."
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 30
|
||||
final_response.usage.completion_tokens = 10
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_weather",
|
||||
description="Get the weather",
|
||||
parameters={
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(tool_use_id=tool_use.id, content="Sunny, 22C", is_error=False)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "What's the weather in London?"}],
|
||||
system="You are a weather assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert result.content == "The weather in London is sunny."
|
||||
assert result.input_tokens == 50 # 20 + 30
|
||||
assert result.output_tokens == 25 # 15 + 10
|
||||
assert mock_completion.call_count == 2
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_complete_with_tools_invalid_json_arguments_are_handled(self, mock_completion):
|
||||
"""Test that invalid JSON tool arguments do not execute the tool."""
|
||||
# Mock response with invalid JSON arguments
|
||||
tool_call_response = MagicMock()
|
||||
tool_call_response.choices = [MagicMock()]
|
||||
tool_call_response.choices[0].message.content = None
|
||||
tool_call_response.choices[0].message.tool_calls = [MagicMock()]
|
||||
tool_call_response.choices[0].message.tool_calls[0].id = "call_123"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.name = "test_tool"
|
||||
tool_call_response.choices[0].message.tool_calls[0].function.arguments = "{invalid json"
|
||||
tool_call_response.choices[0].finish_reason = "tool_calls"
|
||||
tool_call_response.model = "gpt-4o-mini"
|
||||
tool_call_response.usage.prompt_tokens = 10
|
||||
tool_call_response.usage.completion_tokens = 5
|
||||
|
||||
# Final response (LLM continues after tool error)
|
||||
final_response = MagicMock()
|
||||
final_response.choices = [MagicMock()]
|
||||
final_response.choices[0].message.content = "Handled error"
|
||||
final_response.choices[0].message.tool_calls = None
|
||||
final_response.choices[0].finish_reason = "stop"
|
||||
final_response.model = "gpt-4o-mini"
|
||||
final_response.usage.prompt_tokens = 5
|
||||
final_response.usage.completion_tokens = 5
|
||||
|
||||
mock_completion.side_effect = [tool_call_response, final_response]
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
called["value"] = True
|
||||
return ToolResult(
|
||||
tool_use_id=tool_use.id, content="should not be called", is_error=False
|
||||
)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "Run tool"}],
|
||||
system="You are a test assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert called["value"] is False
|
||||
assert result.content == "Handled error"
|
||||
|
||||
|
||||
class TestToolConversion:
|
||||
"""Test tool format conversion."""
|
||||
|
||||
@@ -352,43 +234,6 @@ class TestAnthropicProviderBackwardCompatibility:
|
||||
assert call_kwargs["model"] == "claude-3-haiku-20240307"
|
||||
assert call_kwargs["api_key"] == "test-key"
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_anthropic_provider_complete_with_tools(self, mock_completion):
|
||||
"""Test AnthropicProvider.complete_with_tools() delegates to LiteLLM."""
|
||||
# Mock a simple response (no tool calls)
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "The time is 3:00 PM."
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "claude-3-haiku-20240307"
|
||||
mock_response.usage.prompt_tokens = 20
|
||||
mock_response.usage.completion_tokens = 10
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
provider = AnthropicProvider(api_key="test-key", model="claude-3-haiku-20240307")
|
||||
|
||||
tools = [
|
||||
Tool(
|
||||
name="get_time",
|
||||
description="Get current time",
|
||||
parameters={"properties": {}, "required": []},
|
||||
)
|
||||
]
|
||||
|
||||
def tool_executor(tool_use: ToolUse) -> ToolResult:
|
||||
return ToolResult(tool_use_id=tool_use.id, content="3:00 PM", is_error=False)
|
||||
|
||||
result = provider.complete_with_tools(
|
||||
messages=[{"role": "user", "content": "What time is it?"}],
|
||||
system="You are a time assistant.",
|
||||
tools=tools,
|
||||
tool_executor=tool_executor,
|
||||
)
|
||||
|
||||
assert result.content == "The time is 3:00 PM."
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_anthropic_provider_passes_response_format(self, mock_completion):
|
||||
"""Test that AnthropicProvider accepts and forwards response_format."""
|
||||
@@ -738,43 +583,6 @@ class TestAsyncComplete:
|
||||
f"Event loop was blocked — only {len(heartbeat_ticks)} heartbeat ticks"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.acompletion")
|
||||
async def test_acomplete_with_tools_uses_acompletion(self, mock_acompletion):
|
||||
"""acomplete_with_tools() should use litellm.acompletion."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "tool result"
|
||||
mock_response.choices[0].message.tool_calls = None
|
||||
mock_response.choices[0].finish_reason = "stop"
|
||||
mock_response.model = "gpt-4o-mini"
|
||||
mock_response.usage.prompt_tokens = 10
|
||||
mock_response.usage.completion_tokens = 5
|
||||
|
||||
async def async_return(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_acompletion.side_effect = async_return
|
||||
|
||||
provider = LiteLLMProvider(model="gpt-4o-mini", api_key="test-key")
|
||||
tools = [
|
||||
Tool(
|
||||
name="search",
|
||||
description="Search the web",
|
||||
parameters={"properties": {"q": {"type": "string"}}, "required": ["q"]},
|
||||
)
|
||||
]
|
||||
|
||||
result = await provider.acomplete_with_tools(
|
||||
messages=[{"role": "user", "content": "Search for cats"}],
|
||||
system="You are helpful.",
|
||||
tools=tools,
|
||||
tool_executor=lambda tu: ToolResult(tool_use_id=tu.id, content="cats"),
|
||||
)
|
||||
|
||||
assert result.content == "tool result"
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mock_provider_acomplete(self):
|
||||
"""MockLLMProvider.acomplete() should work without blocking."""
|
||||
@@ -809,11 +617,6 @@ class TestAsyncComplete:
|
||||
time.sleep(0.1) # Sync blocking
|
||||
return LLMResponse(content="sync done", model="slow")
|
||||
|
||||
def complete_with_tools(
|
||||
self, messages, system, tools, tool_executor, max_iterations=10
|
||||
):
|
||||
return LLMResponse(content="sync tools done", model="slow")
|
||||
|
||||
provider = SlowSyncProvider()
|
||||
main_thread_id = threading.current_thread().ident
|
||||
|
||||
|
||||
@@ -52,9 +52,6 @@ class MockLLMProvider(LLMProvider):
|
||||
output_tokens=50,
|
||||
)
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, max_iterations=10):
|
||||
raise NotImplementedError("Tool use not needed for judge tests")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LLMJudge Tests - Custom Provider
|
||||
|
||||
@@ -102,4 +102,3 @@ class TestOrchestratorLLMProviderType:
|
||||
|
||||
assert isinstance(orchestrator._llm, LLMProvider)
|
||||
assert hasattr(orchestrator._llm, "complete")
|
||||
assert hasattr(orchestrator._llm, "complete_with_tools")
|
||||
|
||||
@@ -21,22 +21,51 @@ from framework.runtime.runtime_log_schemas import (
|
||||
from framework.runtime.runtime_log_store import RuntimeLogStore
|
||||
from framework.runtime.runtime_logger import RuntimeLogger
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SESSION_PREFIX = "session_20250101_000000"
|
||||
|
||||
|
||||
def _sid(suffix: str) -> str:
|
||||
"""Build a deterministic session ID for tests."""
|
||||
return f"{_SESSION_PREFIX}_{suffix}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RuntimeLogStore tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _force_session_run_ids(monkeypatch):
|
||||
"""Use unified session_* IDs in tests to avoid deprecated run path warnings."""
|
||||
|
||||
original_start_run = RuntimeLogger.start_run
|
||||
counter = 0
|
||||
|
||||
def _patched_start_run(self, goal_id: str = "", session_id: str = "") -> str:
|
||||
nonlocal counter
|
||||
if not session_id:
|
||||
counter += 1
|
||||
session_id = _sid(f"{counter:08x}")
|
||||
return original_start_run(self, goal_id=goal_id, session_id=session_id)
|
||||
|
||||
monkeypatch.setattr(RuntimeLogger, "start_run", _patched_start_run)
|
||||
|
||||
|
||||
class TestRuntimeLogStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_run_dir_creates_directory(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
store.ensure_run_dir("test_run_1")
|
||||
assert (tmp_path / "logs" / "runs" / "test_run_1").is_dir()
|
||||
store.ensure_run_dir(_sid("test0001"))
|
||||
assert (tmp_path / "logs" / "sessions" / _sid("test0001") / "logs").is_dir()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_and_load_details(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
store.ensure_run_dir("test_run_2")
|
||||
store.ensure_run_dir(_sid("test0002"))
|
||||
|
||||
detail1 = NodeDetail(
|
||||
node_id="node-1",
|
||||
@@ -56,10 +85,10 @@ class TestRuntimeLogStore:
|
||||
total_steps=1,
|
||||
)
|
||||
|
||||
store.append_node_detail("test_run_2", detail1)
|
||||
store.append_node_detail("test_run_2", detail2)
|
||||
store.append_node_detail(_sid("test0002"), detail1)
|
||||
store.append_node_detail(_sid("test0002"), detail2)
|
||||
|
||||
loaded = await store.load_details("test_run_2")
|
||||
loaded = await store.load_details(_sid("test0002"))
|
||||
assert loaded is not None
|
||||
assert len(loaded.nodes) == 2
|
||||
assert loaded.nodes[0].node_id == "node-1"
|
||||
@@ -69,7 +98,7 @@ class TestRuntimeLogStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_and_load_tool_logs(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
store.ensure_run_dir("test_run_3")
|
||||
store.ensure_run_dir(_sid("test0003"))
|
||||
|
||||
step = NodeStepLog(
|
||||
node_id="node-1",
|
||||
@@ -91,9 +120,9 @@ class TestRuntimeLogStore:
|
||||
verdict="CONTINUE",
|
||||
)
|
||||
|
||||
store.append_step("test_run_3", step)
|
||||
store.append_step(_sid("test0003"), step)
|
||||
|
||||
loaded = await store.load_tool_logs("test_run_3")
|
||||
loaded = await store.load_tool_logs(_sid("test0003"))
|
||||
assert loaded is not None
|
||||
assert len(loaded.steps) == 1
|
||||
assert loaded.steps[0].tool_calls[0].tool_name == "web_search"
|
||||
@@ -104,7 +133,7 @@ class TestRuntimeLogStore:
|
||||
async def test_save_and_load_summary(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
summary = RunSummaryLog(
|
||||
run_id="test_run_1",
|
||||
run_id=_sid("test0001"),
|
||||
agent_id="agent-a",
|
||||
goal_id="goal-1",
|
||||
status="success",
|
||||
@@ -115,11 +144,11 @@ class TestRuntimeLogStore:
|
||||
execution_quality="clean",
|
||||
)
|
||||
|
||||
await store.save_summary("test_run_1", summary)
|
||||
await store.save_summary(_sid("test0001"), summary)
|
||||
|
||||
loaded = await store.load_summary("test_run_1")
|
||||
loaded = await store.load_summary(_sid("test0001"))
|
||||
assert loaded is not None
|
||||
assert loaded.run_id == "test_run_1"
|
||||
assert loaded.run_id == _sid("test0001")
|
||||
assert loaded.status == "success"
|
||||
assert loaded.total_nodes_executed == 3
|
||||
assert loaded.goal_id == "goal-1"
|
||||
@@ -128,9 +157,9 @@ class TestRuntimeLogStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_missing_run_returns_none(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
assert await store.load_summary("nonexistent") is None
|
||||
assert await store.load_details("nonexistent") is None
|
||||
assert await store.load_tool_logs("nonexistent") is None
|
||||
assert await store.load_summary(_sid("missing00")) is None
|
||||
assert await store.load_details(_sid("missing00")) is None
|
||||
assert await store.load_tool_logs(_sid("missing00")) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_runs_empty(self, tmp_path: Path):
|
||||
@@ -143,21 +172,21 @@ class TestRuntimeLogStore:
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
|
||||
# Save a success run
|
||||
store.ensure_run_dir("run_ok")
|
||||
store.ensure_run_dir(_sid("runok000"))
|
||||
await store.save_summary(
|
||||
"run_ok",
|
||||
_sid("runok000"),
|
||||
RunSummaryLog(
|
||||
run_id="run_ok",
|
||||
run_id=_sid("runok000"),
|
||||
status="success",
|
||||
started_at="2025-01-01T00:00:01",
|
||||
),
|
||||
)
|
||||
# Save a failure run
|
||||
store.ensure_run_dir("run_fail")
|
||||
store.ensure_run_dir(_sid("runfail0"))
|
||||
await store.save_summary(
|
||||
"run_fail",
|
||||
_sid("runfail0"),
|
||||
RunSummaryLog(
|
||||
run_id="run_fail",
|
||||
run_id=_sid("runfail0"),
|
||||
status="failure",
|
||||
needs_attention=True,
|
||||
started_at="2025-01-01T00:00:02",
|
||||
@@ -171,19 +200,19 @@ class TestRuntimeLogStore:
|
||||
# Filter by status
|
||||
success_runs = await store.list_runs(status="success")
|
||||
assert len(success_runs) == 1
|
||||
assert success_runs[0].run_id == "run_ok"
|
||||
assert success_runs[0].run_id == _sid("runok000")
|
||||
|
||||
# Filter by needs_attention
|
||||
attention_runs = await store.list_runs(status="needs_attention")
|
||||
assert len(attention_runs) == 1
|
||||
assert attention_runs[0].run_id == "run_fail"
|
||||
assert attention_runs[0].run_id == _sid("runfail0")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_runs_sorted_by_timestamp_desc(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
|
||||
for i in range(5):
|
||||
run_id = f"run_{i}"
|
||||
run_id = f"session_20250101_0000{i:02d}_run{i:04d}"
|
||||
store.ensure_run_dir(run_id)
|
||||
await store.save_summary(
|
||||
run_id,
|
||||
@@ -196,15 +225,15 @@ class TestRuntimeLogStore:
|
||||
|
||||
runs = await store.list_runs()
|
||||
# Most recent first
|
||||
assert runs[0].run_id == "run_4"
|
||||
assert runs[-1].run_id == "run_0"
|
||||
assert runs[0].run_id == "session_20250101_000004_run0004"
|
||||
assert runs[-1].run_id == "session_20250101_000000_run0000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_runs_limit(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
|
||||
for i in range(10):
|
||||
run_id = f"run_{i}"
|
||||
run_id = f"session_20250101_0000{i:02d}_run{i:04d}"
|
||||
store.ensure_run_dir(run_id)
|
||||
await store.save_summary(
|
||||
run_id,
|
||||
@@ -224,45 +253,45 @@ class TestRuntimeLogStore:
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
|
||||
# Completed run with summary
|
||||
store.ensure_run_dir("run_done")
|
||||
store.ensure_run_dir(_sid("rundone0"))
|
||||
await store.save_summary(
|
||||
"run_done",
|
||||
_sid("rundone0"),
|
||||
RunSummaryLog(
|
||||
run_id="run_done",
|
||||
run_id=_sid("rundone0"),
|
||||
status="success",
|
||||
started_at="2025-01-01T00:00:01",
|
||||
),
|
||||
)
|
||||
|
||||
# In-progress run: directory exists but no summary.json
|
||||
store.ensure_run_dir("run_active")
|
||||
store.ensure_run_dir(_sid("runactiv0"))
|
||||
|
||||
all_runs = await store.list_runs()
|
||||
assert len(all_runs) == 2
|
||||
run_ids = {r.run_id for r in all_runs}
|
||||
assert "run_done" in run_ids
|
||||
assert "run_active" in run_ids
|
||||
assert _sid("rundone0") in run_ids
|
||||
assert _sid("runactiv0") in run_ids
|
||||
|
||||
active = next(r for r in all_runs if r.run_id == "run_active")
|
||||
active = next(r for r in all_runs if r.run_id == _sid("runactiv0"))
|
||||
assert active.status == "in_progress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_node_details_sync(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
store.ensure_run_dir("test_run")
|
||||
store.ensure_run_dir(_sid("testsync0"))
|
||||
|
||||
store.append_node_detail(
|
||||
"test_run",
|
||||
_sid("testsync0"),
|
||||
NodeDetail(
|
||||
node_id="n1", node_name="A", success=True, input_tokens=100, output_tokens=50
|
||||
),
|
||||
)
|
||||
store.append_node_detail(
|
||||
"test_run",
|
||||
_sid("testsync0"),
|
||||
NodeDetail(node_id="n2", node_name="B", success=False, error="oops"),
|
||||
)
|
||||
|
||||
details = store.read_node_details_sync("test_run")
|
||||
details = store.read_node_details_sync(_sid("testsync0"))
|
||||
assert len(details) == 2
|
||||
assert details[0].node_id == "n1"
|
||||
assert details[1].error == "oops"
|
||||
@@ -271,15 +300,15 @@ class TestRuntimeLogStore:
|
||||
async def test_corrupt_jsonl_line_skipped(self, tmp_path: Path):
|
||||
"""A corrupt JSONL line should be skipped without breaking reads."""
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
store.ensure_run_dir("test_run")
|
||||
store.ensure_run_dir(_sid("corrupt00"))
|
||||
|
||||
# Write a valid line, a corrupt line, then another valid line
|
||||
jsonl_path = tmp_path / "logs" / "runs" / "test_run" / "details.jsonl"
|
||||
jsonl_path = tmp_path / "logs" / "sessions" / _sid("corrupt00") / "logs" / "details.jsonl"
|
||||
valid1 = json.dumps(NodeDetail(node_id="n1", node_name="A", success=True).model_dump())
|
||||
valid2 = json.dumps(NodeDetail(node_id="n2", node_name="B", success=True).model_dump())
|
||||
jsonl_path.write_text(f"{valid1}\n{{corrupt line\n{valid2}\n")
|
||||
|
||||
details = store.read_node_details_sync("test_run")
|
||||
details = store.read_node_details_sync(_sid("corrupt00"))
|
||||
assert len(details) == 2
|
||||
assert details[0].node_id == "n1"
|
||||
assert details[1].node_id == "n2"
|
||||
@@ -297,14 +326,14 @@ class TestRuntimeLogger:
|
||||
rl = RuntimeLogger(store=store, agent_id="test-agent")
|
||||
run_id = rl.start_run("goal-1")
|
||||
assert run_id
|
||||
assert len(run_id) > 10 # timestamp + uuid
|
||||
assert run_id.startswith("session_")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_run_creates_directory(self, tmp_path: Path):
|
||||
store = RuntimeLogStore(tmp_path / "logs")
|
||||
rl = RuntimeLogger(store=store, agent_id="test-agent")
|
||||
run_id = rl.start_run("goal-1")
|
||||
assert (tmp_path / "logs" / "runs" / run_id).is_dir()
|
||||
assert (tmp_path / "logs" / "sessions" / run_id / "logs").is_dir()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_step_writes_to_disk_immediately(self, tmp_path: Path):
|
||||
@@ -322,7 +351,7 @@ class TestRuntimeLogger:
|
||||
)
|
||||
|
||||
# Verify the file exists and has one line
|
||||
jsonl_path = tmp_path / "logs" / "runs" / run_id / "tool_logs.jsonl"
|
||||
jsonl_path = tmp_path / "logs" / "sessions" / run_id / "logs" / "tool_logs.jsonl"
|
||||
assert jsonl_path.exists()
|
||||
lines = [line for line in jsonl_path.read_text().strip().split("\n") if line]
|
||||
assert len(lines) == 1
|
||||
@@ -345,7 +374,7 @@ class TestRuntimeLogger:
|
||||
exit_status="success",
|
||||
)
|
||||
|
||||
jsonl_path = tmp_path / "logs" / "runs" / run_id / "details.jsonl"
|
||||
jsonl_path = tmp_path / "logs" / "sessions" / run_id / "logs" / "details.jsonl"
|
||||
assert jsonl_path.exists()
|
||||
lines = [line for line in jsonl_path.read_text().strip().split("\n") if line]
|
||||
assert len(lines) == 1
|
||||
@@ -789,10 +818,10 @@ class TestRuntimeLogger:
|
||||
# Make the store path unwritable to force an error
|
||||
import os
|
||||
|
||||
bad_path = tmp_path / "logs" / "runs"
|
||||
bad_path = tmp_path / "logs" / "sessions"
|
||||
bad_path.mkdir(parents=True, exist_ok=True)
|
||||
# Create a file where directory should be
|
||||
run_dir = bad_path / rt_logger._run_id
|
||||
run_dir = bad_path / rt_logger._run_id / "logs"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
blocker = run_dir / "summary.json"
|
||||
blocker.write_text("not json")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,693 @@
|
||||
"""End-to-end test for subagent escalation via report_to_parent(wait_for_response=True).
|
||||
|
||||
Tests the FULL routing chain:
|
||||
ExecutionStream → GraphExecutor → EventLoopNode → _execute_subagent
|
||||
→ _report_callback registers _EscalationReceiver in executor.node_registry
|
||||
→ emit CLIENT_INPUT_REQUESTED with escalation_id
|
||||
→ subscriber calls stream.inject_input(escalation_id, "done")
|
||||
→ ExecutionStream finds _EscalationReceiver in executor.node_registry
|
||||
→ receiver.inject_event("done") unblocks the subagent
|
||||
→ subagent continues and completes
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from framework.graph import Goal, NodeSpec, SuccessCriterion
|
||||
from framework.graph.edge import GraphSpec
|
||||
from framework.llm.provider import LLMProvider, LLMResponse, Tool
|
||||
from framework.llm.stream_events import (
|
||||
FinishEvent,
|
||||
StreamEvent,
|
||||
TextDeltaEvent,
|
||||
ToolCallEvent,
|
||||
)
|
||||
from framework.runtime.event_bus import AgentEvent, EventBus, EventType
|
||||
from framework.runtime.execution_stream import EntryPointSpec, ExecutionStream
|
||||
from framework.runtime.outcome_aggregator import OutcomeAggregator
|
||||
from framework.runtime.shared_state import SharedStateManager
|
||||
from framework.storage.concurrent import ConcurrentStorage
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sequenced mock LLM — returns different responses per call index
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SequencedLLM(LLMProvider):
|
||||
"""Mock LLM that returns pre-programmed stream events per call.
|
||||
|
||||
Each call to stream() pops the next scenario from the queue.
|
||||
Shared between parent and subagent (they use the same LLM instance).
|
||||
"""
|
||||
|
||||
def __init__(self, scenarios: list[list[StreamEvent]]):
|
||||
self._scenarios = list(scenarios)
|
||||
self._call_index = 0
|
||||
self.stream_calls: list[dict] = []
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
system: str = "",
|
||||
tools: list[Tool] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
self.stream_calls.append(
|
||||
{
|
||||
"index": self._call_index,
|
||||
"system": system[:200],
|
||||
"tool_names": [t.name for t in (tools or [])],
|
||||
}
|
||||
)
|
||||
if self._call_index < len(self._scenarios):
|
||||
events = self._scenarios[self._call_index]
|
||||
else:
|
||||
# Fallback: just finish
|
||||
events = [
|
||||
TextDeltaEvent(content="Done.", snapshot="Done."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5),
|
||||
]
|
||||
self._call_index += 1
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
def complete(self, messages, system="", **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="Summary.", model="mock", stop_reason="stop")
|
||||
|
||||
def complete_with_tools(self, messages, system, tools, tool_executor, **kwargs) -> LLMResponse:
|
||||
return LLMResponse(content="", model="mock", stop_reason="stop")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalation_e2e_through_execution_stream(tmp_path):
|
||||
"""Full e2e: subagent escalation routed through ExecutionStream.inject_input().
|
||||
|
||||
Scenario:
|
||||
1. Parent node delegates to "researcher" subagent
|
||||
2. Researcher calls report_to_parent(wait_for_response=True, message="Login required")
|
||||
3. A subscriber on CLIENT_INPUT_REQUESTED gets the escalation_id
|
||||
4. Subscriber calls stream.inject_input(escalation_id, "done logging in")
|
||||
5. Subagent unblocks, sets output, completes
|
||||
6. Parent receives subagent result, sets its own output, completes
|
||||
"""
|
||||
|
||||
# -- Graph setup --
|
||||
goal = Goal(
|
||||
id="escalation-test",
|
||||
name="Escalation Test",
|
||||
description="Test subagent escalation flow",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="result",
|
||||
description="Result present",
|
||||
metric="output_contains",
|
||||
target="result",
|
||||
)
|
||||
],
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
parent_node = NodeSpec(
|
||||
id="parent",
|
||||
name="Parent",
|
||||
description="Parent that delegates to researcher",
|
||||
node_type="event_loop",
|
||||
input_keys=["query"],
|
||||
output_keys=["result"],
|
||||
sub_agents=["researcher"],
|
||||
system_prompt="You delegate research tasks to the researcher sub-agent.",
|
||||
)
|
||||
|
||||
researcher_node = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches by browsing, may need user help for login",
|
||||
node_type="event_loop",
|
||||
input_keys=["task"],
|
||||
output_keys=["findings"],
|
||||
system_prompt="You research topics. If you hit a login wall, ask for help.",
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="escalation-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node="parent",
|
||||
entry_points={"start": "parent"},
|
||||
terminal_nodes=["parent"],
|
||||
pause_nodes=[],
|
||||
nodes=[parent_node, researcher_node],
|
||||
edges=[],
|
||||
default_model="mock",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
# -- LLM scenarios --
|
||||
# The LLM is shared between parent and subagent. Calls happen in order:
|
||||
#
|
||||
# Call 0 (parent turn 1): delegate to researcher
|
||||
# Call 1 (subagent turn 1): report_to_parent(wait_for_response=True)
|
||||
# → blocks here until inject_input()
|
||||
# Call 2 (subagent turn 2): set_output("findings", "...")
|
||||
# Call 3 (subagent turn 3): text finish (implicit judge accepts after output filled)
|
||||
# Call 4 (parent turn 2): set_output("result", "...")
|
||||
# Call 5 (parent turn 3): text finish
|
||||
|
||||
scenarios: list[list[StreamEvent]] = [
|
||||
# Call 0: Parent delegates
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="delegate_to_sub_agent",
|
||||
tool_input={"agent_id": "researcher", "task": "Check LinkedIn profiles"},
|
||||
tool_use_id="delegate_1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 1: Subagent hits login wall, escalates
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="report_to_parent",
|
||||
tool_input={
|
||||
"message": "Login required for LinkedIn. Please log in manually.",
|
||||
"wait_for_response": True,
|
||||
},
|
||||
tool_use_id="report_1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 2: Subagent continues after user login, sets output
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "findings", "value": "Profile data extracted after login"},
|
||||
tool_use_id="set_1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 3: Subagent finishes
|
||||
[
|
||||
TextDeltaEvent(content="Research complete.", snapshot="Research complete."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 4: Parent uses subagent result
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "result", "value": "LinkedIn profile data retrieved"},
|
||||
tool_use_id="set_2",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 5: Parent finishes
|
||||
[
|
||||
TextDeltaEvent(content="Task complete.", snapshot="Task complete."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5, model="mock"),
|
||||
],
|
||||
]
|
||||
|
||||
llm = SequencedLLM(scenarios)
|
||||
|
||||
# -- Event bus + subscriber that auto-responds to escalation --
|
||||
bus = EventBus()
|
||||
escalation_events: list[AgentEvent] = []
|
||||
all_events: list[AgentEvent] = []
|
||||
inject_called = asyncio.Event()
|
||||
|
||||
# We need the stream reference for inject_input, so use a holder
|
||||
stream_holder: list[ExecutionStream] = []
|
||||
|
||||
async def escalation_handler(event: AgentEvent):
|
||||
"""Simulate a TUI/runner: when CLIENT_INPUT_REQUESTED arrives with
|
||||
an escalation node_id, inject the user's response via the stream."""
|
||||
all_events.append(event)
|
||||
if event.type == EventType.CLIENT_INPUT_REQUESTED:
|
||||
node_id = event.node_id
|
||||
if ":escalation:" in node_id:
|
||||
escalation_events.append(event)
|
||||
# Small delay to simulate user typing
|
||||
await asyncio.sleep(0.05)
|
||||
# Route through the REAL inject_input chain
|
||||
stream = stream_holder[0]
|
||||
success = await stream.inject_input(node_id, "done logging in")
|
||||
assert success, (
|
||||
f"inject_input({node_id!r}) returned False — "
|
||||
"escalation receiver not found in executor.node_registry"
|
||||
)
|
||||
inject_called.set()
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED, EventType.CLIENT_OUTPUT_DELTA],
|
||||
handler=escalation_handler,
|
||||
)
|
||||
|
||||
# -- Build and run ExecutionStream --
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
stream = ExecutionStream(
|
||||
stream_id="start",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="start",
|
||||
name="Start",
|
||||
entry_node="parent",
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, bus),
|
||||
event_bus=bus,
|
||||
llm=llm,
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
)
|
||||
stream_holder.append(stream)
|
||||
|
||||
await stream.start()
|
||||
|
||||
# Execute
|
||||
execution_id = await stream.execute({"query": "Find LinkedIn profiles"})
|
||||
result = await stream.wait_for_completion(execution_id, timeout=15)
|
||||
|
||||
await stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
# -- Assertions --
|
||||
|
||||
# 1. Execution completed successfully
|
||||
assert result is not None, "Execution should have completed"
|
||||
assert result.success, f"Execution should have succeeded, got: {result}"
|
||||
|
||||
# 2. Escalation event was received and routed
|
||||
assert inject_called.is_set(), "inject_input should have been called for escalation"
|
||||
assert len(escalation_events) >= 1, "Should have received at least one escalation event"
|
||||
|
||||
# 3. Escalation event has correct structure
|
||||
esc_event = escalation_events[0]
|
||||
assert ":escalation:" in esc_event.node_id
|
||||
assert esc_event.data["prompt"] == "Login required for LinkedIn. Please log in manually."
|
||||
|
||||
# 4. CLIENT_OUTPUT_DELTA was emitted for the escalation message
|
||||
output_deltas = [
|
||||
e
|
||||
for e in all_events
|
||||
if e.type == EventType.CLIENT_OUTPUT_DELTA and "Login required" in e.data.get("content", "")
|
||||
]
|
||||
assert len(output_deltas) >= 1, (
|
||||
"Should have emitted CLIENT_OUTPUT_DELTA with escalation message"
|
||||
)
|
||||
|
||||
# 5. The parent node got the subagent's result
|
||||
assert "result" in result.output
|
||||
assert result.output["result"] == "LinkedIn profile data retrieved"
|
||||
|
||||
# 6. The LLM was called the expected number of times
|
||||
assert llm._call_index >= 4, (
|
||||
f"Expected at least 4 LLM calls (delegate + escalation + set_output + finish), "
|
||||
f"got {llm._call_index}"
|
||||
)
|
||||
|
||||
# 7. The user's escalation response appeared in the subagent's conversation
|
||||
# Call index 2 should be the subagent's second turn (after receiving "done logging in")
|
||||
assert len(llm.stream_calls) >= 3
|
||||
# The second subagent call should have report_to_parent in its tools
|
||||
# (verifying the subagent got the right tool set)
|
||||
subagent_tools = llm.stream_calls[1]["tool_names"]
|
||||
assert "report_to_parent" in subagent_tools, (
|
||||
f"Subagent should have report_to_parent tool, got: {subagent_tools}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_escalation_cleanup_after_completion(tmp_path):
|
||||
"""Verify that _EscalationReceiver is cleaned up from the registry after use.
|
||||
|
||||
After the escalation flow completes, no escalation receivers should remain
|
||||
in the executor's node_registry.
|
||||
"""
|
||||
from framework.graph.event_loop_node import _EscalationReceiver
|
||||
|
||||
goal = Goal(
|
||||
id="cleanup-test",
|
||||
name="Cleanup Test",
|
||||
description="Test escalation cleanup",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="result",
|
||||
description="Result present",
|
||||
metric="output_contains",
|
||||
target="result",
|
||||
)
|
||||
],
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
parent_node = NodeSpec(
|
||||
id="parent",
|
||||
name="Parent",
|
||||
description="Delegates to researcher",
|
||||
node_type="event_loop",
|
||||
input_keys=["query"],
|
||||
output_keys=["result"],
|
||||
sub_agents=["researcher"],
|
||||
)
|
||||
|
||||
researcher_node = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches topics",
|
||||
node_type="event_loop",
|
||||
input_keys=["task"],
|
||||
output_keys=["findings"],
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="cleanup-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node="parent",
|
||||
entry_points={"start": "parent"},
|
||||
terminal_nodes=["parent"],
|
||||
pause_nodes=[],
|
||||
nodes=[parent_node, researcher_node],
|
||||
edges=[],
|
||||
default_model="mock",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
scenarios = [
|
||||
# Parent delegates
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="delegate_to_sub_agent",
|
||||
tool_input={"agent_id": "researcher", "task": "Check page"},
|
||||
tool_use_id="d1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Subagent escalates
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="report_to_parent",
|
||||
tool_input={"message": "Need help", "wait_for_response": True},
|
||||
tool_use_id="r1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Subagent sets output
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "findings", "value": "Done"},
|
||||
tool_use_id="s1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Subagent finish
|
||||
[
|
||||
TextDeltaEvent(content="Done.", snapshot="Done."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Parent sets output
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "result", "value": "Got it"},
|
||||
tool_use_id="s2",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Parent finish
|
||||
[
|
||||
TextDeltaEvent(content="Complete.", snapshot="Complete."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5, model="mock"),
|
||||
],
|
||||
]
|
||||
|
||||
llm = SequencedLLM(scenarios)
|
||||
bus = EventBus()
|
||||
|
||||
# Track node_registry contents via the executor
|
||||
registries_snapshot: list[dict] = []
|
||||
stream_holder: list[ExecutionStream] = []
|
||||
|
||||
async def auto_respond(event: AgentEvent):
|
||||
if event.type == EventType.CLIENT_INPUT_REQUESTED and ":escalation:" in event.node_id:
|
||||
stream = stream_holder[0]
|
||||
|
||||
# Snapshot the active executor's node_registry BEFORE responding
|
||||
for executor in stream._active_executors.values():
|
||||
escalation_keys = [k for k in executor.node_registry if ":escalation:" in k]
|
||||
registries_snapshot.append(
|
||||
{
|
||||
"phase": "before_inject",
|
||||
"escalation_keys": escalation_keys,
|
||||
"has_receiver": any(
|
||||
isinstance(v, _EscalationReceiver)
|
||||
for v in executor.node_registry.values()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
await stream.inject_input(event.node_id, "ok")
|
||||
|
||||
bus.subscribe(
|
||||
event_types=[EventType.CLIENT_INPUT_REQUESTED],
|
||||
handler=auto_respond,
|
||||
)
|
||||
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
stream = ExecutionStream(
|
||||
stream_id="start",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="start",
|
||||
name="Start",
|
||||
entry_node="parent",
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, bus),
|
||||
event_bus=bus,
|
||||
llm=llm,
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
)
|
||||
stream_holder.append(stream)
|
||||
|
||||
await stream.start()
|
||||
execution_id = await stream.execute({"query": "test"})
|
||||
result = await stream.wait_for_completion(execution_id, timeout=15)
|
||||
await stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
assert result is not None and result.success
|
||||
|
||||
# The receiver WAS in the registry during escalation
|
||||
assert len(registries_snapshot) >= 1
|
||||
assert registries_snapshot[0]["has_receiver"] is True
|
||||
assert len(registries_snapshot[0]["escalation_keys"]) == 1
|
||||
|
||||
# After completion, no active executors remain (they're cleaned up),
|
||||
# so no stale receivers can linger. The `finally` block in the callback
|
||||
# guarantees cleanup even within a single execution.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: mark_complete e2e through ExecutionStream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_complete_e2e_through_execution_stream(tmp_path):
|
||||
"""Full e2e: subagent uses report_to_parent(mark_complete=True) to terminate.
|
||||
|
||||
Scenario:
|
||||
1. Parent delegates to "researcher" subagent
|
||||
2. Researcher calls report_to_parent(mark_complete=True, message="Found profiles", data={...})
|
||||
3. Subagent terminates immediately (no set_output needed)
|
||||
4. Parent receives subagent result with reports, sets its own output, completes
|
||||
"""
|
||||
|
||||
goal = Goal(
|
||||
id="mark-complete-test",
|
||||
name="Mark Complete Test",
|
||||
description="Test mark_complete subagent flow",
|
||||
success_criteria=[
|
||||
SuccessCriterion(
|
||||
id="result",
|
||||
description="Result present",
|
||||
metric="output_contains",
|
||||
target="result",
|
||||
)
|
||||
],
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
parent_node = NodeSpec(
|
||||
id="parent",
|
||||
name="Parent",
|
||||
description="Parent that delegates to researcher",
|
||||
node_type="event_loop",
|
||||
input_keys=["query"],
|
||||
output_keys=["result"],
|
||||
sub_agents=["researcher"],
|
||||
system_prompt="You delegate research tasks to the researcher sub-agent.",
|
||||
)
|
||||
|
||||
researcher_node = NodeSpec(
|
||||
id="researcher",
|
||||
name="Researcher",
|
||||
description="Researches topics and reports findings",
|
||||
node_type="event_loop",
|
||||
input_keys=["task"],
|
||||
output_keys=["findings"],
|
||||
system_prompt="You research topics. Use report_to_parent with mark_complete when done.",
|
||||
)
|
||||
|
||||
graph = GraphSpec(
|
||||
id="mark-complete-graph",
|
||||
goal_id=goal.id,
|
||||
version="1.0.0",
|
||||
entry_node="parent",
|
||||
entry_points={"start": "parent"},
|
||||
terminal_nodes=["parent"],
|
||||
pause_nodes=[],
|
||||
nodes=[parent_node, researcher_node],
|
||||
edges=[],
|
||||
default_model="mock",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
# LLM call sequence:
|
||||
# Call 0 (parent turn 1): delegate to researcher
|
||||
# Call 1 (subagent turn 1): report_to_parent(mark_complete=True) → sets flag
|
||||
# Call 2 (subagent turn 2): text finish (inner loop exit) → _evaluate sees flag → ACCEPT
|
||||
# Call 3 (parent turn 2): set_output("result", "...")
|
||||
# Call 4 (parent turn 3): text finish
|
||||
scenarios: list[list[StreamEvent]] = [
|
||||
# Call 0: Parent delegates
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="delegate_to_sub_agent",
|
||||
tool_input={"agent_id": "researcher", "task": "Find LinkedIn profiles"},
|
||||
tool_use_id="delegate_1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 1: Subagent reports with mark_complete=True
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="report_to_parent",
|
||||
tool_input={
|
||||
"message": "Found 3 matching profiles",
|
||||
"data": {"profiles": ["alice", "bob", "carol"]},
|
||||
"mark_complete": True,
|
||||
},
|
||||
tool_use_id="report_1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 2: Subagent text finish (inner loop needs this to exit)
|
||||
[
|
||||
TextDeltaEvent(content="Done.", snapshot="Done."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 3: Parent uses subagent result to set output
|
||||
[
|
||||
ToolCallEvent(
|
||||
tool_name="set_output",
|
||||
tool_input={"key": "result", "value": "Found 3 profiles: alice, bob, carol"},
|
||||
tool_use_id="set_1",
|
||||
),
|
||||
FinishEvent(stop_reason="tool_use", input_tokens=10, output_tokens=5, model="mock"),
|
||||
],
|
||||
# Call 4: Parent finishes
|
||||
[
|
||||
TextDeltaEvent(content="Task complete.", snapshot="Task complete."),
|
||||
FinishEvent(stop_reason="end_turn", input_tokens=5, output_tokens=5, model="mock"),
|
||||
],
|
||||
]
|
||||
|
||||
llm = SequencedLLM(scenarios)
|
||||
bus = EventBus()
|
||||
|
||||
# Track subagent report events
|
||||
report_events: list[AgentEvent] = []
|
||||
|
||||
async def report_handler(event: AgentEvent):
|
||||
if event.type == EventType.SUBAGENT_REPORT:
|
||||
report_events.append(event)
|
||||
|
||||
bus.subscribe(event_types=[EventType.SUBAGENT_REPORT], handler=report_handler)
|
||||
|
||||
storage = ConcurrentStorage(tmp_path)
|
||||
await storage.start()
|
||||
|
||||
stream = ExecutionStream(
|
||||
stream_id="start",
|
||||
entry_spec=EntryPointSpec(
|
||||
id="start",
|
||||
name="Start",
|
||||
entry_node="parent",
|
||||
trigger_type="manual",
|
||||
isolation_level="shared",
|
||||
),
|
||||
graph=graph,
|
||||
goal=goal,
|
||||
state_manager=SharedStateManager(),
|
||||
storage=storage,
|
||||
outcome_aggregator=OutcomeAggregator(goal, bus),
|
||||
event_bus=bus,
|
||||
llm=llm,
|
||||
tools=[],
|
||||
tool_executor=None,
|
||||
)
|
||||
|
||||
await stream.start()
|
||||
execution_id = await stream.execute({"query": "Find LinkedIn profiles"})
|
||||
result = await stream.wait_for_completion(execution_id, timeout=15)
|
||||
await stream.stop()
|
||||
await storage.stop()
|
||||
|
||||
# -- Assertions --
|
||||
|
||||
# 1. Execution completed successfully
|
||||
assert result is not None, "Execution should have completed"
|
||||
assert result.success, f"Execution should have succeeded, got: {result}"
|
||||
|
||||
# 2. Parent got the final output
|
||||
assert "result" in result.output
|
||||
assert "3 profiles" in result.output["result"]
|
||||
|
||||
# 3. Subagent report was emitted via event bus
|
||||
# (The subagent's EventLoopNode has event_bus=None, but _execute_subagent
|
||||
# wires its own callback that emits via the parent's bus)
|
||||
assert len(report_events) >= 1, "Should have received subagent report event"
|
||||
assert report_events[0].data["message"] == "Found 3 matching profiles"
|
||||
|
||||
# 4. The subagent did NOT need to call set_output — it used mark_complete
|
||||
# Verify by checking LLM call count: subagent only needed 2 calls
|
||||
# (report_to_parent + text finish), not 3+ (report + set_output + text finish)
|
||||
assert llm._call_index == 5, (
|
||||
f"Expected 5 LLM calls total (delegate + report + finish + set_output + finish), "
|
||||
f"got {llm._call_index}"
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user