Merge pull request #107 from uttam-salamander/test/add-unit-tests-security-plan-example

test: add unit tests for security, plan, and example_tool modules
This commit is contained in:
RichardTang-Aden
2026-01-23 14:50:27 -08:00
committed by GitHub
3 changed files with 928 additions and 0 deletions
+125
View File
@@ -0,0 +1,125 @@
"""Tests for example_tool - A simple text processing tool."""
import pytest
from fastmcp import FastMCP
from aden_tools.tools.example_tool.example_tool import register_tools
@pytest.fixture
def example_tool_fn(mcp: FastMCP):
"""Register and return the example_tool function."""
register_tools(mcp)
return mcp._tool_manager._tools["example_tool"].fn
class TestExampleTool:
"""Tests for example_tool function."""
def test_valid_message(self, example_tool_fn):
"""Basic message returns unchanged."""
result = example_tool_fn(message="Hello, World!")
assert result == "Hello, World!"
def test_uppercase_true(self, example_tool_fn):
"""uppercase=True converts message to uppercase."""
result = example_tool_fn(message="hello", uppercase=True)
assert result == "HELLO"
def test_uppercase_false(self, example_tool_fn):
"""uppercase=False (default) preserves case."""
result = example_tool_fn(message="Hello", uppercase=False)
assert result == "Hello"
def test_repeat_multiple(self, example_tool_fn):
"""repeat=3 joins message with spaces."""
result = example_tool_fn(message="Hi", repeat=3)
assert result == "Hi Hi Hi"
def test_repeat_default(self, example_tool_fn):
"""repeat=1 (default) returns single message."""
result = example_tool_fn(message="Hello", repeat=1)
assert result == "Hello"
def test_uppercase_and_repeat_combined(self, example_tool_fn):
"""uppercase and repeat work together."""
result = example_tool_fn(message="hi", uppercase=True, repeat=2)
assert result == "HI HI"
def test_empty_message_error(self, example_tool_fn):
"""Empty string returns error string."""
result = example_tool_fn(message="")
assert "Error" in result
assert "1-1000" in result
def test_message_too_long_error(self, example_tool_fn):
"""Message over 1000 chars returns error string."""
long_message = "x" * 1001
result = example_tool_fn(message=long_message)
assert "Error" in result
assert "1-1000" in result
def test_message_at_max_length(self, example_tool_fn):
"""Message exactly 1000 chars is valid."""
max_message = "x" * 1000
result = example_tool_fn(message=max_message)
assert result == max_message
def test_repeat_zero_error(self, example_tool_fn):
"""repeat=0 returns error string."""
result = example_tool_fn(message="Hi", repeat=0)
assert "Error" in result
assert "1-10" in result
def test_repeat_eleven_error(self, example_tool_fn):
"""repeat=11 returns error string."""
result = example_tool_fn(message="Hi", repeat=11)
assert "Error" in result
assert "1-10" in result
def test_repeat_at_max(self, example_tool_fn):
"""repeat=10 (maximum) is valid."""
result = example_tool_fn(message="Hi", repeat=10)
assert result == " ".join(["Hi"] * 10)
def test_repeat_negative_error(self, example_tool_fn):
"""Negative repeat returns error string."""
result = example_tool_fn(message="Hi", repeat=-1)
assert "Error" in result
assert "1-10" in result
def test_whitespace_only_message(self, example_tool_fn):
"""Whitespace-only message is valid (non-empty)."""
result = example_tool_fn(message=" ")
assert result == " "
def test_special_characters_in_message(self, example_tool_fn):
"""Special characters are preserved."""
result = example_tool_fn(message="Hello! @#$%^&*()")
assert result == "Hello! @#$%^&*()"
def test_unicode_message(self, example_tool_fn):
"""Unicode characters are handled correctly."""
result = example_tool_fn(message="Hello 世界 🌍")
assert result == "Hello 世界 🌍"
def test_unicode_uppercase(self, example_tool_fn):
"""Unicode uppercase conversion works."""
result = example_tool_fn(message="café", uppercase=True)
assert result == "CAFÉ"
+215
View File
@@ -0,0 +1,215 @@
"""Tests for security.py - get_secure_path() function."""
import os
import pytest
from unittest.mock import patch
class TestGetSecurePath:
"""Tests for get_secure_path() function."""
@pytest.fixture(autouse=True)
def setup_workspaces_dir(self, tmp_path):
"""Patch WORKSPACES_DIR to use temp directory."""
self.workspaces_dir = tmp_path / "workspaces"
self.workspaces_dir.mkdir()
with patch(
"aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR",
str(self.workspaces_dir),
):
yield
@pytest.fixture
def ids(self):
"""Standard workspace, agent, and session IDs."""
return {
"workspace_id": "test-workspace",
"agent_id": "test-agent",
"session_id": "test-session",
}
def test_creates_session_directory(self, ids):
"""Session directory is created if it doesn't exist."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("file.txt", **ids)
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
assert session_dir.exists()
assert session_dir.is_dir()
def test_relative_path_resolved(self, ids):
"""Relative paths are resolved within session directory."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("subdir/file.txt", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "subdir" / "file.txt"
assert result == str(expected)
def test_absolute_path_treated_as_relative(self, ids):
"""Absolute paths are treated as relative to session root."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("/etc/passwd", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "etc" / "passwd"
assert result == str(expected)
def test_path_traversal_blocked(self, ids):
"""Path traversal attempts are blocked."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError, match="outside the session sandbox"):
get_secure_path("../../../etc/passwd", **ids)
def test_path_traversal_with_nested_dotdot(self, ids):
"""Nested path traversal with valid prefix is blocked."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError, match="outside the session sandbox"):
get_secure_path("valid/../../..", **ids)
def test_path_traversal_absolute_with_dotdot(self, ids):
"""Absolute path with traversal is blocked."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError, match="outside the session sandbox"):
get_secure_path("/foo/../../../etc/passwd", **ids)
def test_missing_workspace_id_raises(self, ids):
"""Missing workspace_id raises ValueError."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError, match="workspace_id.*required"):
get_secure_path("file.txt", workspace_id="", agent_id=ids["agent_id"], session_id=ids["session_id"])
def test_missing_agent_id_raises(self, ids):
"""Missing agent_id raises ValueError."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError, match="agent_id.*required"):
get_secure_path("file.txt", workspace_id=ids["workspace_id"], agent_id="", session_id=ids["session_id"])
def test_missing_session_id_raises(self, ids):
"""Missing session_id raises ValueError."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError, match="session_id.*required"):
get_secure_path("file.txt", workspace_id=ids["workspace_id"], agent_id=ids["agent_id"], session_id="")
def test_none_ids_raise(self):
"""None values for IDs raise ValueError."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
with pytest.raises(ValueError):
get_secure_path("file.txt", workspace_id=None, agent_id="agent", session_id="session")
def test_simple_filename(self, ids):
"""Simple filename resolves correctly."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("file.txt", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "file.txt"
assert result == str(expected)
def test_current_dir_path(self, ids):
"""Current directory path (.) resolves to session dir."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path(".", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
assert result == str(expected)
def test_dot_slash_path(self, ids):
"""Dot-slash paths resolve correctly."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("./subdir/file.txt", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "subdir" / "file.txt"
assert result == str(expected)
def test_deeply_nested_path(self, ids):
"""Deeply nested paths work correctly."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("a/b/c/d/e/file.txt", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "a" / "b" / "c" / "d" / "e" / "file.txt"
assert result == str(expected)
def test_path_with_spaces(self, ids):
"""Paths with spaces work correctly."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("my folder/my file.txt", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "my folder" / "my file.txt"
assert result == str(expected)
def test_path_with_special_characters(self, ids):
"""Paths with special characters work correctly."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("file-name_v2.0.txt", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "file-name_v2.0.txt"
assert result == str(expected)
def test_empty_path(self, ids):
"""Empty string path resolves to session directory."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
result = get_secure_path("", **ids)
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
assert result == str(expected)
def test_symlink_within_sandbox_works(self, ids):
"""Symlinks that stay within the sandbox are allowed."""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
# Create session directory structure
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
session_dir.mkdir(parents=True, exist_ok=True)
# Create a target file and a symlink to it
target_file = session_dir / "target.txt"
target_file.write_text("content")
symlink_path = session_dir / "link_to_target"
symlink_path.symlink_to(target_file)
# Path through symlink should resolve
result = get_secure_path("link_to_target", **ids)
assert result == str(symlink_path)
def test_symlink_escape_detected_with_realpath(self, ids):
"""Symlinks pointing outside sandbox can be detected using realpath.
Note: get_secure_path uses abspath (not realpath), so it validates the
lexical path. To fully protect against symlink attacks, callers should
verify realpath(result) is still within the sandbox before file I/O.
This test documents that pattern.
"""
from aden_tools.tools.file_system_toolkits.security import get_secure_path
# Create session directory
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
session_dir.mkdir(parents=True, exist_ok=True)
# Create a symlink inside session pointing outside
outside_target = self.workspaces_dir / "outside_file.txt"
outside_target.write_text("sensitive data")
symlink_path = session_dir / "escape_link"
symlink_path.symlink_to(outside_target)
# get_secure_path accepts the lexical path (symlink is inside session)
result = get_secure_path("escape_link", **ids)
assert result == str(symlink_path)
# However, realpath reveals the escape - callers should check this
real_path = os.path.realpath(result)
assert os.path.commonpath([real_path, str(session_dir)]) != str(session_dir)
+588
View File
@@ -0,0 +1,588 @@
"""Tests for plan.py - Plan enums and Pydantic models."""
import json
import pytest
from framework.graph.plan import (
ActionType,
StepStatus,
ApprovalDecision,
JudgmentAction,
ExecutionStatus,
ActionSpec,
PlanStep,
Plan,
)
class TestActionTypeEnum:
"""Tests for ActionType enum values."""
def test_action_type_values_exist(self):
"""All 5 ActionType values exist."""
assert ActionType.LLM_CALL.value == "llm_call"
assert ActionType.TOOL_USE.value == "tool_use"
assert ActionType.SUB_GRAPH.value == "sub_graph"
assert ActionType.FUNCTION.value == "function"
assert ActionType.CODE_EXECUTION.value == "code_execution"
def test_action_type_count(self):
"""ActionType has exactly 5 members."""
assert len(ActionType) == 5
def test_action_type_string_enum(self):
"""ActionType is a string enum."""
assert isinstance(ActionType.LLM_CALL, str)
assert ActionType.LLM_CALL == "llm_call"
class TestStepStatusEnum:
"""Tests for StepStatus enum values."""
def test_step_status_values_exist(self):
"""All 7 StepStatus values exist."""
assert StepStatus.PENDING.value == "pending"
assert StepStatus.AWAITING_APPROVAL.value == "awaiting_approval"
assert StepStatus.IN_PROGRESS.value == "in_progress"
assert StepStatus.COMPLETED.value == "completed"
assert StepStatus.FAILED.value == "failed"
assert StepStatus.SKIPPED.value == "skipped"
assert StepStatus.REJECTED.value == "rejected"
def test_step_status_count(self):
"""StepStatus has exactly 7 members."""
assert len(StepStatus) == 7
def test_step_status_transition_pending_to_in_progress(self):
"""Status can change from PENDING to IN_PROGRESS."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(action_type=ActionType.FUNCTION),
status=StepStatus.PENDING,
)
step.status = StepStatus.IN_PROGRESS
assert step.status == StepStatus.IN_PROGRESS
def test_step_status_transition_in_progress_to_completed(self):
"""Status can change from IN_PROGRESS to COMPLETED."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(action_type=ActionType.FUNCTION),
status=StepStatus.IN_PROGRESS,
)
step.status = StepStatus.COMPLETED
assert step.status == StepStatus.COMPLETED
def test_step_status_transition_in_progress_to_failed(self):
"""Status can change from IN_PROGRESS to FAILED."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(action_type=ActionType.FUNCTION),
status=StepStatus.IN_PROGRESS,
)
step.status = StepStatus.FAILED
assert step.status == StepStatus.FAILED
class TestApprovalDecisionEnum:
"""Tests for ApprovalDecision enum values."""
def test_approval_decision_values_exist(self):
"""All 4 ApprovalDecision values exist."""
assert ApprovalDecision.APPROVE.value == "approve"
assert ApprovalDecision.REJECT.value == "reject"
assert ApprovalDecision.MODIFY.value == "modify"
assert ApprovalDecision.ABORT.value == "abort"
def test_approval_decision_count(self):
"""ApprovalDecision has exactly 4 members."""
assert len(ApprovalDecision) == 4
class TestJudgmentActionEnum:
"""Tests for JudgmentAction enum values."""
def test_judgment_action_values_exist(self):
"""All 4 JudgmentAction values exist."""
assert JudgmentAction.ACCEPT.value == "accept"
assert JudgmentAction.RETRY.value == "retry"
assert JudgmentAction.REPLAN.value == "replan"
assert JudgmentAction.ESCALATE.value == "escalate"
def test_judgment_action_count(self):
"""JudgmentAction has exactly 4 members."""
assert len(JudgmentAction) == 4
class TestExecutionStatusEnum:
"""Tests for ExecutionStatus enum values."""
def test_execution_status_values_exist(self):
"""All 7 ExecutionStatus values exist."""
assert ExecutionStatus.COMPLETED.value == "completed"
assert ExecutionStatus.AWAITING_APPROVAL.value == "awaiting_approval"
assert ExecutionStatus.NEEDS_REPLAN.value == "needs_replan"
assert ExecutionStatus.NEEDS_ESCALATION.value == "needs_escalation"
assert ExecutionStatus.REJECTED.value == "rejected"
assert ExecutionStatus.ABORTED.value == "aborted"
assert ExecutionStatus.FAILED.value == "failed"
def test_execution_status_count(self):
"""ExecutionStatus has exactly 7 members."""
assert len(ExecutionStatus) == 7
class TestPlanStepIsReady:
"""Tests for PlanStep.is_ready() method."""
def test_plan_step_is_ready_no_deps(self):
"""Step with no dependencies is ready when PENDING."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=[],
status=StepStatus.PENDING,
)
assert step.is_ready(set()) is True
def test_plan_step_is_ready_deps_met(self):
"""Step is ready when all dependencies are completed."""
step = PlanStep(
id="step_2",
description="Second step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=["step_1"],
status=StepStatus.PENDING,
)
assert step.is_ready({"step_1"}) is True
def test_plan_step_not_ready_deps_missing(self):
"""Step is not ready when dependencies are incomplete."""
step = PlanStep(
id="step_2",
description="Second step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=["step_1", "step_3"],
status=StepStatus.PENDING,
)
# Only step_1 completed, step_3 still pending
assert step.is_ready({"step_1"}) is False
def test_plan_step_not_ready_wrong_status(self):
"""Step is not ready if status is not PENDING."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=[],
status=StepStatus.IN_PROGRESS,
)
assert step.is_ready(set()) is False
def test_plan_step_not_ready_completed_status(self):
"""Completed step is not ready to execute again."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=[],
status=StepStatus.COMPLETED,
)
assert step.is_ready(set()) is False
def test_plan_step_is_ready_multiple_deps_all_met(self):
"""Step with multiple dependencies is ready when all are met."""
step = PlanStep(
id="step_4",
description="Fourth step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=["step_1", "step_2", "step_3"],
status=StepStatus.PENDING,
)
assert step.is_ready({"step_1", "step_2", "step_3"}) is True
class TestPlanFromJson:
"""Tests for Plan.from_json() method."""
def test_plan_from_json_string(self):
"""Parse Plan from JSON string."""
json_str = json.dumps({
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": [
{
"id": "step_1",
"description": "First step",
"action": {
"action_type": "function",
"function_name": "do_something",
},
}
],
})
plan = Plan.from_json(json_str)
assert plan.id == "plan_1"
assert plan.goal_id == "goal_1"
assert len(plan.steps) == 1
assert plan.steps[0].id == "step_1"
def test_plan_from_json_dict(self):
"""Parse Plan from dict directly."""
data = {
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": [
{
"id": "step_1",
"description": "First step",
"action": {
"action_type": "function",
},
}
],
}
plan = Plan.from_json(data)
assert plan.id == "plan_1"
assert plan.goal_id == "goal_1"
def test_plan_from_json_nested_plan_key(self):
"""Handle {"plan": {...}} wrapper from export_graph()."""
data = {
"plan": {
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": [],
}
}
plan = Plan.from_json(data)
assert plan.id == "plan_1"
def test_plan_from_json_action_type_conversion(self):
"""String action_type is converted to ActionType enum."""
data = {
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": [
{
"id": "step_1",
"description": "LLM step",
"action": {
"action_type": "llm_call",
"prompt": "Hello",
},
}
],
}
plan = Plan.from_json(data)
assert plan.steps[0].action.action_type == ActionType.LLM_CALL
def test_plan_from_json_all_action_types(self):
"""All action types are correctly converted."""
action_types = ["llm_call", "tool_use", "sub_graph", "function", "code_execution"]
for action_type in action_types:
data = {
"id": "plan",
"goal_id": "goal",
"description": "Test",
"steps": [
{
"id": "step",
"description": "Step",
"action": {"action_type": action_type},
}
],
}
plan = Plan.from_json(data)
assert plan.steps[0].action.action_type.value == action_type
def test_from_json_invalid_action_type(self):
"""Unknown action_type raises ValueError."""
data = {
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": [
{
"id": "step_1",
"description": "Invalid step",
"action": {
"action_type": "invalid_type",
},
}
],
}
with pytest.raises(ValueError):
Plan.from_json(data)
def test_from_json_malformed_json_string(self):
"""Invalid JSON syntax raises parse error."""
invalid_json = "{ invalid json }"
with pytest.raises(json.JSONDecodeError):
Plan.from_json(invalid_json)
def test_from_json_missing_step_id(self):
"""Step without 'id' raises validation error."""
data = {
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": [
{
"description": "Step without ID",
"action": {"action_type": "function"},
}
],
}
with pytest.raises(KeyError):
Plan.from_json(data)
def test_from_json_wrong_type_for_steps(self):
"""Non-list steps value raises error."""
data = {
"id": "plan_1",
"goal_id": "goal_1",
"description": "Test plan",
"steps": "not a list",
}
with pytest.raises(AttributeError):
Plan.from_json(data)
def test_from_json_empty_data(self):
"""Empty dict creates plan with defaults."""
plan = Plan.from_json({})
assert plan.id == "plan"
assert plan.goal_id == ""
assert plan.steps == []
class TestPlanMethods:
"""Tests for Plan instance methods."""
@pytest.fixture
def sample_plan(self):
"""Create a sample plan with multiple steps."""
return Plan(
id="test_plan",
goal_id="goal_1",
description="Test plan",
steps=[
PlanStep(
id="step_1",
description="First step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=[],
status=StepStatus.COMPLETED,
result={"data": "result1"},
),
PlanStep(
id="step_2",
description="Second step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=["step_1"],
status=StepStatus.PENDING,
),
PlanStep(
id="step_3",
description="Third step",
action=ActionSpec(action_type=ActionType.FUNCTION),
dependencies=["step_1"],
status=StepStatus.FAILED,
error="Something went wrong",
attempts=3,
),
],
)
def test_plan_get_step(self, sample_plan):
"""Find step by ID."""
step = sample_plan.get_step("step_2")
assert step is not None
assert step.id == "step_2"
assert step.description == "Second step"
def test_plan_get_step_not_found(self, sample_plan):
"""Returns None for missing step ID."""
step = sample_plan.get_step("nonexistent")
assert step is None
def test_plan_get_ready_steps(self, sample_plan):
"""Filter steps ready to execute."""
ready = sample_plan.get_ready_steps()
assert len(ready) == 1
assert ready[0].id == "step_2"
def test_plan_get_completed_steps(self, sample_plan):
"""Filter completed steps."""
completed = sample_plan.get_completed_steps()
assert len(completed) == 1
assert completed[0].id == "step_1"
def test_plan_is_complete_false(self, sample_plan):
"""Plan is not complete when steps are pending/failed."""
assert sample_plan.is_complete() is False
def test_plan_is_complete_true(self):
"""Plan is complete when all steps are completed."""
plan = Plan(
id="test_plan",
goal_id="goal_1",
description="Test plan",
steps=[
PlanStep(
id="step_1",
description="First step",
action=ActionSpec(action_type=ActionType.FUNCTION),
status=StepStatus.COMPLETED,
),
PlanStep(
id="step_2",
description="Second step",
action=ActionSpec(action_type=ActionType.FUNCTION),
status=StepStatus.COMPLETED,
),
],
)
assert plan.is_complete() is True
def test_plan_is_complete_empty(self):
"""Empty plan is considered complete."""
plan = Plan(
id="empty_plan",
goal_id="goal_1",
description="Empty plan",
steps=[],
)
assert plan.is_complete() is True
def test_plan_to_feedback_context(self, sample_plan):
"""Serializes context for replanning."""
context = sample_plan.to_feedback_context()
assert context["plan_id"] == "test_plan"
assert context["revision"] == 1
assert len(context["completed_steps"]) == 1
assert context["completed_steps"][0]["id"] == "step_1"
assert len(context["failed_steps"]) == 1
assert context["failed_steps"][0]["id"] == "step_3"
assert context["failed_steps"][0]["error"] == "Something went wrong"
class TestPlanRoundTrip:
"""Tests for Plan serialization round-trip."""
def test_plan_round_trip_model_dump(self):
"""from_json(plan.model_dump()) preserves data."""
original = Plan(
id="plan_1",
goal_id="goal_1",
description="Test plan",
steps=[
PlanStep(
id="step_1",
description="First step",
action=ActionSpec(
action_type=ActionType.LLM_CALL,
prompt="Hello world",
),
dependencies=[],
expected_outputs=["greeting"],
),
],
context={"key": "value"},
revision=2,
)
# Round-trip through dict
data = original.model_dump()
restored = Plan.from_json(data)
assert restored.id == original.id
assert restored.goal_id == original.goal_id
assert restored.description == original.description
assert restored.context == original.context
assert restored.revision == original.revision
assert len(restored.steps) == len(original.steps)
assert restored.steps[0].id == original.steps[0].id
assert restored.steps[0].action.action_type == original.steps[0].action.action_type
def test_plan_round_trip_json_string(self):
"""from_json(plan.model_dump_json()) preserves data."""
original = Plan(
id="plan_1",
goal_id="goal_1",
description="Test plan",
steps=[
PlanStep(
id="step_1",
description="First step",
action=ActionSpec(
action_type=ActionType.TOOL_USE,
tool_name="my_tool",
tool_args={"arg1": "value1"},
),
dependencies=[],
),
],
)
# Round-trip through JSON string
json_str = original.model_dump_json()
restored = Plan.from_json(json_str)
assert restored.id == original.id
assert len(restored.steps) == 1
assert restored.steps[0].action.tool_name == "my_tool"
def test_plan_step_serialization(self):
"""PlanStep serializes and deserializes correctly."""
step = PlanStep(
id="step_1",
description="Test step",
action=ActionSpec(
action_type=ActionType.CODE_EXECUTION,
code="print('hello')",
language="python",
),
inputs={"input1": "value1"},
expected_outputs=["output1", "output2"],
dependencies=["dep1", "dep2"],
requires_approval=True,
approval_message="Please approve",
)
# Serialize and deserialize
data = step.model_dump()
assert data["id"] == "step_1"
assert data["action"]["action_type"] == "code_execution"
assert data["action"]["code"] == "print('hello')"
assert data["inputs"] == {"input1": "value1"}
assert data["expected_outputs"] == ["output1", "output2"]
assert data["dependencies"] == ["dep1", "dep2"]
assert data["requires_approval"] is True