diff --git a/.claude/settings.json b/.claude/settings.json index 4736a78d..e526549f 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -30,7 +30,11 @@ "Bash(grep -E \"\\\\.\\(ts|tsx\\)$\")", "Bash(xargs cat:*)", "Bash(find /home/timothy/aden/hive -path \"*/.venv\" -prune -o -name \"*.py\" -type f -exec grep -l \"frontend\\\\|UI\\\\|terminal\\\\|interactive\\\\|TUI\" {} \\\\;)", - "Bash(wc -l /home/timothy/.hive/backup/*/SKILL.md)" + "Bash(wc -l /home/timothy/.hive/backup/*/SKILL.md)", + "Bash(awk -F'::' '{print $1}')", + "Bash(wait)", + "Bash(pkill -f \"pytest.*test_event_loop_node\")", + "Bash(pkill -f \"pytest.*TestToolConcurrency\")" ], "additionalDirectories": [ "/home/timothy/.hive/skills/writing-hive-skills", diff --git a/core/framework/agent_loop/agent_loop.py b/core/framework/agent_loop/agent_loop.py index 6b588f9e..6e6a0069 100644 --- a/core/framework/agent_loop/agent_loop.py +++ b/core/framework/agent_loop/agent_loop.py @@ -586,6 +586,10 @@ class AgentLoop(AgentProtocol): output_keys=ctx.agent_spec.output_keys or None, store=self._conversation_store, run_id=ctx.effective_run_id, + compaction_buffer_tokens=self._config.compaction_buffer_tokens, + compaction_warning_buffer_tokens=( + self._config.compaction_warning_buffer_tokens + ), ) accumulator = OutputAccumulator( store=self._conversation_store, @@ -2411,6 +2415,33 @@ class AgentLoop(AgentProtocol): tool_calls: list[ToolCallEvent] = [] _stream_error: StreamErrorEvent | None = None + # Gap 1 - Streaming tool execution. Any tool flagged as + # concurrency_safe is kicked off the moment its ToolCallEvent + # arrives in the stream, instead of waiting for the full + # assistant message stop event. The dispatch phase below + # reuses these already-running tasks so read_file / grep / + # glob overlap with whatever text the model is still + # generating. Unsafe tools (bash, edits, browser actions) + # still wait for FinishEvent so we don't race a write + # against a decision the model hasn't finished making. + _early_safe_names = { + t.name for t in tools if getattr(t, "concurrency_safe", False) + } + _early_tasks: dict[str, asyncio.Task] = {} + + async def _timed_execute( + _tc: ToolCallEvent, + ) -> tuple[ToolResult | BaseException, str, float]: + """Execute a tool and return (result, start_iso, duration_s).""" + _s = time.time() + _iso = datetime.now(UTC).isoformat() + try: + _r = await self._execute_tool(_tc) + except BaseException as _exc: + _r = _exc + _dur = round(time.time() - _s, 3) + return _r, _iso, _dur + logger.debug( "[_run_single_turn] inner_turn=%d: Starting LLM stream with %d messages, %d tools", inner_turn, @@ -2447,6 +2478,9 @@ class AgentLoop(AgentProtocol): _msgs: list = messages, # noqa: B006 _tc: list[ToolCallEvent] = tool_calls, # noqa: B006 inner_turn: int = inner_turn, + _safe_names: set = _early_safe_names, # noqa: B006,B008 + _tasks: dict = _early_tasks, # noqa: B006,B008 + _exec_fn=_timed_execute, ) -> None: nonlocal accumulated_text, _stream_error, _stream_last_event_at _clean_snapshot = "" # visible-only text for the frontend @@ -2480,6 +2514,18 @@ class AgentLoop(AgentProtocol): elif isinstance(event, ToolCallEvent): _tc.append(event) + # Gap 1: start concurrency-safe tools immediately + # while the rest of the stream is still arriving, + # so read-heavy turns don't stall after the last + # text delta. Unsafe tools wait for FinishEvent. + if ( + event.tool_name in _safe_names + and "_raw" not in event.tool_input + and event.tool_use_id not in _tasks + ): + _tasks[event.tool_use_id] = asyncio.create_task( + _exec_fn(event) + ) elif isinstance(event, FinishEvent): token_counts["input"] += event.input_tokens @@ -2552,6 +2598,12 @@ class AgentLoop(AgentProtocol): logger.debug("[_run_single_turn] inner_turn=%d: Stream task cancelled", inner_turn) if accumulated_text: await conversation.add_assistant_message(content=accumulated_text) + # Gap 1: kill any early-dispatched tool tasks too. + # Without this, a safe tool started during streaming + # would leak past cancellation and keep running. + for _early in _early_tasks.values(): + if not _early.done(): + _early.cancel() # Distinguish cancel_current_turn() (cancels the child # _stream_task) from stop_worker (cancels the parent # execution task). When the parent itself is cancelled, @@ -2566,6 +2618,12 @@ class AgentLoop(AgentProtocol): logger.exception( "[_run_single_turn] inner_turn=%d: Stream task failed: %s", inner_turn, e ) + # Don't orphan early tool tasks on a stream failure + # either - the outer retry loop will re-emit the tool + # calls on the next attempt. + for _early in _early_tasks.values(): + if not _early.done(): + _early.cancel() raise finally: self._stream_task = None @@ -2575,6 +2633,9 @@ class AgentLoop(AgentProtocol): # raise so the outer transient-error retry can handle it # with proper backoff instead of burning judge iterations. if _stream_error and not accumulated_text and not tool_calls: + for _early in _early_tasks.values(): + if not _early.done(): + _early.cancel() raise ConnectionError( f"Stream failed with recoverable error: {_stream_error.error}" ) @@ -2945,39 +3006,116 @@ class AgentLoop(AgentProtocol): else: pending_real.append(tc) - # Phase 2a: execute real tools in parallel. + # Phase 2a: partition real tools by concurrency safety. + # Read-only tools flagged concurrency_safe run in one parallel + # batch (bounded by a semaphore). Everything else - shell, file + # writes, browser actions, unknown MCP tools - runs serially + # afterwards so we can't race an edit against a bash command + # that touches the same path. Result ordering is preserved via + # results_by_id below; the split only affects scheduling. + # Reuses the same _early_safe_names set the stream used for + # Gap 1 early dispatch, so "safe" means exactly the same + # thing in both places. + parallel_batch: list[ToolCallEvent] = [] + serial_batch: list[ToolCallEvent] = [] + for tc in pending_real: + if tc.tool_name in _early_safe_names: + parallel_batch.append(tc) + else: + serial_batch.append(tc) + if pending_real: + # Cap on concurrent read-only tool executions. Ten matches + # Claude Code's StreamingToolExecutor default and keeps MCP + # server load bounded on turns where the model issues a + # big fan-out of reads. + _PARALLEL_CAP = 10 + _parallel_sem = asyncio.Semaphore(_PARALLEL_CAP) - async def _timed_execute( + async def _capped( _tc: ToolCallEvent, + _sem: asyncio.Semaphore = _parallel_sem, # noqa: B008,B023 ) -> tuple[ToolResult | BaseException, str, float]: - """Execute a tool and return (result, start_iso, duration_s).""" - _s = time.time() - _iso = datetime.now(UTC).isoformat() - try: - _r = await self._execute_tool(_tc) - except BaseException as _exc: - _r = _exc - _dur = round(time.time() - _s, 3) - return _r, _iso, _dur + async with _sem: + return await _timed_execute(_tc) - self._tool_task = asyncio.ensure_future( - asyncio.gather( - *(_timed_execute(tc) for tc in pending_real), - return_exceptions=True, + timed_results_by_id: dict[ + str, tuple[ToolResult | BaseException, str, float] | BaseException + ] = {} + + # Phase 2b: resolve the concurrency-safe batch. Prefer + # any early task already started during streaming (Gap + # 1) so we don't accidentally execute the same tool + # twice; for everything else, schedule via the semaphore- + # capped wrapper as before. + if parallel_batch: + _awaitables: list = [] + for tc in parallel_batch: + early = _early_tasks.get(tc.tool_use_id) + if early is not None: + _awaitables.append(early) + else: + _awaitables.append(_capped(tc)) + self._tool_task = asyncio.ensure_future( + asyncio.gather(*_awaitables, return_exceptions=True) ) - ) - try: - timed_results = await self._tool_task - finally: - self._tool_task = None - # gather(return_exceptions=True) captures CancelledError - # as a return value instead of propagating it. Re-raise - # so stop_worker actually stops the execution. - for entry in timed_results: - if isinstance(entry, asyncio.CancelledError): - raise entry - for tc, entry in zip(pending_real, timed_results, strict=True): + try: + parallel_timed = await self._tool_task + finally: + self._tool_task = None + # gather(return_exceptions=True) captures CancelledError + # as a return value instead of propagating it. Re-raise + # so stop_worker actually stops the execution. + for entry in parallel_timed: + if isinstance(entry, asyncio.CancelledError): + raise entry + for tc, entry in zip(parallel_batch, parallel_timed, strict=True): + timed_results_by_id[tc.tool_use_id] = entry + + # Phase 2c: run unsafe tools sequentially. On a raised + # exception, cancel the remaining siblings with a clear + # error so the model sees the cascade instead of a silent + # drop. A ToolResult with is_error=True is a normal return + # (e.g. "file not found") and does NOT trip the cascade - + # the model should see subsequent errors too. + _serial_cascade_broken = False + for tc in serial_batch: + if _serial_cascade_broken: + timed_results_by_id[tc.tool_use_id] = ( + ToolResult( + tool_use_id=tc.tool_use_id, + content=( + "Cancelled: an earlier non-concurrent tool " + "in this turn raised an exception. Re-issue " + "this call once the previous error is resolved." + ), + is_error=True, + ), + datetime.now(UTC).isoformat(), + 0.0, + ) + continue + + self._tool_task = asyncio.ensure_future(_timed_execute(tc)) + try: + entry = await self._tool_task + finally: + self._tool_task = None + + timed_results_by_id[tc.tool_use_id] = entry + raw_check = entry[0] if isinstance(entry, tuple) else entry + if isinstance(raw_check, BaseException) and not isinstance( + raw_check, asyncio.CancelledError + ): + _serial_cascade_broken = True + elif isinstance(raw_check, asyncio.CancelledError): + raise raw_check + + # Phase 2d: reassemble results in original call order so + # the rest of the loop sees no difference from the + # pre-partition world. + for tc in pending_real: + entry = timed_results_by_id[tc.tool_use_id] if isinstance(entry, BaseException): raw = entry _start_iso = datetime.now(UTC).isoformat() diff --git a/core/framework/agent_loop/conversation.py b/core/framework/agent_loop/conversation.py index 39fb3314..6240f328 100644 --- a/core/framework/agent_loop/conversation.py +++ b/core/framework/agent_loop/conversation.py @@ -381,10 +381,20 @@ class NodeConversation: output_keys: list[str] | None = None, store: ConversationStore | None = None, run_id: str | None = None, + compaction_buffer_tokens: int | None = None, + compaction_warning_buffer_tokens: int | None = None, ) -> None: self._system_prompt = system_prompt self._max_context_tokens = max_context_tokens self._compaction_threshold = compaction_threshold + # Buffer-based compaction trigger (Gap 7). When set, takes + # precedence over the multiplicative compaction_threshold so the + # loop reserves a fixed headroom for the next turn's input+output + # instead of trying to get exactly X% of the way to the hard + # limit. If left as None the legacy threshold-based rule is + # used, keeping old call sites behaving identically. + self._compaction_buffer_tokens = compaction_buffer_tokens + self._compaction_warning_buffer_tokens = compaction_warning_buffer_tokens self._output_keys = output_keys self._store = store self._messages: list[Message] = [] @@ -729,8 +739,37 @@ class NodeConversation: return self.estimate_tokens() / self._max_context_tokens def needs_compaction(self) -> bool: + """True when the conversation should be compacted before the + next LLM call. + + Buffer-based rule (Gap 7): trigger when the current estimate + plus the configured buffer would exceed the hard context limit. + Prevents compaction from firing only AFTER we're already over + the wire and forced into a reactive binary-split pass. + + When no buffer is configured, falls back to the multiplicative + threshold the old callers were built around. + """ + if self._max_context_tokens <= 0: + return False + if self._compaction_buffer_tokens is not None: + budget = self._max_context_tokens - self._compaction_buffer_tokens + return self.estimate_tokens() >= max(0, budget) return self.estimate_tokens() >= self._max_context_tokens * self._compaction_threshold + def compaction_warning(self) -> bool: + """True when the conversation has crossed the warning threshold + but not yet the hard compaction trigger. + + Used by telemetry / UI to show a "context getting tight" hint + before a compaction pass actually runs. Returns False when no + warning buffer is configured (legacy behaviour). + """ + if self._max_context_tokens <= 0 or self._compaction_warning_buffer_tokens is None: + return False + warn_at = self._max_context_tokens - self._compaction_warning_buffer_tokens + return self.estimate_tokens() >= max(0, warn_at) + # --- Output-key extraction --------------------------------------------- def _extract_protected_values(self, messages: list[Message]) -> dict[str, str]: @@ -1264,6 +1303,10 @@ class NodeConversation: "system_prompt": self._system_prompt, "max_context_tokens": self._max_context_tokens, "compaction_threshold": self._compaction_threshold, + "compaction_buffer_tokens": self._compaction_buffer_tokens, + "compaction_warning_buffer_tokens": ( + self._compaction_warning_buffer_tokens + ), "output_keys": self._output_keys, } await self._store.write_meta(run_meta) @@ -1311,6 +1354,10 @@ class NodeConversation: output_keys=meta.get("output_keys"), store=store, run_id=run_id, + compaction_buffer_tokens=meta.get("compaction_buffer_tokens"), + compaction_warning_buffer_tokens=meta.get( + "compaction_warning_buffer_tokens" + ), ) conv._meta_persisted = True diff --git a/core/framework/agent_loop/internals/types.py b/core/framework/agent_loop/internals/types.py index 0c7633e8..cf9dcc05 100644 --- a/core/framework/agent_loop/internals/types.py +++ b/core/framework/agent_loop/internals/types.py @@ -54,6 +54,17 @@ class LoopConfig: stall_detection_threshold: int = 3 stall_similarity_threshold: float = 0.85 max_context_tokens: int = 32_000 + # Headroom reserved for the NEXT turn's input + output so that + # proactive compaction always finishes before the hard context limit + # is hit mid-stream. Scaled to match Claude Code's 13k-buffer-on- + # 200k-window ratio (~6.5%) applied to hive's default 32k window, + # with extra margin because hive's token estimator is char-based + # and less tight than Anthropic's own counting. Override via + # LoopConfig for larger windows. + compaction_buffer_tokens: int = 8_000 + # Warning is emitted one buffer earlier so the user/telemetry gets + # a "we're close" signal without triggering a compaction pass. + compaction_warning_buffer_tokens: int = 12_000 store_prefix: str = "" # Overflow margin for max_tool_calls_per_turn. Tool calls are only diff --git a/core/framework/llm/provider.py b/core/framework/llm/provider.py index fcdaa61d..853b6d8b 100644 --- a/core/framework/llm/provider.py +++ b/core/framework/llm/provider.py @@ -27,6 +27,12 @@ class Tool: name: str description: str parameters: dict[str, Any] = field(default_factory=dict) + # If True, this tool performs no filesystem/process/network writes and is + # safe to run concurrently with other safe-flagged tools inside the same + # assistant turn. Unsafe tools (writes, shell, browser actions) are always + # serialized after the safe batch. Default False - the conservative choice + # when a tool's behavior isn't explicitly vetted. + concurrency_safe: bool = False @dataclass diff --git a/core/framework/loader/tool_registry.py b/core/framework/loader/tool_registry.py index 2dd7efbf..e7434912 100644 --- a/core/framework/loader/tool_registry.py +++ b/core/framework/loader/tool_registry.py @@ -50,6 +50,33 @@ class ToolRegistry: # and auto-injected at call time for tools that accept them. CONTEXT_PARAMS = frozenset({"agent_id", "data_dir", "profile"}) + # Tools that perform no filesystem/process/network writes and are safe + # to run concurrently with other safe tools in the same assistant turn. + # Unknown tools default to unsafe (serialized) - adding a name here is + # an explicit promise about that tool's side effects. Keep this list + # conservative: anything that mutates state, writes to disk, issues + # POST/PUT/DELETE requests, or drives a browser MUST NOT be listed. + CONCURRENCY_SAFE_TOOLS = frozenset( + { + # File system reads + "read_file", + "list_directory", + "grep", + "glob", + # Web reads + "web_search", + "web_fetch", + # Browser read-only snapshots (mutate-free observations) + "browser_screenshot", + "browser_snapshot", + "browser_console", + "browser_get_text", + # Background bash polling - reads output buffers only, does + # not touch the subprocess itself. + "bash_output", + } + ) + # Credential directory used for change detection _CREDENTIAL_DIR = Path("~/.hive/credentials/credentials").expanduser() @@ -152,6 +179,7 @@ class ToolRegistry: "properties": properties, "required": required, }, + concurrency_safe=tool_name in self.CONCURRENCY_SAFE_TOOLS, ) def executor(inputs: dict) -> Any: @@ -970,6 +998,7 @@ class ToolRegistry: "properties": properties, "required": required, }, + concurrency_safe=mcp_tool.name in self.CONCURRENCY_SAFE_TOOLS, ) return tool diff --git a/core/tests/test_event_loop_node.py b/core/tests/test_event_loop_node.py index 82a76f85..0c796516 100644 --- a/core/tests/test_event_loop_node.py +++ b/core/tests/test_event_loop_node.py @@ -1837,3 +1837,299 @@ class TestSubagentAccumulatorMemory: # Should return None (not raise PermissionError) assert scoped.read("tweet_content") is None assert scoped.read("user_request") == "hi" + + +# --------------------------------------------------------------------------- +# Tool concurrency partitioning (Gap 5) +# --------------------------------------------------------------------------- + + +def _multi_tool_scenario(*calls: tuple[str, dict, str]) -> list: + """Build a stream scenario that emits multiple tool calls in one turn. + + Each ``calls`` entry is ``(tool_name, tool_input, tool_use_id)``. + """ + events: list = [] + for name, inp, uid in calls: + events.append( + ToolCallEvent(tool_use_id=uid, tool_name=name, tool_input=inp) + ) + events.append( + FinishEvent(stop_reason="tool_calls", input_tokens=10, output_tokens=5, model="mock") + ) + return events + + +class TestToolConcurrencyPartition: + """Gap 5: safe tools run in parallel, unsafe tools serialize after them.""" + + @pytest.mark.asyncio + async def test_safe_tools_overlap_unsafe_tools_do_not( + self, runtime, node_spec, buffer + ): + """A turn with (safe, safe, unsafe) schedules safes in parallel and + runs unsafe strictly after both safes have started.""" + scenario = _multi_tool_scenario( + ("read_file", {"path": "/a"}, "call_1"), + ("read_file", {"path": "/b"}, "call_2"), + ("execute_command", {"command": "echo hi"}, "call_3"), + ) + # Second turn emits plain text so the loop terminates. + llm = MockStreamingLLM(scenarios=[scenario, text_scenario("done")]) + node_spec.output_keys = [] + + start_events: list[tuple[str, float]] = [] + end_events: list[tuple[str, float]] = [] + + async def tool_exec(tool_use: ToolUse) -> ToolResult: + start_events.append((tool_use.id, asyncio.get_event_loop().time())) + # The two safes sleep long enough that a serial scheduler + # would show them end-before-start, but a parallel scheduler + # overlaps them. execute_command also sleeps so we can prove + # it started AFTER both safes started. + await asyncio.sleep(0.05) + end_events.append((tool_use.id, asyncio.get_event_loop().time())) + return ToolResult(tool_use_id=tool_use.id, content="ok", is_error=False) + + tools = [ + Tool( + name="read_file", + description="", + parameters={}, + concurrency_safe=True, + ), + Tool( + name="execute_command", + description="", + parameters={}, + concurrency_safe=False, + ), + ] + + ctx = build_ctx( + runtime, + node_spec, + buffer, + llm, + tools=tools, + is_subagent_mode=True, + ) + node = EventLoopNode( + tool_executor=tool_exec, + config=LoopConfig(max_iterations=3), + ) + await node.execute(ctx) + + # Build lookup dicts for readability. + starts = dict(start_events) + ends = dict(end_events) + + # Both safe reads must start (approximately) together and before + # either has finished - proving they ran concurrently. + assert starts["call_1"] < ends["call_2"] + assert starts["call_2"] < ends["call_1"] + + # The unsafe tool must start strictly AFTER both safes have ended - + # proving it was serialized after the parallel batch. + assert starts["call_3"] >= ends["call_1"] + assert starts["call_3"] >= ends["call_2"] + + @pytest.mark.asyncio + async def test_serial_exception_cascades_cancel_siblings( + self, runtime, node_spec, buffer + ): + """When an unsafe tool raises, the remaining unsafe siblings are + cancelled with a clear error rather than silently executed.""" + scenario = _multi_tool_scenario( + ("execute_command", {"command": "boom"}, "call_1"), + ("execute_command", {"command": "echo survivor"}, "call_2"), + ) + llm = MockStreamingLLM(scenarios=[scenario, text_scenario("done")]) + node_spec.output_keys = [] + + executed: list[str] = [] + + async def tool_exec(tool_use: ToolUse) -> ToolResult: + executed.append(tool_use.id) + if tool_use.id == "call_1": + raise RuntimeError("first tool exploded") + return ToolResult(tool_use_id=tool_use.id, content="ok", is_error=False) + + tools = [ + Tool( + name="execute_command", + description="", + parameters={}, + concurrency_safe=False, + ), + ] + ctx = build_ctx( + runtime, + node_spec, + buffer, + llm, + tools=tools, + is_subagent_mode=True, + ) + node = EventLoopNode( + tool_executor=tool_exec, + config=LoopConfig(max_iterations=3), + ) + await node.execute(ctx) + + # First tool ran (and raised); second tool must NOT have run. + assert executed == ["call_1"] + + @pytest.mark.asyncio + async def test_safe_tool_starts_before_finish_event( + self, runtime, node_spec, buffer + ): + """Gap 1: a concurrency-safe tool must start executing while the + stream is still in flight, not after the final FinishEvent. + + Builds a custom LLM that sleeps between the ToolCallEvent and + the FinishEvent. A well-behaved harness starts the tool as soon + as the ToolCallEvent arrives, so by the time FinishEvent lands + the tool has already been running for ~sleep_seconds. + """ + from framework.llm.stream_events import FinishEvent, ToolCallEvent + + delay = 0.25 + + class SlowStreamLLM(LLMProvider): + def __init__(self): + self._calls = 0 + + async def stream(self, messages, system="", tools=None, max_tokens=4096): + self._calls += 1 + if self._calls == 1: + # Emit the tool call, stall, then finish. + yield ToolCallEvent( + tool_use_id="call_1", + tool_name="read_file", + tool_input={"path": "/a"}, + ) + await asyncio.sleep(delay) + yield FinishEvent( + stop_reason="tool_calls", + input_tokens=10, + output_tokens=5, + model="mock", + ) + else: + # Turn 2 needs to match text_scenario shape so the + # outer loop terminates cleanly (needs a text delta + # before the finish event; empty turns are treated + # as worker silence and fall into the escalation + # grace window). + yield TextDeltaEvent(content="done", snapshot="done") + yield FinishEvent( + stop_reason="stop", + input_tokens=1, + output_tokens=1, + model="mock", + ) + + def complete(self, messages, system="", **kwargs) -> LLMResponse: + return LLMResponse(content="", model="mock", stop_reason="stop") + + tool_started_at: list[float] = [] + tool_finished_at: list[float] = [] + + async def tool_exec(tool_use: ToolUse) -> ToolResult: + tool_started_at.append(asyncio.get_event_loop().time()) + # Short simulated work so the tool finishes before the stream + # does; this proves the tool was running concurrently with + # the sleep inside the LLM stream. + await asyncio.sleep(0.05) + tool_finished_at.append(asyncio.get_event_loop().time()) + return ToolResult(tool_use_id=tool_use.id, content="ok", is_error=False) + + tools = [ + Tool( + name="read_file", + description="", + parameters={}, + concurrency_safe=True, + ), + ] + node_spec.output_keys = [] + llm = SlowStreamLLM() + ctx = build_ctx( + runtime, + node_spec, + buffer, + llm, + tools=tools, + is_subagent_mode=True, + ) + node = EventLoopNode( + tool_executor=tool_exec, + config=LoopConfig(max_iterations=3), + ) + turn_started = asyncio.get_event_loop().time() + await node.execute(ctx) + turn_ended = asyncio.get_event_loop().time() + + assert tool_started_at, "tool never ran" + # The tool must have STARTED within the LLM's sleep window - + # i.e. before turn_started + delay, not after. A post-stream + # dispatcher would start the tool at turn_started + delay or + # later. + assert tool_started_at[0] < turn_started + delay, ( + f"tool started at +{tool_started_at[0] - turn_started:.3f}s, " + f"but the stream sleep was {delay}s - the harness is still " + f"waiting for FinishEvent before dispatching." + ) + # Sanity: the whole turn took at least the sleep window (the + # stream had to drain before dispatch). + assert turn_ended - turn_started >= delay + + @pytest.mark.asyncio + async def test_soft_error_does_not_cascade( + self, runtime, node_spec, buffer + ): + """A ToolResult with is_error=True (e.g. 'file not found') is a + normal return and must NOT cancel subsequent serial siblings - the + model needs to see all tool errors to decide what to do next.""" + scenario = _multi_tool_scenario( + ("execute_command", {"command": "false"}, "call_1"), + ("execute_command", {"command": "echo two"}, "call_2"), + ) + llm = MockStreamingLLM(scenarios=[scenario, text_scenario("done")]) + node_spec.output_keys = [] + + executed: list[str] = [] + + async def tool_exec(tool_use: ToolUse) -> ToolResult: + executed.append(tool_use.id) + return ToolResult( + tool_use_id=tool_use.id, + content="soft error" if tool_use.id == "call_1" else "ok", + is_error=(tool_use.id == "call_1"), + ) + + tools = [ + Tool( + name="execute_command", + description="", + parameters={}, + concurrency_safe=False, + ), + ] + ctx = build_ctx( + runtime, + node_spec, + buffer, + llm, + tools=tools, + is_subagent_mode=True, + ) + node = EventLoopNode( + tool_executor=tool_exec, + config=LoopConfig(max_iterations=3), + ) + await node.execute(ctx) + + # Both tools must have run: soft errors don't cascade. + assert executed == ["call_1", "call_2"] diff --git a/core/tests/test_node_conversation.py b/core/tests/test_node_conversation.py index d426311f..e6c1811c 100644 --- a/core/tests/test_node_conversation.py +++ b/core/tests/test_node_conversation.py @@ -243,6 +243,61 @@ class TestNodeConversation: await conv.add_user_message("x" * 320) assert conv.needs_compaction() is True + @pytest.mark.asyncio + async def test_needs_compaction_uses_buffer_when_set(self): + """Gap 7: a compaction_buffer_tokens overrides the multiplicative + threshold - compaction triggers when estimate + buffer would + cross the hard context limit, not at a fractional threshold.""" + conv = NodeConversation( + max_context_tokens=1000, + compaction_threshold=0.9, # would normally trigger at 900 + compaction_buffer_tokens=300, # buffer wants 700 hard cap + ) + # 650 tokens is below the 700 budget - no compaction yet. + conv.update_token_count(650) + assert conv.needs_compaction() is False + # 700+ crosses the budget - compaction fires BEFORE reaching + # the hard 1000 limit, so the next turn's input has headroom. + conv.update_token_count(700) + assert conv.needs_compaction() is True + + @pytest.mark.asyncio + async def test_compaction_warning_fires_before_hard_trigger(self): + """Gap 7: the warning threshold is meant to surface early signal + to telemetry without actually triggering compaction.""" + conv = NodeConversation( + max_context_tokens=1000, + compaction_buffer_tokens=200, + compaction_warning_buffer_tokens=400, + ) + conv.update_token_count(500) + assert conv.compaction_warning() is False + assert conv.needs_compaction() is False + + # Cross 600 tokens: warning fires (1000 - 400) but compaction + # doesn't yet (1000 - 200 = 800 budget). + conv.update_token_count(650) + assert conv.compaction_warning() is True + assert conv.needs_compaction() is False + + # Cross 800: both fire. + conv.update_token_count(820) + assert conv.compaction_warning() is True + assert conv.needs_compaction() is True + + @pytest.mark.asyncio + async def test_legacy_threshold_rule_still_works_without_buffer(self): + """Without compaction_buffer_tokens, the old multiplicative rule + applies so existing callers keep behaving identically.""" + conv = NodeConversation( + max_context_tokens=1000, + compaction_threshold=0.75, + ) + conv.update_token_count(700) + assert conv.needs_compaction() is False + conv.update_token_count(800) + assert conv.needs_compaction() is True + @pytest.mark.asyncio async def test_compact_replaces_with_summary(self): """keep_recent=0 replaces all messages; empty conversation is a no-op.""" diff --git a/core/tests/test_tool_registry.py b/core/tests/test_tool_registry.py index 8e5bcdd3..6a1e9495 100644 --- a/core/tests/test_tool_registry.py +++ b/core/tests/test_tool_registry.py @@ -797,3 +797,60 @@ def test_resync_returns_false_when_credentials_unchanged(tmp_path, monkeypatch): monkeypatch.setattr(registry, "_snapshot_credentials", lambda: set()) assert registry.resync_mcp_servers_if_needed() is False + + +# --------------------------------------------------------------------------- +# Concurrency-safe flag propagation +# --------------------------------------------------------------------------- + + +def test_mcp_tool_conversion_marks_known_safe_tools(): + """MCP tools whose names are in CONCURRENCY_SAFE_TOOLS become concurrency_safe.""" + from framework.loader.mcp_client import MCPTool + registry = ToolRegistry() + + safe_mcp = MCPTool( + name="read_file", + description="", + input_schema={"type": "object", "properties": {}, "required": []}, + server_name="stub", + ) + unsafe_mcp = MCPTool( + name="execute_command", + description="", + input_schema={"type": "object", "properties": {}, "required": []}, + server_name="stub", + ) + + safe_tool = registry._convert_mcp_tool_to_framework_tool(safe_mcp) # noqa: SLF001 + unsafe_tool = registry._convert_mcp_tool_to_framework_tool(unsafe_mcp) # noqa: SLF001 + + assert safe_tool.concurrency_safe is True + assert unsafe_tool.concurrency_safe is False + + +def test_concurrency_safe_allowlist_is_conservative(): + """Every listed name must denote a read-only operation. + + This test is a guard against someone casually adding a write-capable + tool to the allowlist. If a new name is added here, justify it in the + comment above the set in tool_registry.py. + """ + from framework.loader.tool_registry import ToolRegistry + + allowlist = ToolRegistry.CONCURRENCY_SAFE_TOOLS + + # Positive assertions: known-safe read operations are present. + for name in ("read_file", "grep", "glob", "list_directory", "web_search"): + assert name in allowlist, f"{name} should be concurrency-safe" + + # Negative assertions: nothing that mutates state is allowed in. + for forbidden in ( + "execute_command", + "write_file", + "hashline_edit", + "browser_click", + "browser_type", + "browser_navigate", + ): + assert forbidden not in allowlist, f"{forbidden} must not be concurrency-safe" diff --git a/skills/linkedin-connection-greeter/SKILL.md b/skills/linkedin-connection-greeter/SKILL.md new file mode 100644 index 00000000..5aa1df42 --- /dev/null +++ b/skills/linkedin-connection-greeter/SKILL.md @@ -0,0 +1,132 @@ +--- +name: linkedin-connection-greeter +description: Automates accepting LinkedIn connections and sending a welcome message about the HoneyComb prediction market. Handles shadow DOM and Lexical editors. +--- + +# LinkedIn Connection Greeter + +This skill outlines the exact flow to accept connection requests and send a specific welcome message without triggering spam filters. + +## 1. Load Ledger +Before starting, read `data/linkedin_contacts.json`. If it doesn't exist, initialize with `{"contacts": []}`. You will use this to skip people you've already messaged. + +## 2. Scan Pending Connections +Navigate to `https://www.linkedin.com/mynetwork/invitation-manager/received/`. Wait until load + sleep 4s. +Strip unload handlers: +`browser_evaluate("(function(){window.onbeforeunload=null;})()")` + +Extract cards using this specific snippet (handles changing classes and follow invites): +```javascript +(function(){ + const btns = Array.from(document.querySelectorAll('button')).filter(b => b.textContent.includes('Accept')); + let results = []; + for (let b of btns) { + let card = b.closest('[role="listitem"]'); + if (!card) continue; + let text = card.textContent.toLowerCase(); + if (text.includes('invited you to follow') || text.includes('invited you to subscribe')) continue; + + let nameEls = Array.from(card.querySelectorAll('a[href*="/in/"]')); + let nameEl = nameEls.find(el => el.textContent.trim().length > 0); + + let r = b.getBoundingClientRect(); + results.push({ + first_name: nameEl ? nameEl.textContent.trim().split(/\s+/)[0] : 'there', + profile_url: nameEl ? nameEl.href : '', + cx: r.x + r.width/2, + cy: r.y + r.height/2 + }); + } + return results; +})(); +``` + +## 3. Process Each Card (Max 10 per run) +For each card, check if `profile_url` is already in the ledger. If not: +1. `browser_click_coordinate(cx, cy)` to click the specific Accept button. +2. `sleep(2)` +3. `browser_navigate(profile_url, wait_until="load")` +4. `sleep(4)` +5. `browser_evaluate("(function(){window.onbeforeunload=null; window.addEventListener('beforeunload', e => e.stopImmediatePropagation(), true);})()")` + +## 4. Message the User +Click Message Button on their profile: +```javascript +(function(){ + const links = Array.from(document.querySelectorAll('a[href*="/messaging/compose/"]')); + for (const a of links){ + if (!a.href.includes('NON_SELF_PROFILE_VIEW') || a.href.includes('body=')) continue; + const r = a.getBoundingClientRect(); + if (r.width === 0 || r.x > 700) continue; + return {cx: r.x + r.width / 2, cy: r.y + r.height / 2}; + } + return null; +})(); +``` +Click that coordinate, then `sleep(2.5)`. + +Find Textarea (it is hidden inside shadow DOM): +```javascript +(function(){ + const vh = window.innerHeight, vw = window.innerWidth; + const candidates = []; + function walk(root){ + const els = root.querySelectorAll ? root.querySelectorAll('div.msg-form__contenteditable') : []; + for (const el of els){ + const r = el.getBoundingClientRect(); + if (r.width > 0 && r.height > 0 && r.y >= 0 && r.y + r.height <= vh && r.x >= 0 && r.x + r.width <= vw) { + candidates.push({cx: r.x + r.width/2, cy: r.y + r.height/2, area: r.width * r.height}); + } + } + const all = root.querySelectorAll ? root.querySelectorAll('*') : []; + for (const host of all){ if (host.shadowRoot) walk(host.shadowRoot); } + } + walk(document); + candidates.sort((a, b) => b.area - a.area); + return candidates.length ? candidates[0] : null; +})(); +``` +Click that coordinate, `sleep(1)`. + +Inject text and Send: +Construct the message: `Hey {first_name}, thanks for the connection invite! I'm currently building a prediction market for jobs: https://honeycomb.open-hive.com/. If you could check it out and share some feedback, I'd really appreciate it.` + +Escape the string properly for JS injection, then run: +```javascript +// Replace MSG_TEXT with your actual string +browser_evaluate("(function(){ document.execCommand('insertText', false, `MSG_TEXT`); return true; })()") +``` + +Find Send button (also inside shadow DOM): +```javascript +(function(){ + const vh = window.innerHeight; + function walk(root){ + const btns = root.querySelectorAll ? root.querySelectorAll('button') : []; + for (const b of btns){ + const cls = (b.className || '').toString(); + if (!cls.includes('send-button') && b.textContent.trim() !== 'Send') continue; + const r = b.getBoundingClientRect(); + if (r.width <= 0 || r.y + r.height > vh) continue; + return { cx: r.x + r.width/2, cy: r.y + r.height/2, disabled: b.disabled || b.getAttribute('aria-disabled') === 'true' }; + } + const all = root.querySelectorAll ? root.querySelectorAll('*') : []; + for (const host of all){ if (host.shadowRoot) { const got = walk(host.shadowRoot); if (got) return got; } } + return null; + } + return walk(document); +})(); +``` +Click send coordinate, `sleep(2)`. + +## 5. Update Ledger +Append the user to `data/linkedin_contacts.json`. +```json +{ + "profile_url": "...", + "name": "...", + "action": "connection_accepted+message_sent", + "timestamp": "2026-..." +} +``` +`sleep(5)` before moving to the next card to mimic human pacing. diff --git a/tools/src/aden_tools/file_ops.py b/tools/src/aden_tools/file_ops.py index b9bf5127..08f93b1f 100644 --- a/tools/src/aden_tools/file_ops.py +++ b/tools/src/aden_tools/file_ops.py @@ -31,6 +31,7 @@ from pathlib import Path from fastmcp import FastMCP +from aden_tools.file_state_cache import Freshness, check_fresh, record_read from aden_tools.hashline import ( HASHLINE_MAX_FILE_BYTES, compute_line_hash, @@ -377,8 +378,16 @@ def register_file_tools( return f"Binary file: {path} ({size:,} bytes). Cannot display binary content." try: - with open(resolved, encoding="utf-8", errors="replace") as f: - content = f.read() + # Read raw bytes once; use them both for the line-formatted + # return value and to hash into the file-state cache so a + # later edit can detect external writes without a second + # open. Hash is computed even on partial/offset reads so the + # guard still fires when the model only read the start of a + # large file before editing deeper into it. + with open(resolved, "rb") as fb: + raw_bytes = fb.read() + content = raw_bytes.decode("utf-8", errors="replace") + record_read(None, resolved, content_bytes=raw_bytes) # Use splitlines() for consistent line splitting with hashline module all_lines = content.splitlines() @@ -434,6 +443,27 @@ def register_file_tools( resolved = _resolve(path) resolved_path = Path(resolved) + # Stale-edit guard: an existing file must have been read recently + # and still match the on-disk content. Writing over a file the + # model has never seen (or that changed since it last saw it) + # risks clobbering the user's work. Brand-new files are allowed + # without a prior read - there's nothing to clobber. + if resolved_path.is_file(): + _fresh = check_fresh(None, resolved) + if _fresh.status is Freshness.UNREAD: + return ( + f"Refusing to overwrite '{path}': call read_file('{path}') " + f"first so the harness can track its state before you " + f"replace it. If you intend to discard the current " + f"contents, read it first to acknowledge what you are " + f"overwriting." + ) + if _fresh.status is Freshness.STALE: + return ( + f"Refusing to overwrite '{path}': {_fresh.detail}. " + f"Re-read the file with read_file before writing." + ) + try: # Create parent dirs first (before git snapshot) so structure exists resolved_path.parent.mkdir(parents=True, exist_ok=True) @@ -452,6 +482,14 @@ def register_file_tools( f.flush() os.fsync(f.fileno()) + # Record the post-write state so a later edit in the same + # turn doesn't trip the stale-edit guard against the file + # this call just created or overwrote. + try: + record_read(None, resolved, content_bytes=content_str.encode("utf-8")) + except Exception: + pass + line_count = content_str.count("\n") + ( 1 if content_str and not content_str.endswith("\n") else 0 ) @@ -478,6 +516,23 @@ def register_file_tools( if not os.path.isfile(resolved): return f"Error: File not found: {path}" + # Stale-edit guard: refuse unless a recent read is on record and + # the file on disk still matches it. Prevents the model from + # overwriting changes the user made in their editor between + # calling read_file and edit_file. + _fresh = check_fresh(None, resolved) + if _fresh.status is Freshness.UNREAD: + return ( + f"Refusing to edit '{path}': call read_file('{path}') " + f"first so the harness can track its state before you " + f"edit it." + ) + if _fresh.status is Freshness.STALE: + return ( + f"Refusing to edit '{path}': {_fresh.detail}. Re-read " + f"the file with read_file before editing." + ) + try: with open(resolved, encoding="utf-8") as f: content = f.read() @@ -532,6 +587,13 @@ def register_file_tools( with open(resolved, "w", encoding="utf-8") as f: f.write(new_content) + # Re-record post-write state so a second edit in the same + # turn doesn't trip its own stale guard. + try: + record_read(None, resolved, content_bytes=new_content.encode("utf-8")) + except Exception: + pass + diff = _compute_diff(content, new_content, path) match_info = f" (matched via {strategy_used})" if strategy_used != "exact" else "" result = f"Replaced {count} occurrence(s) in {path}{match_info}" @@ -771,6 +833,25 @@ def register_file_tools( if not os.path.isfile(resolved): return f"Error: File not found: {path}" + # Stale-edit guard: require a prior read_file that still matches + # disk. hashline_edit already rehashes anchors, but anchor hashes + # only protect the exact lines touched - content drift around + # those lines (e.g. new imports the user added) would still slip + # through silently. This guard closes that gap. + _fresh = check_fresh(None, resolved) + if _fresh.status is Freshness.UNREAD: + return ( + f"Error: Refusing to edit '{path}': call read_file" + f"('{path}', hashline=True) first so the harness can " + f"track its state before you edit it." + ) + if _fresh.status is Freshness.STALE: + return ( + f"Error: Refusing to edit '{path}': {_fresh.detail}. " + f"Re-read the file with read_file(hashline=True) before " + f"editing." + ) + try: with open(resolved, "rb") as f: raw_head = f.read(8192) @@ -1074,6 +1155,14 @@ def register_file_tools( except Exception as e: return f"Error: Failed to write file: {e}" + # Refresh the file-state cache so chained edits in the same turn + # see the new hash instead of tripping the stale guard against + # the post-write disk state. + try: + record_read(None, resolved, content_bytes=joined.encode(encoding)) + except Exception: + pass + # 10. Build response updated_lines = joined.splitlines() total_lines = len(updated_lines) diff --git a/tools/src/aden_tools/file_state_cache.py b/tools/src/aden_tools/file_state_cache.py new file mode 100644 index 00000000..ac391e34 --- /dev/null +++ b/tools/src/aden_tools/file_state_cache.py @@ -0,0 +1,177 @@ +"""Per-agent tracking of files the model has Read, so Edit can detect +staleness from external writes (e.g. the user saving the file in their +editor between a Read and an Edit). + +The cache lives in the MCP server process and is keyed on +``(scope, absolute_path)`` where ``scope`` is the agent_id when available +(the normal case) or ``"__global__"`` as a last-resort fallback. That +keeps two agents running in the same MCP server process from sharing +(or corrupting) each other's read-state view. + +Freshness is decided by ``(size, mtime_ns, sha256)``: +- If the file's ``size`` and ``mtime_ns`` both match the recorded values, + we trust the read (fast path, no hashing). +- If either differs, we hash the current content and compare to the + recorded sha. mtime preservation by some editors means mtime alone is + unreliable; hashing only on a mismatch keeps the happy path cheap. + +The cache is bounded (LRU, 256 entries per scope) so a chatty agent +cannot grow it without bound. +""" + +from __future__ import annotations + +import hashlib +import os +import threading +from collections import OrderedDict +from dataclasses import dataclass +from enum import Enum + + +@dataclass(frozen=True) +class FileReadRecord: + size: int + mtime_ns: int + sha256: str + + +class Freshness(Enum): + FRESH = "fresh" + STALE = "stale" + UNREAD = "unread" + + +@dataclass +class FreshResult: + status: Freshness + detail: str = "" + + +_MAX_ENTRIES_PER_SCOPE = 256 + +# scope -> ordered dict of absolute_path -> FileReadRecord. +# Ordered so we can evict least-recently-read entries. +_cache: dict[str, "OrderedDict[str, FileReadRecord]"] = {} +_lock = threading.Lock() + + +def _scope_key(agent_id: str | None) -> str: + return agent_id or "__global__" + + +def _hash_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _hash_file(abs_path: str) -> str: + h = hashlib.sha256() + with open(abs_path, "rb") as fh: + for chunk in iter(lambda: fh.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + +def record_read( + agent_id: str | None, + abs_path: str, + content_bytes: bytes | None = None, +) -> None: + """Record that ``abs_path`` was just successfully read. + + If ``content_bytes`` is provided the hash is computed from that; this + is the fast path and avoids a second open. Otherwise we re-open the + file to hash it. Silently ignores files that disappear between the + read and the record (race with concurrent deletion). + """ + try: + st = os.stat(abs_path) + except OSError: + return + + try: + sha = _hash_bytes(content_bytes) if content_bytes is not None else _hash_file(abs_path) + except OSError: + return + + rec = FileReadRecord(size=st.st_size, mtime_ns=st.st_mtime_ns, sha256=sha) + scope = _scope_key(agent_id) + with _lock: + entries = _cache.setdefault(scope, OrderedDict()) + entries[abs_path] = rec + entries.move_to_end(abs_path) + while len(entries) > _MAX_ENTRIES_PER_SCOPE: + entries.popitem(last=False) + + +def check_fresh(agent_id: str | None, abs_path: str) -> FreshResult: + """Check whether ``abs_path`` is safe to edit. + + Returns FRESH if the file on disk matches the recorded read. + Returns STALE if it was read previously but has since changed. + Returns UNREAD if the agent has never read this path via read_file. + """ + scope = _scope_key(agent_id) + with _lock: + entries = _cache.get(scope) + rec = entries.get(abs_path) if entries else None + if rec is not None and entries is not None: + entries.move_to_end(abs_path) + + if rec is None: + return FreshResult(Freshness.UNREAD) + + try: + st = os.stat(abs_path) + except FileNotFoundError: + return FreshResult(Freshness.STALE, "file has been deleted since it was read") + except OSError as e: + return FreshResult(Freshness.STALE, f"stat failed: {e}") + + if st.st_size == rec.size and st.st_mtime_ns == rec.mtime_ns: + return FreshResult(Freshness.FRESH) + + # mtime/size differ - fall through to a content hash so that editors + # that rewrite the file with identical content don't trip a false + # stale. This is the only path where we pay the O(file) hashing cost. + try: + current_sha = _hash_file(abs_path) + except OSError as e: + return FreshResult(Freshness.STALE, f"hash failed: {e}") + + if current_sha == rec.sha256: + # Content is unchanged even though metadata differs (e.g. editor + # rewrote with preserved content). Refresh the record so future + # checks hit the fast path again. + rec = FileReadRecord(size=st.st_size, mtime_ns=st.st_mtime_ns, sha256=current_sha) + with _lock: + entries = _cache.setdefault(scope, OrderedDict()) + entries[abs_path] = rec + entries.move_to_end(abs_path) + return FreshResult(Freshness.FRESH) + + return FreshResult( + Freshness.STALE, + "content changed on disk since the last read (sha256 differs)", + ) + + +def forget(agent_id: str | None, abs_path: str) -> None: + """Drop a single cache entry. Used in tests to force UNREAD.""" + scope = _scope_key(agent_id) + with _lock: + entries = _cache.get(scope) + if entries is not None: + entries.pop(abs_path, None) + + +def clear_scope(agent_id: str | None) -> None: + """Drop all entries for one agent (used at session teardown).""" + with _lock: + _cache.pop(_scope_key(agent_id), None) + + +def reset_all() -> None: + """Test hook: wipe every scope.""" + with _lock: + _cache.clear() diff --git a/tools/src/aden_tools/tools/file_system_toolkits/data_tools/data_tools.py b/tools/src/aden_tools/tools/file_system_toolkits/data_tools/data_tools.py index 5b6dd95b..b4387e36 100644 --- a/tools/src/aden_tools/tools/file_system_toolkits/data_tools/data_tools.py +++ b/tools/src/aden_tools/tools/file_system_toolkits/data_tools/data_tools.py @@ -15,6 +15,8 @@ from pathlib import Path from mcp.server.fastmcp import FastMCP +from aden_tools.file_state_cache import record_read + # ~/.hive/ is always allowed for cross-agent file access HIVE_DIR = os.path.expanduser("~/.hive") @@ -71,6 +73,7 @@ def register_tools(mcp: FastMCP) -> None: offset: int = 1, limit: int = 0, data_dir: str = "", + agent_id: str = "", ) -> str: """Read file contents with line numbers. @@ -83,6 +86,8 @@ def register_tools(mcp: FastMCP) -> None: offset: Starting line number, 1-indexed (default: 1). limit: Max lines to return, 0 = up to 2000 (default: 0). data_dir: Auto-injected - the session's data directory. + agent_id: Auto-injected - the calling agent id, used to scope + the file-state cache that powers stale-edit detection. """ try: resolved = _resolve_path(path, data_dir) @@ -112,8 +117,17 @@ def register_tools(mcp: FastMCP) -> None: pass try: - with open(resolved, encoding="utf-8", errors="replace") as f: - content = f.read() + # Read as bytes first so we can hash them for the state cache + # without a second open, then decode for the line-formatted + # return value the model sees. + with open(resolved, "rb") as f: + raw_bytes = f.read() + content = raw_bytes.decode("utf-8", errors="replace") + # Record this read in the per-agent state cache so a later + # hashline_edit/write_file call can detect external writes + # that happened between now and then. Scoped to agent_id so + # two agents sharing the MCP server can't see each other. + record_read(agent_id or None, resolved, content_bytes=raw_bytes) all_lines = content.splitlines() total_lines = len(all_lines) diff --git a/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/background_jobs.py b/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/background_jobs.py new file mode 100644 index 00000000..ae7f5001 --- /dev/null +++ b/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/background_jobs.py @@ -0,0 +1,213 @@ +"""In-process registry of long-running shell jobs spawned by +``execute_command_tool(run_in_background=True)``. + +Jobs are keyed on a short id the tool returns to the agent. The agent +can then call ``bash_output(id=...)`` to poll for new output and +``bash_kill(id=...)`` to terminate. Each job is scoped to an +``agent_id`` so two agents sharing the same MCP server can't see or +kill each other's work. + +The stdout/stderr buffers are bounded rolling tail buffers (64 KB each) +so a runaway process can't exhaust memory. Older bytes are dropped with +a one-time ``[truncated N bytes]`` marker prepended to the returned +text. +""" + +from __future__ import annotations + +import asyncio +import time +from collections import deque +from dataclasses import dataclass, field +from uuid import uuid4 + +# 64 KB rolling window per stream. Large enough for long build logs, +# small enough that a bash infinite loop can't OOM the MCP process. +_MAX_BUFFER_BYTES = 64 * 1024 + + +@dataclass +class _RingBuffer: + """Append-only byte buffer with a hard byte ceiling and per-read + offset tracking so each bash_output call only returns new bytes. + """ + + max_bytes: int = _MAX_BUFFER_BYTES + # deque of (global_offset, bytes) chunks. global_offset is the total + # bytes written prior to this chunk; lets us compute "bytes since + # last poll" without copying. + _chunks: deque[tuple[int, bytes]] = field(default_factory=deque) + _total_written: int = 0 + _total_dropped: int = 0 + _read_cursor: int = 0 + + def write(self, data: bytes) -> None: + if not data: + return + self._chunks.append((self._total_written, data)) + self._total_written += len(data) + # Evict from the front until we're under the ceiling. + current_bytes = sum(len(c) for _, c in self._chunks) + while current_bytes > self.max_bytes and self._chunks: + dropped_offset, dropped = self._chunks.popleft() + self._total_dropped += len(dropped) + current_bytes -= len(dropped) + # Push the read cursor forward if the reader was still + # pointing at bytes we just evicted. + if self._read_cursor < dropped_offset + len(dropped): + self._read_cursor = dropped_offset + len(dropped) + + def read_new(self) -> str: + """Return any bytes since the last call, as decoded text. + + Includes a ``[truncated N bytes]`` prefix if rolling-window + eviction dropped any bytes the reader hadn't yet consumed. + """ + chunks_out: list[bytes] = [] + cursor = self._read_cursor + for offset, chunk in self._chunks: + end = offset + len(chunk) + if end <= cursor: + continue + start_in_chunk = max(0, cursor - offset) + chunks_out.append(chunk[start_in_chunk:]) + cursor = end + self._read_cursor = cursor + raw = b"".join(chunks_out) + text = raw.decode("utf-8", errors="replace") + # Surface eviction ONCE per poll so the agent knows to check + # the file system for larger logs instead of assuming it's got + # the full output. + if self._total_dropped > 0 and text: + text = f"[truncated {self._total_dropped} earlier bytes]\n" + text + return text + + +@dataclass +class BackgroundJob: + id: str + agent_id: str + command: str + cwd: str + started_at: float + process: asyncio.subprocess.Process + stdout_buf: _RingBuffer = field(default_factory=_RingBuffer) + stderr_buf: _RingBuffer = field(default_factory=_RingBuffer) + _pump_task: asyncio.Task | None = None + exit_code: int | None = None + + def status(self) -> str: + if self.exit_code is not None: + return f"exited({self.exit_code})" + if self.process.returncode is not None: + # Not yet surfaced by the pump but already finished. + return f"exited({self.process.returncode})" + return "running" + + +# agent_id -> {job_id -> BackgroundJob} +_jobs: dict[str, dict[str, BackgroundJob]] = {} +_jobs_lock = asyncio.Lock() + + +def _short_id() -> str: + return uuid4().hex[:8] + + +async def _pump(job: BackgroundJob) -> None: + """Drain the child process's stdout/stderr into the ring buffers.""" + proc = job.process + + async def _drain(stream: asyncio.StreamReader | None, buf: _RingBuffer) -> None: + if stream is None: + return + while True: + chunk = await stream.read(4096) + if not chunk: + return + buf.write(chunk) + + await asyncio.gather( + _drain(proc.stdout, job.stdout_buf), + _drain(proc.stderr, job.stderr_buf), + ) + job.exit_code = await proc.wait() + + +async def spawn( + command: str, cwd: str, agent_id: str +) -> BackgroundJob: + """Start a subprocess in the background and register it. The caller + holds the job id returned from here and can poll via ``get()``. + """ + proc = await asyncio.create_subprocess_shell( + command, + cwd=cwd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + job = BackgroundJob( + id=_short_id(), + agent_id=agent_id, + command=command, + cwd=cwd, + started_at=time.time(), + process=proc, + ) + # Start pumping IO in the background so the ring buffers stay warm + # even if the agent doesn't poll for a while. + job._pump_task = asyncio.create_task(_pump(job)) + + async with _jobs_lock: + _jobs.setdefault(agent_id, {})[job.id] = job + return job + + +async def get(agent_id: str, job_id: str) -> BackgroundJob | None: + async with _jobs_lock: + return _jobs.get(agent_id, {}).get(job_id) + + +async def kill(agent_id: str, job_id: str, grace_seconds: float = 3.0) -> str: + """SIGTERM a background job, escalating to SIGKILL after a grace + period. Returns a human-readable status string. + """ + job = await get(agent_id, job_id) + if job is None: + return f"no background job with id '{job_id}'" + if job.process.returncode is not None: + status = f"already exited with code {job.process.returncode}" + else: + try: + job.process.terminate() + except ProcessLookupError: + pass + try: + await asyncio.wait_for(job.process.wait(), timeout=grace_seconds) + status = f"terminated cleanly (exit={job.process.returncode})" + except asyncio.TimeoutError: + try: + job.process.kill() + except ProcessLookupError: + pass + await job.process.wait() + status = f"killed (SIGKILL, exit={job.process.returncode})" + # Deregister after kill so the id is no longer reachable. + async with _jobs_lock: + scope = _jobs.get(agent_id) + if scope is not None: + scope.pop(job_id, None) + return status + + +async def clear_agent(agent_id: str) -> None: + """Test hook: kill every job owned by ``agent_id``.""" + async with _jobs_lock: + scope = _jobs.pop(agent_id, {}) + for job in scope.values(): + if job.process.returncode is None: + try: + job.process.kill() + except ProcessLookupError: + pass + await job.process.wait() diff --git a/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/execute_command_tool.py b/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/execute_command_tool.py index efc047b4..5d0cb617 100644 --- a/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/execute_command_tool.py +++ b/tools/src/aden_tools/tools/file_system_toolkits/execute_command_tool/execute_command_tool.py @@ -1,78 +1,228 @@ +"""Shell command execution tool. + +Three tools are registered: + +* ``execute_command_tool`` runs a command synchronously with a per-call + timeout (default 120s, max 600s). Uses ``asyncio.create_subprocess_shell`` + so the MCP event loop is not blocked while the child runs. +* ``bash_output`` polls a background job started with + ``execute_command_tool(run_in_background=True)`` and returns any new + stdout/stderr since the last poll plus the current status. +* ``bash_kill`` terminates a background job (SIGTERM then SIGKILL after + a 3-second grace period). + +All three go through the same pre-execution safety blocklist in +``command_sanitizer.py``. +""" + +from __future__ import annotations + +import asyncio import os -import subprocess +import time from mcp.server.fastmcp import FastMCP from ..command_sanitizer import CommandBlockedError, validate_command from ..security import AGENT_SANDBOXES_DIR, get_sandboxed_path +from .background_jobs import get as get_job +from .background_jobs import kill as kill_job +from .background_jobs import spawn as spawn_job + +# Bounds on per-call timeout. 1s minimum prevents accidental zeros that +# would cause every command to fail. 600s maximum (10 min) is the same +# ceiling Claude Code uses for its Bash tool; builds and test suites +# longer than that should use run_in_background instead. +_MIN_TIMEOUT = 1 +_MAX_TIMEOUT = 600 +_DEFAULT_TIMEOUT = 120 + + +def _resolve_cwd(cwd: str | None, agent_id: str) -> str: + agent_root = os.path.join(AGENT_SANDBOXES_DIR, agent_id, "current") + os.makedirs(agent_root, exist_ok=True) + if cwd: + return get_sandboxed_path(cwd, agent_id) + return agent_root def register_tools(mcp: FastMCP) -> None: """Register command execution tools with the MCP server.""" @mcp.tool() - def execute_command_tool(command: str, agent_id: str, cwd: str | None = None) -> dict: + async def execute_command_tool( + command: str, + agent_id: str, + cwd: str | None = None, + timeout_seconds: int = _DEFAULT_TIMEOUT, + run_in_background: bool = False, + ) -> dict: """ Purpose Execute a shell command within the agent sandbox. When to use - Run validators or linters + Run validators, linters, builds, test suites Generate derived artifacts (indexes, summaries) Perform controlled maintenance tasks + Start long-running processes via ``run_in_background=True`` + (dev servers, watchers, file-triggered builds) Rules & Constraints No network access unless explicitly allowed No destructive commands (rm -rf, system modification) - Output must be treated as data, not truth - Commands are validated against a safety blocklist before execution - Commands still run through shell=True, so the blocklist only - prevents explicit nested shell executables; it does not remove - shell parsing entirely + Commands are validated against a safety blocklist before + execution. The blocklist runs through shell=True, so it + only prevents explicit nested shell executables. + timeout_seconds is clamped to [1, 600]. For longer-running + work use run_in_background=True + bash_output to poll. Args: - command: The shell command to execute - agent_id: The ID of the agent - cwd: The working directory for the command (relative to agent sandbox, optional) + command: The shell command to execute. + agent_id: The ID of the agent (auto-injected). + cwd: Working directory for the command (relative to the + agent sandbox). Defaults to the sandbox root. + timeout_seconds: Max wall-clock seconds the foreground + command is allowed to run. Ignored when + run_in_background=True. Default 120, max 600. + run_in_background: If True, spawn the command and return + immediately with a job id. Use bash_output(id=...) to + read output and bash_kill(id=...) to stop it. Returns: - Dict with command output and execution details, or error dict + For foreground commands: dict with stdout, stderr, return_code, + elapsed_seconds. + For background commands: dict with id, pid, started_at, and + instructions for polling / killing the job. + On error: dict with an "error" key. """ - # Validate command against safety blocklist before execution try: validate_command(command) except CommandBlockedError as e: return {"error": f"Command blocked: {e}", "blocked": True} try: - # Default cwd is the agent sandbox root - agent_root = os.path.join(AGENT_SANDBOXES_DIR, agent_id, "current") - os.makedirs(agent_root, exist_ok=True) - - if cwd: - secure_cwd = get_sandboxed_path(cwd, agent_id) - else: - secure_cwd = agent_root - - result = subprocess.run( - command, - shell=True, - cwd=secure_cwd, - capture_output=True, - text=True, - timeout=60, - encoding="utf-8", - ) + secure_cwd = _resolve_cwd(cwd, agent_id) + except Exception as e: + return {"error": f"Failed to resolve cwd: {e}"} + if run_in_background: + try: + job = await spawn_job(command, secure_cwd, agent_id) + except Exception as e: + return {"error": f"Failed to spawn background job: {e}"} return { "success": True, + "background": True, + "id": job.id, + "pid": job.process.pid, "command": command, - "return_code": result.returncode, - "stdout": result.stdout, - "stderr": result.stderr, "cwd": cwd or ".", + "started_at": job.started_at, + "hint": ( + "Background job started. Call " + f"bash_output(id='{job.id}') to read output, or " + f"bash_kill(id='{job.id}') to terminate it." + ), } - except subprocess.TimeoutExpired: - return {"error": "Command timed out after 60 seconds"} + + # Foreground path: clamp timeout, spawn, wait with a watchdog. + try: + timeout = max(_MIN_TIMEOUT, min(_MAX_TIMEOUT, int(timeout_seconds))) + except (TypeError, ValueError): + timeout = _DEFAULT_TIMEOUT + + started = time.monotonic() + try: + proc = await asyncio.create_subprocess_shell( + command, + cwd=secure_cwd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) except Exception as e: - return {"error": f"Failed to execute command: {str(e)}"} + return {"error": f"Failed to execute command: {e}"} + + try: + stdout_b, stderr_b = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + # Child is still running: kill it, drain what it already + # wrote so the agent gets a partial log, then report. + try: + proc.kill() + except ProcessLookupError: + pass + try: + stdout_b, stderr_b = await asyncio.wait_for( + proc.communicate(), timeout=2.0 + ) + except (asyncio.TimeoutError, Exception): + stdout_b, stderr_b = b"", b"" + elapsed = round(time.monotonic() - started, 2) + return { + "error": ( + f"Command timed out after {timeout} seconds. " + f"For longer work pass timeout_seconds (max 600) or " + f"run_in_background=True." + ), + "timed_out": True, + "elapsed_seconds": elapsed, + "stdout": stdout_b.decode("utf-8", errors="replace"), + "stderr": stderr_b.decode("utf-8", errors="replace"), + } + except Exception as e: + return {"error": f"Failed while running command: {e}"} + + return { + "success": True, + "command": command, + "return_code": proc.returncode, + "stdout": stdout_b.decode("utf-8", errors="replace"), + "stderr": stderr_b.decode("utf-8", errors="replace"), + "cwd": cwd or ".", + "elapsed_seconds": round(time.monotonic() - started, 2), + } + + @mcp.tool() + async def bash_output(id: str, agent_id: str) -> dict: + """Poll a background command for new output and its current status. + + Returns any stdout/stderr bytes written since the last call. + The status is one of "running", "exited(N)", or "killed". + When the job has finished and all output has been consumed, it + is removed from the registry on the next poll. + + Args: + id: The job id returned from + execute_command_tool(run_in_background=True). + agent_id: The ID of the agent (auto-injected). + """ + job = await get_job(agent_id, id) + if job is None: + return {"error": f"no background job with id '{id}'"} + new_stdout = job.stdout_buf.read_new() + new_stderr = job.stderr_buf.read_new() + return { + "id": id, + "status": job.status(), + "stdout": new_stdout, + "stderr": new_stderr, + "elapsed_seconds": round(time.time() - job.started_at, 2), + } + + @mcp.tool() + async def bash_kill(id: str, agent_id: str) -> dict: + """Terminate a background command. + + Sends SIGTERM, waits up to 3 seconds, then escalates to SIGKILL + if the process is still alive. The job id is then deregistered. + + Args: + id: The job id returned from + execute_command_tool(run_in_background=True). + agent_id: The ID of the agent (auto-injected). + """ + status = await kill_job(agent_id, id) + return {"id": id, "status": status} diff --git a/tools/src/aden_tools/tools/file_system_toolkits/hashline_edit/hashline_edit.py b/tools/src/aden_tools/tools/file_system_toolkits/hashline_edit/hashline_edit.py index b2398597..0c523927 100644 --- a/tools/src/aden_tools/tools/file_system_toolkits/hashline_edit/hashline_edit.py +++ b/tools/src/aden_tools/tools/file_system_toolkits/hashline_edit/hashline_edit.py @@ -18,6 +18,8 @@ from aden_tools.hashline import ( validate_anchor, ) +from aden_tools.file_state_cache import Freshness, check_fresh, record_read + from ..security import get_sandboxed_path @@ -87,6 +89,29 @@ def register_tools(mcp: FastMCP) -> None: if not os.path.isfile(secure_path): return {"error": f"Path is not a file: {path}"} + # Stale-edit guard: refuse to edit unless we have a recent + # read recorded in the file-state cache and the file on disk + # still matches it. This catches the case where the user + # saved the file in their editor between the model's Read + # and Edit, which would otherwise cause the model to write + # against a stale mental model. + fresh = check_fresh(agent_id, secure_path) + if fresh.status is Freshness.UNREAD: + return { + "error": ( + f"Refusing to edit '{path}': you must call " + f"read_file('{path}') first so the harness can " + f"track its state before you edit it." + ) + } + if fresh.status is Freshness.STALE: + return { + "error": ( + f"Refusing to edit '{path}': {fresh.detail}. " + f"Re-read the file with read_file before editing." + ) + } + with open(secure_path, "rb") as f: raw_head = f.read(8192) eol = "\r\n" if b"\r\n" in raw_head else "\n" @@ -405,6 +430,15 @@ def register_tools(mcp: FastMCP) -> None: except Exception as e: return {"error": f"Failed to write file: {e}"} + # Re-record the new file state so a second edit in the same turn + # sees the post-write hash instead of tripping the stale guard. + try: + record_read(agent_id, secure_path, content_bytes=joined.encode(encoding)) + except Exception: + # Hash record is best-effort; a failure here must not break + # the edit that already succeeded on disk. + pass + # 10. Build response updated_lines = joined.splitlines() hashline_content = format_hashlines(updated_lines) diff --git a/tools/tests/test_file_state_cache.py b/tools/tests/test_file_state_cache.py new file mode 100644 index 00000000..c887ccd6 --- /dev/null +++ b/tools/tests/test_file_state_cache.py @@ -0,0 +1,206 @@ +"""Tests for aden_tools.file_state_cache and its integration with file_ops. + +These tests cover the stale-edit guard added for Gap 4: +- read_file records a per-file hash snapshot +- edit_file / write_file / hashline_edit refuse to run when the on-disk + file has diverged from the last recorded read +- write_file is allowed without a prior read when the target doesn't + exist yet (brand-new file, nothing to clobber) +- re-recording after a successful write keeps chained edits working +""" + +from __future__ import annotations + +import os +import time +from pathlib import Path + +import pytest +from fastmcp import FastMCP + +from aden_tools import file_state_cache +from aden_tools.file_ops import register_file_tools + + +def _find_tool(mcp: FastMCP, name: str): + """Pull a tool function out of an MCP registration for direct testing.""" + # fastmcp stores tools in a ToolManager. We reach through it to grab + # the underlying callable so tests can invoke tools directly without + # a full MCP round-trip. + manager = getattr(mcp, "_tool_manager", None) or getattr(mcp, "tool_manager", None) + assert manager is not None, "could not locate fastmcp tool manager" + tools = getattr(manager, "_tools", None) or getattr(manager, "tools", None) + assert tools is not None, "could not locate fastmcp tools dict" + tool = tools[name] + return getattr(tool, "fn", None) or getattr(tool, "func", None) or tool + + +@pytest.fixture +def sandbox(tmp_path: Path): + """A sandbox directory the tools are allowed to read/write within.""" + file_state_cache.reset_all() + return tmp_path + + +@pytest.fixture +def tools(sandbox: Path): + """Register file_ops onto a fresh FastMCP and return the tool callables.""" + mcp = FastMCP("test-server") + + def resolve(path: str) -> str: + # Absolute paths under the sandbox are fine; relative paths + # resolve against the sandbox root. + if os.path.isabs(path): + return os.path.abspath(path) + return str(sandbox / path) + + register_file_tools(mcp, resolve_path=resolve) + + return { + "read_file": _find_tool(mcp, "read_file"), + "write_file": _find_tool(mcp, "write_file"), + "edit_file": _find_tool(mcp, "edit_file"), + "hashline_edit": _find_tool(mcp, "hashline_edit"), + } + + +# --------------------------------------------------------------------------- +# Cache primitives +# --------------------------------------------------------------------------- + + +def test_check_fresh_returns_unread_when_never_recorded(sandbox: Path): + target = sandbox / "nope.txt" + target.write_text("hi") + result = file_state_cache.check_fresh(None, str(target)) + assert result.status is file_state_cache.Freshness.UNREAD + + +def test_record_then_check_returns_fresh(sandbox: Path): + target = sandbox / "a.txt" + target.write_text("one") + file_state_cache.record_read(None, str(target), content_bytes=b"one") + result = file_state_cache.check_fresh(None, str(target)) + assert result.status is file_state_cache.Freshness.FRESH + + +def test_external_write_makes_check_return_stale(sandbox: Path): + target = sandbox / "b.txt" + target.write_text("original") + file_state_cache.record_read(None, str(target), content_bytes=b"original") + + # Simulate an external editor save with different content. Sleep + # briefly to ensure mtime moves (some filesystems have 1s resolution + # but most Linux fs have ns; this is belt-and-braces). + time.sleep(0.01) + target.write_text("hijacked by the user") + os.utime(str(target), None) # bump mtime in case the write was too fast + + result = file_state_cache.check_fresh(None, str(target)) + assert result.status is file_state_cache.Freshness.STALE + assert "changed on disk" in result.detail or "differs" in result.detail + + +def test_identical_content_rewrite_stays_fresh(sandbox: Path): + """Editors that rewrite a file without changing its bytes shouldn't + be reported as stale even though mtime moved.""" + target = sandbox / "c.txt" + target.write_text("same") + file_state_cache.record_read(None, str(target), content_bytes=b"same") + + time.sleep(0.01) + target.write_text("same") # different mtime, same content + os.utime(str(target), None) + + result = file_state_cache.check_fresh(None, str(target)) + assert result.status is file_state_cache.Freshness.FRESH + + +def test_agent_scopes_are_isolated(sandbox: Path): + target = sandbox / "d.txt" + target.write_text("xyz") + file_state_cache.record_read("agent-A", str(target), content_bytes=b"xyz") + + # Another agent hasn't read this file yet. + result = file_state_cache.check_fresh("agent-B", str(target)) + assert result.status is file_state_cache.Freshness.UNREAD + + +# --------------------------------------------------------------------------- +# file_ops integration +# --------------------------------------------------------------------------- + + +def test_edit_file_refuses_without_prior_read(sandbox: Path, tools): + target = sandbox / "e.py" + target.write_text("print('hello')\n") + # Clear the cache first so there's definitely no recorded read. + file_state_cache.reset_all() + + result = tools["edit_file"]("e.py", "hello", "world") + assert "Refusing to edit" in result + assert "read_file" in result + + +def test_edit_file_proceeds_after_read(sandbox: Path, tools): + target = sandbox / "f.py" + target.write_text("print('hello')\n") + file_state_cache.reset_all() + + tools["read_file"]("f.py") + result = tools["edit_file"]("f.py", "hello", "world") + assert "Replaced" in result + assert target.read_text() == "print('world')\n" + + +def test_edit_file_refuses_when_file_changed_between_read_and_edit( + sandbox: Path, tools +): + target = sandbox / "g.py" + target.write_text("print('hello')\n") + file_state_cache.reset_all() + + tools["read_file"]("g.py") + + # Simulate the user editing the file outside the agent. + time.sleep(0.01) + target.write_text("print('bye')\n") + os.utime(str(target), None) + + result = tools["edit_file"]("g.py", "hello", "world") + assert "Refusing to edit" in result + assert "Re-read" in result + + +def test_write_file_allowed_for_new_file_without_prior_read(sandbox: Path, tools): + file_state_cache.reset_all() + result = tools["write_file"]("brand_new.txt", "first contents\n") + assert "Created" in result + assert (sandbox / "brand_new.txt").read_text() == "first contents\n" + + +def test_write_file_refuses_overwrite_without_prior_read(sandbox: Path, tools): + target = sandbox / "existing.txt" + target.write_text("do not clobber\n") + file_state_cache.reset_all() + + result = tools["write_file"]("existing.txt", "clobbered\n") + assert "Refusing to overwrite" in result + assert target.read_text() == "do not clobber\n" # unchanged + + +def test_chained_edits_in_same_turn_do_not_self_invalidate( + sandbox: Path, tools +): + target = sandbox / "chained.py" + target.write_text("print('a')\nprint('b')\n") + file_state_cache.reset_all() + + tools["read_file"]("chained.py") + r1 = tools["edit_file"]("chained.py", "a", "A") + assert "Replaced" in r1 + # Immediate second edit must NOT trip the stale guard because + # edit_file re-records the post-write state. + r2 = tools["edit_file"]("chained.py", "b", "B") + assert "Replaced" in r2 + assert target.read_text() == "print('A')\nprint('B')\n" diff --git a/tools/tests/tools/test_file_ops_hashline.py b/tools/tests/tools/test_file_ops_hashline.py index 365b12ac..22cebc65 100644 --- a/tools/tests/tools/test_file_ops_hashline.py +++ b/tools/tests/tools/test_file_ops_hashline.py @@ -11,6 +11,21 @@ from fastmcp import FastMCP from aden_tools.hashline import compute_line_hash +@pytest.fixture(autouse=True) +def _bypass_stale_edit_guard(): + """These tests exercise edit logic directly without a prior read_file, + so the Gap 4 stale-edit guard would reject every call. Force + check_fresh to always return FRESH here; the cache itself is + covered by ``tools/tests/test_file_state_cache.py``. + """ + from aden_tools.file_state_cache import FreshResult, Freshness + with patch( + "aden_tools.file_ops.check_fresh", + return_value=FreshResult(Freshness.FRESH), + ): + yield + + def _anchor(line_num, line_text): """Build an anchor string N:hhhh.""" return f"{line_num}:{compute_line_hash(line_text)}" diff --git a/tools/tests/tools/test_file_system_toolkits.py b/tools/tests/tools/test_file_system_toolkits.py index df4f191e..5c547b35 100644 --- a/tools/tests/tools/test_file_system_toolkits.py +++ b/tools/tests/tools/test_file_system_toolkits.py @@ -1,5 +1,6 @@ """Tests for file_system_toolkits tools (FastMCP).""" +import asyncio import json import os from unittest.mock import patch @@ -8,6 +9,22 @@ import pytest from fastmcp import FastMCP +@pytest.fixture(autouse=True) +def _bypass_stale_edit_guard(): + """These tests exercise edit logic directly without a prior read_file, + so the Gap 4 stale-edit guard would reject every call. Force + check_fresh to always return FRESH here; the cache itself is + covered by ``tools/tests/test_file_state_cache.py``. + """ + from aden_tools.file_state_cache import FreshResult, Freshness + with patch( + "aden_tools.tools.file_system_toolkits.hashline_edit." + "hashline_edit.check_fresh", + return_value=FreshResult(Freshness.FRESH), + ): + yield + + @pytest.fixture def mcp(): """Create a FastMCP instance.""" @@ -336,51 +353,222 @@ class TestExecuteCommandTool: register_tools(mcp) return mcp._tool_manager._tools["execute_command_tool"].fn - def test_execute_simple_command(self, execute_command_fn, mock_workspace, mock_secure_path): + async def test_execute_simple_command( + self, execute_command_fn, mock_workspace, mock_secure_path + ): """Executing a simple command returns output.""" - result = execute_command_fn(command="echo 'Hello World'", **mock_workspace) + result = await execute_command_fn(command="echo 'Hello World'", **mock_workspace) assert result["success"] is True assert result["return_code"] == 0 assert "Hello World" in result["stdout"] - def test_execute_failing_command(self, execute_command_fn, mock_workspace, mock_secure_path): + async def test_execute_failing_command( + self, execute_command_fn, mock_workspace, mock_secure_path + ): """Executing a failing command returns non-zero exit code.""" - result = execute_command_fn(command="exit 1", **mock_workspace) + result = await execute_command_fn(command="exit 1", **mock_workspace) assert result["success"] is True assert result["return_code"] == 1 - def test_execute_command_with_stderr( + async def test_execute_command_with_stderr( self, execute_command_fn, mock_workspace, mock_secure_path ): """Executing a command that writes to stderr captures it.""" - result = execute_command_fn(command="echo 'error message' >&2", **mock_workspace) + result = await execute_command_fn( + command="echo 'error message' >&2", **mock_workspace + ) assert result["success"] is True assert "error message" in result.get("stderr", "") - def test_execute_command_list_files( + async def test_execute_command_list_files( self, execute_command_fn, mock_workspace, mock_secure_path, tmp_path ): """Executing ls command lists files.""" # Create a test file (tmp_path / "testfile.txt").write_text("content", encoding="utf-8") - result = execute_command_fn(command=f"ls {tmp_path}", **mock_workspace) + result = await execute_command_fn(command=f"ls {tmp_path}", **mock_workspace) assert result["success"] is True assert result["return_code"] == 0 assert "testfile.txt" in result["stdout"] - def test_execute_command_with_pipe(self, execute_command_fn, mock_workspace, mock_secure_path): + async def test_execute_command_with_pipe( + self, execute_command_fn, mock_workspace, mock_secure_path + ): """Executing a command with pipe works correctly.""" - result = execute_command_fn(command="echo 'hello world' | tr 'a-z' 'A-Z'", **mock_workspace) + result = await execute_command_fn( + command="echo 'hello world' | tr 'a-z' 'A-Z'", **mock_workspace + ) assert result["success"] is True assert result["return_code"] == 0 assert "HELLO WORLD" in result["stdout"] + # ── Gap 3: async, per-call timeout, background jobs ────────────── + + @pytest.fixture + def bash_output_fn(self, mcp): + from aden_tools.tools.file_system_toolkits.execute_command_tool import ( + register_tools, + ) + + register_tools(mcp) + return mcp._tool_manager._tools["bash_output"].fn + + @pytest.fixture + def bash_kill_fn(self, mcp): + from aden_tools.tools.file_system_toolkits.execute_command_tool import ( + register_tools, + ) + + register_tools(mcp) + return mcp._tool_manager._tools["bash_kill"].fn + + async def test_per_call_timeout_overrides_default( + self, execute_command_fn, mock_workspace, mock_secure_path + ): + """A per-call timeout under the default kills the command early.""" + import time + + start = time.monotonic() + result = await execute_command_fn( + command="sleep 10", + timeout_seconds=1, + **mock_workspace, + ) + elapsed = time.monotonic() - start + + assert result.get("timed_out") is True + assert "1 seconds" in result.get("error", "") + # Must include the watchdog grace but stay well under 10s. + assert elapsed < 5, f"timeout did not kill the command promptly ({elapsed:.2f}s)" + + async def test_timeout_is_clamped_upwards( + self, execute_command_fn, mock_workspace, mock_secure_path + ): + """A timeout above the 600s ceiling is silently clamped.""" + # We don't actually sleep 600s - we just run a quick command + # with a nonsense timeout to prove the clamp doesn't raise. + result = await execute_command_fn( + command="echo fast", + timeout_seconds=99999, + **mock_workspace, + ) + assert result["success"] is True + assert "fast" in result["stdout"] + + async def test_event_loop_unblocked_while_command_runs( + self, execute_command_fn, mock_workspace, mock_secure_path + ): + """The event loop keeps servicing other tasks while a bash + command is running, unlike the old blocking subprocess.run.""" + ticks = 0 + + async def ticker(): + nonlocal ticks + for _ in range(20): + await asyncio.sleep(0.05) + ticks += 1 + + ticker_task = asyncio.create_task(ticker()) + # A 0.5s command: if the event loop were blocked, ticks would + # stay at 0 until it returned. We expect several ticks to land. + result = await execute_command_fn(command="sleep 0.5", **mock_workspace) + await ticker_task + + assert result["success"] is True + assert ticks >= 5, ( + f"event loop looked blocked during subprocess " + f"(only {ticks} ticks in 1s)" + ) + + async def test_background_job_start_poll_and_complete( + self, + execute_command_fn, + bash_output_fn, + mock_workspace, + mock_secure_path, + ): + """A run_in_background job can be started, polled, and reports + its exit status once the command finishes.""" + start_result = await execute_command_fn( + command=( + "python -c 'import time,sys;" + "print(\"one\");sys.stdout.flush();time.sleep(0.1);" + "print(\"two\");sys.stdout.flush();time.sleep(0.1);" + "print(\"three\")'" + ), + run_in_background=True, + **mock_workspace, + ) + assert start_result["background"] is True + job_id = start_result["id"] + + # Wait for the command to finish. + deadline = asyncio.get_event_loop().time() + 5.0 + seen_text = "" + while asyncio.get_event_loop().time() < deadline: + poll = await bash_output_fn(id=job_id, **mock_workspace) + seen_text += poll["stdout"] + if poll["status"].startswith("exited"): + break + await asyncio.sleep(0.05) + + assert "one" in seen_text + assert "two" in seen_text + assert "three" in seen_text + assert poll["status"] == "exited(0)" + + async def test_background_job_kill( + self, + execute_command_fn, + bash_output_fn, + bash_kill_fn, + mock_workspace, + mock_secure_path, + ): + """bash_kill terminates a long-running background job.""" + start_result = await execute_command_fn( + command="sleep 30", + run_in_background=True, + **mock_workspace, + ) + job_id = start_result["id"] + + kill_result = await bash_kill_fn(id=job_id, **mock_workspace) + assert kill_result["id"] == job_id + assert ( + "terminated" in kill_result["status"] + or "killed" in kill_result["status"] + ) + + # Job id should be deregistered after kill. + poll = await bash_output_fn(id=job_id, **mock_workspace) + assert "no background job" in poll.get("error", "") + + async def test_bash_output_isolated_across_agents( + self, execute_command_fn, bash_output_fn, mock_secure_path + ): + """Agent A's job id is not reachable from agent B.""" + start = await execute_command_fn( + command="sleep 5", + run_in_background=True, + agent_id="agent-A", + ) + poll_b = await bash_output_fn(id=start["id"], agent_id="agent-B") + assert "no background job" in poll_b.get("error", "") + + # Clean up. + from aden_tools.tools.file_system_toolkits.execute_command_tool import ( + background_jobs, + ) + + await background_jobs.clear_agent("agent-A") + class TestApplyDiffTool: """Tests for apply_diff tool.""" diff --git a/tools/tests/tools/test_hashline_edit.py b/tools/tests/tools/test_hashline_edit.py index b16b9e81..6e8aa9bd 100644 --- a/tools/tests/tools/test_hashline_edit.py +++ b/tools/tests/tools/test_hashline_edit.py @@ -11,6 +11,22 @@ from fastmcp import FastMCP from aden_tools.hashline import compute_line_hash +@pytest.fixture(autouse=True) +def _bypass_stale_edit_guard(): + """These tests exercise edit logic directly without a prior read_file, + so the Gap 4 file-state cache would reject every single call. Patch + the imported ``check_fresh`` symbol to always return FRESH here; the + cache itself is covered by ``tests/test_file_state_cache.py``. + """ + from aden_tools.file_state_cache import FreshResult, Freshness + with patch( + "aden_tools.tools.file_system_toolkits.hashline_edit." + "hashline_edit.check_fresh", + return_value=FreshResult(Freshness.FRESH), + ): + yield + + @pytest.fixture def mcp(): """Create a FastMCP instance."""