Merge pull request #107 from uttam-salamander/test/add-unit-tests-security-plan-example
test: add unit tests for security, plan, and example_tool modules
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
"""Tests for example_tool - A simple text processing tool."""
|
||||
import pytest
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from aden_tools.tools.example_tool.example_tool import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_tool_fn(mcp: FastMCP):
|
||||
"""Register and return the example_tool function."""
|
||||
register_tools(mcp)
|
||||
return mcp._tool_manager._tools["example_tool"].fn
|
||||
|
||||
|
||||
class TestExampleTool:
|
||||
"""Tests for example_tool function."""
|
||||
|
||||
def test_valid_message(self, example_tool_fn):
|
||||
"""Basic message returns unchanged."""
|
||||
result = example_tool_fn(message="Hello, World!")
|
||||
|
||||
assert result == "Hello, World!"
|
||||
|
||||
def test_uppercase_true(self, example_tool_fn):
|
||||
"""uppercase=True converts message to uppercase."""
|
||||
result = example_tool_fn(message="hello", uppercase=True)
|
||||
|
||||
assert result == "HELLO"
|
||||
|
||||
def test_uppercase_false(self, example_tool_fn):
|
||||
"""uppercase=False (default) preserves case."""
|
||||
result = example_tool_fn(message="Hello", uppercase=False)
|
||||
|
||||
assert result == "Hello"
|
||||
|
||||
def test_repeat_multiple(self, example_tool_fn):
|
||||
"""repeat=3 joins message with spaces."""
|
||||
result = example_tool_fn(message="Hi", repeat=3)
|
||||
|
||||
assert result == "Hi Hi Hi"
|
||||
|
||||
def test_repeat_default(self, example_tool_fn):
|
||||
"""repeat=1 (default) returns single message."""
|
||||
result = example_tool_fn(message="Hello", repeat=1)
|
||||
|
||||
assert result == "Hello"
|
||||
|
||||
def test_uppercase_and_repeat_combined(self, example_tool_fn):
|
||||
"""uppercase and repeat work together."""
|
||||
result = example_tool_fn(message="hi", uppercase=True, repeat=2)
|
||||
|
||||
assert result == "HI HI"
|
||||
|
||||
def test_empty_message_error(self, example_tool_fn):
|
||||
"""Empty string returns error string."""
|
||||
result = example_tool_fn(message="")
|
||||
|
||||
assert "Error" in result
|
||||
assert "1-1000" in result
|
||||
|
||||
def test_message_too_long_error(self, example_tool_fn):
|
||||
"""Message over 1000 chars returns error string."""
|
||||
long_message = "x" * 1001
|
||||
result = example_tool_fn(message=long_message)
|
||||
|
||||
assert "Error" in result
|
||||
assert "1-1000" in result
|
||||
|
||||
def test_message_at_max_length(self, example_tool_fn):
|
||||
"""Message exactly 1000 chars is valid."""
|
||||
max_message = "x" * 1000
|
||||
result = example_tool_fn(message=max_message)
|
||||
|
||||
assert result == max_message
|
||||
|
||||
def test_repeat_zero_error(self, example_tool_fn):
|
||||
"""repeat=0 returns error string."""
|
||||
result = example_tool_fn(message="Hi", repeat=0)
|
||||
|
||||
assert "Error" in result
|
||||
assert "1-10" in result
|
||||
|
||||
def test_repeat_eleven_error(self, example_tool_fn):
|
||||
"""repeat=11 returns error string."""
|
||||
result = example_tool_fn(message="Hi", repeat=11)
|
||||
|
||||
assert "Error" in result
|
||||
assert "1-10" in result
|
||||
|
||||
def test_repeat_at_max(self, example_tool_fn):
|
||||
"""repeat=10 (maximum) is valid."""
|
||||
result = example_tool_fn(message="Hi", repeat=10)
|
||||
|
||||
assert result == " ".join(["Hi"] * 10)
|
||||
|
||||
def test_repeat_negative_error(self, example_tool_fn):
|
||||
"""Negative repeat returns error string."""
|
||||
result = example_tool_fn(message="Hi", repeat=-1)
|
||||
|
||||
assert "Error" in result
|
||||
assert "1-10" in result
|
||||
|
||||
def test_whitespace_only_message(self, example_tool_fn):
|
||||
"""Whitespace-only message is valid (non-empty)."""
|
||||
result = example_tool_fn(message=" ")
|
||||
|
||||
assert result == " "
|
||||
|
||||
def test_special_characters_in_message(self, example_tool_fn):
|
||||
"""Special characters are preserved."""
|
||||
result = example_tool_fn(message="Hello! @#$%^&*()")
|
||||
|
||||
assert result == "Hello! @#$%^&*()"
|
||||
|
||||
def test_unicode_message(self, example_tool_fn):
|
||||
"""Unicode characters are handled correctly."""
|
||||
result = example_tool_fn(message="Hello 世界 🌍")
|
||||
|
||||
assert result == "Hello 世界 🌍"
|
||||
|
||||
def test_unicode_uppercase(self, example_tool_fn):
|
||||
"""Unicode uppercase conversion works."""
|
||||
result = example_tool_fn(message="café", uppercase=True)
|
||||
|
||||
assert result == "CAFÉ"
|
||||
@@ -0,0 +1,215 @@
|
||||
"""Tests for security.py - get_secure_path() function."""
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
class TestGetSecurePath:
|
||||
"""Tests for get_secure_path() function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_workspaces_dir(self, tmp_path):
|
||||
"""Patch WORKSPACES_DIR to use temp directory."""
|
||||
self.workspaces_dir = tmp_path / "workspaces"
|
||||
self.workspaces_dir.mkdir()
|
||||
with patch(
|
||||
"aden_tools.tools.file_system_toolkits.security.WORKSPACES_DIR",
|
||||
str(self.workspaces_dir),
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def ids(self):
|
||||
"""Standard workspace, agent, and session IDs."""
|
||||
return {
|
||||
"workspace_id": "test-workspace",
|
||||
"agent_id": "test-agent",
|
||||
"session_id": "test-session",
|
||||
}
|
||||
|
||||
def test_creates_session_directory(self, ids):
|
||||
"""Session directory is created if it doesn't exist."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("file.txt", **ids)
|
||||
|
||||
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
|
||||
assert session_dir.exists()
|
||||
assert session_dir.is_dir()
|
||||
|
||||
def test_relative_path_resolved(self, ids):
|
||||
"""Relative paths are resolved within session directory."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("subdir/file.txt", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "subdir" / "file.txt"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_absolute_path_treated_as_relative(self, ids):
|
||||
"""Absolute paths are treated as relative to session root."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("/etc/passwd", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "etc" / "passwd"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_path_traversal_blocked(self, ids):
|
||||
"""Path traversal attempts are blocked."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError, match="outside the session sandbox"):
|
||||
get_secure_path("../../../etc/passwd", **ids)
|
||||
|
||||
def test_path_traversal_with_nested_dotdot(self, ids):
|
||||
"""Nested path traversal with valid prefix is blocked."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError, match="outside the session sandbox"):
|
||||
get_secure_path("valid/../../..", **ids)
|
||||
|
||||
def test_path_traversal_absolute_with_dotdot(self, ids):
|
||||
"""Absolute path with traversal is blocked."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError, match="outside the session sandbox"):
|
||||
get_secure_path("/foo/../../../etc/passwd", **ids)
|
||||
|
||||
def test_missing_workspace_id_raises(self, ids):
|
||||
"""Missing workspace_id raises ValueError."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError, match="workspace_id.*required"):
|
||||
get_secure_path("file.txt", workspace_id="", agent_id=ids["agent_id"], session_id=ids["session_id"])
|
||||
|
||||
def test_missing_agent_id_raises(self, ids):
|
||||
"""Missing agent_id raises ValueError."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError, match="agent_id.*required"):
|
||||
get_secure_path("file.txt", workspace_id=ids["workspace_id"], agent_id="", session_id=ids["session_id"])
|
||||
|
||||
def test_missing_session_id_raises(self, ids):
|
||||
"""Missing session_id raises ValueError."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError, match="session_id.*required"):
|
||||
get_secure_path("file.txt", workspace_id=ids["workspace_id"], agent_id=ids["agent_id"], session_id="")
|
||||
|
||||
def test_none_ids_raise(self):
|
||||
"""None values for IDs raise ValueError."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_secure_path("file.txt", workspace_id=None, agent_id="agent", session_id="session")
|
||||
|
||||
def test_simple_filename(self, ids):
|
||||
"""Simple filename resolves correctly."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("file.txt", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "file.txt"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_current_dir_path(self, ids):
|
||||
"""Current directory path (.) resolves to session dir."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path(".", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_dot_slash_path(self, ids):
|
||||
"""Dot-slash paths resolve correctly."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("./subdir/file.txt", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "subdir" / "file.txt"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_deeply_nested_path(self, ids):
|
||||
"""Deeply nested paths work correctly."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("a/b/c/d/e/file.txt", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "a" / "b" / "c" / "d" / "e" / "file.txt"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_path_with_spaces(self, ids):
|
||||
"""Paths with spaces work correctly."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("my folder/my file.txt", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "my folder" / "my file.txt"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_path_with_special_characters(self, ids):
|
||||
"""Paths with special characters work correctly."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("file-name_v2.0.txt", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session" / "file-name_v2.0.txt"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_empty_path(self, ids):
|
||||
"""Empty string path resolves to session directory."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
result = get_secure_path("", **ids)
|
||||
|
||||
expected = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
|
||||
assert result == str(expected)
|
||||
|
||||
def test_symlink_within_sandbox_works(self, ids):
|
||||
"""Symlinks that stay within the sandbox are allowed."""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
# Create session directory structure
|
||||
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a target file and a symlink to it
|
||||
target_file = session_dir / "target.txt"
|
||||
target_file.write_text("content")
|
||||
symlink_path = session_dir / "link_to_target"
|
||||
symlink_path.symlink_to(target_file)
|
||||
|
||||
# Path through symlink should resolve
|
||||
result = get_secure_path("link_to_target", **ids)
|
||||
|
||||
assert result == str(symlink_path)
|
||||
|
||||
def test_symlink_escape_detected_with_realpath(self, ids):
|
||||
"""Symlinks pointing outside sandbox can be detected using realpath.
|
||||
|
||||
Note: get_secure_path uses abspath (not realpath), so it validates the
|
||||
lexical path. To fully protect against symlink attacks, callers should
|
||||
verify realpath(result) is still within the sandbox before file I/O.
|
||||
This test documents that pattern.
|
||||
"""
|
||||
from aden_tools.tools.file_system_toolkits.security import get_secure_path
|
||||
|
||||
# Create session directory
|
||||
session_dir = self.workspaces_dir / "test-workspace" / "test-agent" / "test-session"
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a symlink inside session pointing outside
|
||||
outside_target = self.workspaces_dir / "outside_file.txt"
|
||||
outside_target.write_text("sensitive data")
|
||||
symlink_path = session_dir / "escape_link"
|
||||
symlink_path.symlink_to(outside_target)
|
||||
|
||||
# get_secure_path accepts the lexical path (symlink is inside session)
|
||||
result = get_secure_path("escape_link", **ids)
|
||||
assert result == str(symlink_path)
|
||||
|
||||
# However, realpath reveals the escape - callers should check this
|
||||
real_path = os.path.realpath(result)
|
||||
assert os.path.commonpath([real_path, str(session_dir)]) != str(session_dir)
|
||||
@@ -0,0 +1,588 @@
|
||||
"""Tests for plan.py - Plan enums and Pydantic models."""
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from framework.graph.plan import (
|
||||
ActionType,
|
||||
StepStatus,
|
||||
ApprovalDecision,
|
||||
JudgmentAction,
|
||||
ExecutionStatus,
|
||||
ActionSpec,
|
||||
PlanStep,
|
||||
Plan,
|
||||
)
|
||||
|
||||
|
||||
class TestActionTypeEnum:
|
||||
"""Tests for ActionType enum values."""
|
||||
|
||||
def test_action_type_values_exist(self):
|
||||
"""All 5 ActionType values exist."""
|
||||
assert ActionType.LLM_CALL.value == "llm_call"
|
||||
assert ActionType.TOOL_USE.value == "tool_use"
|
||||
assert ActionType.SUB_GRAPH.value == "sub_graph"
|
||||
assert ActionType.FUNCTION.value == "function"
|
||||
assert ActionType.CODE_EXECUTION.value == "code_execution"
|
||||
|
||||
def test_action_type_count(self):
|
||||
"""ActionType has exactly 5 members."""
|
||||
assert len(ActionType) == 5
|
||||
|
||||
def test_action_type_string_enum(self):
|
||||
"""ActionType is a string enum."""
|
||||
assert isinstance(ActionType.LLM_CALL, str)
|
||||
assert ActionType.LLM_CALL == "llm_call"
|
||||
|
||||
|
||||
class TestStepStatusEnum:
|
||||
"""Tests for StepStatus enum values."""
|
||||
|
||||
def test_step_status_values_exist(self):
|
||||
"""All 7 StepStatus values exist."""
|
||||
assert StepStatus.PENDING.value == "pending"
|
||||
assert StepStatus.AWAITING_APPROVAL.value == "awaiting_approval"
|
||||
assert StepStatus.IN_PROGRESS.value == "in_progress"
|
||||
assert StepStatus.COMPLETED.value == "completed"
|
||||
assert StepStatus.FAILED.value == "failed"
|
||||
assert StepStatus.SKIPPED.value == "skipped"
|
||||
assert StepStatus.REJECTED.value == "rejected"
|
||||
|
||||
def test_step_status_count(self):
|
||||
"""StepStatus has exactly 7 members."""
|
||||
assert len(StepStatus) == 7
|
||||
|
||||
def test_step_status_transition_pending_to_in_progress(self):
|
||||
"""Status can change from PENDING to IN_PROGRESS."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.PENDING,
|
||||
)
|
||||
step.status = StepStatus.IN_PROGRESS
|
||||
assert step.status == StepStatus.IN_PROGRESS
|
||||
|
||||
def test_step_status_transition_in_progress_to_completed(self):
|
||||
"""Status can change from IN_PROGRESS to COMPLETED."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.IN_PROGRESS,
|
||||
)
|
||||
step.status = StepStatus.COMPLETED
|
||||
assert step.status == StepStatus.COMPLETED
|
||||
|
||||
def test_step_status_transition_in_progress_to_failed(self):
|
||||
"""Status can change from IN_PROGRESS to FAILED."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.IN_PROGRESS,
|
||||
)
|
||||
step.status = StepStatus.FAILED
|
||||
assert step.status == StepStatus.FAILED
|
||||
|
||||
|
||||
class TestApprovalDecisionEnum:
|
||||
"""Tests for ApprovalDecision enum values."""
|
||||
|
||||
def test_approval_decision_values_exist(self):
|
||||
"""All 4 ApprovalDecision values exist."""
|
||||
assert ApprovalDecision.APPROVE.value == "approve"
|
||||
assert ApprovalDecision.REJECT.value == "reject"
|
||||
assert ApprovalDecision.MODIFY.value == "modify"
|
||||
assert ApprovalDecision.ABORT.value == "abort"
|
||||
|
||||
def test_approval_decision_count(self):
|
||||
"""ApprovalDecision has exactly 4 members."""
|
||||
assert len(ApprovalDecision) == 4
|
||||
|
||||
|
||||
class TestJudgmentActionEnum:
|
||||
"""Tests for JudgmentAction enum values."""
|
||||
|
||||
def test_judgment_action_values_exist(self):
|
||||
"""All 4 JudgmentAction values exist."""
|
||||
assert JudgmentAction.ACCEPT.value == "accept"
|
||||
assert JudgmentAction.RETRY.value == "retry"
|
||||
assert JudgmentAction.REPLAN.value == "replan"
|
||||
assert JudgmentAction.ESCALATE.value == "escalate"
|
||||
|
||||
def test_judgment_action_count(self):
|
||||
"""JudgmentAction has exactly 4 members."""
|
||||
assert len(JudgmentAction) == 4
|
||||
|
||||
|
||||
class TestExecutionStatusEnum:
|
||||
"""Tests for ExecutionStatus enum values."""
|
||||
|
||||
def test_execution_status_values_exist(self):
|
||||
"""All 7 ExecutionStatus values exist."""
|
||||
assert ExecutionStatus.COMPLETED.value == "completed"
|
||||
assert ExecutionStatus.AWAITING_APPROVAL.value == "awaiting_approval"
|
||||
assert ExecutionStatus.NEEDS_REPLAN.value == "needs_replan"
|
||||
assert ExecutionStatus.NEEDS_ESCALATION.value == "needs_escalation"
|
||||
assert ExecutionStatus.REJECTED.value == "rejected"
|
||||
assert ExecutionStatus.ABORTED.value == "aborted"
|
||||
assert ExecutionStatus.FAILED.value == "failed"
|
||||
|
||||
def test_execution_status_count(self):
|
||||
"""ExecutionStatus has exactly 7 members."""
|
||||
assert len(ExecutionStatus) == 7
|
||||
|
||||
|
||||
class TestPlanStepIsReady:
|
||||
"""Tests for PlanStep.is_ready() method."""
|
||||
|
||||
def test_plan_step_is_ready_no_deps(self):
|
||||
"""Step with no dependencies is ready when PENDING."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=[],
|
||||
status=StepStatus.PENDING,
|
||||
)
|
||||
assert step.is_ready(set()) is True
|
||||
|
||||
def test_plan_step_is_ready_deps_met(self):
|
||||
"""Step is ready when all dependencies are completed."""
|
||||
step = PlanStep(
|
||||
id="step_2",
|
||||
description="Second step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1"],
|
||||
status=StepStatus.PENDING,
|
||||
)
|
||||
assert step.is_ready({"step_1"}) is True
|
||||
|
||||
def test_plan_step_not_ready_deps_missing(self):
|
||||
"""Step is not ready when dependencies are incomplete."""
|
||||
step = PlanStep(
|
||||
id="step_2",
|
||||
description="Second step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1", "step_3"],
|
||||
status=StepStatus.PENDING,
|
||||
)
|
||||
# Only step_1 completed, step_3 still pending
|
||||
assert step.is_ready({"step_1"}) is False
|
||||
|
||||
def test_plan_step_not_ready_wrong_status(self):
|
||||
"""Step is not ready if status is not PENDING."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=[],
|
||||
status=StepStatus.IN_PROGRESS,
|
||||
)
|
||||
assert step.is_ready(set()) is False
|
||||
|
||||
def test_plan_step_not_ready_completed_status(self):
|
||||
"""Completed step is not ready to execute again."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=[],
|
||||
status=StepStatus.COMPLETED,
|
||||
)
|
||||
assert step.is_ready(set()) is False
|
||||
|
||||
def test_plan_step_is_ready_multiple_deps_all_met(self):
|
||||
"""Step with multiple dependencies is ready when all are met."""
|
||||
step = PlanStep(
|
||||
id="step_4",
|
||||
description="Fourth step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1", "step_2", "step_3"],
|
||||
status=StepStatus.PENDING,
|
||||
)
|
||||
assert step.is_ready({"step_1", "step_2", "step_3"}) is True
|
||||
|
||||
|
||||
class TestPlanFromJson:
|
||||
"""Tests for Plan.from_json() method."""
|
||||
|
||||
def test_plan_from_json_string(self):
|
||||
"""Parse Plan from JSON string."""
|
||||
json_str = json.dumps({
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": [
|
||||
{
|
||||
"id": "step_1",
|
||||
"description": "First step",
|
||||
"action": {
|
||||
"action_type": "function",
|
||||
"function_name": "do_something",
|
||||
},
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
plan = Plan.from_json(json_str)
|
||||
|
||||
assert plan.id == "plan_1"
|
||||
assert plan.goal_id == "goal_1"
|
||||
assert len(plan.steps) == 1
|
||||
assert plan.steps[0].id == "step_1"
|
||||
|
||||
def test_plan_from_json_dict(self):
|
||||
"""Parse Plan from dict directly."""
|
||||
data = {
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": [
|
||||
{
|
||||
"id": "step_1",
|
||||
"description": "First step",
|
||||
"action": {
|
||||
"action_type": "function",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
plan = Plan.from_json(data)
|
||||
|
||||
assert plan.id == "plan_1"
|
||||
assert plan.goal_id == "goal_1"
|
||||
|
||||
def test_plan_from_json_nested_plan_key(self):
|
||||
"""Handle {"plan": {...}} wrapper from export_graph()."""
|
||||
data = {
|
||||
"plan": {
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": [],
|
||||
}
|
||||
}
|
||||
|
||||
plan = Plan.from_json(data)
|
||||
|
||||
assert plan.id == "plan_1"
|
||||
|
||||
def test_plan_from_json_action_type_conversion(self):
|
||||
"""String action_type is converted to ActionType enum."""
|
||||
data = {
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": [
|
||||
{
|
||||
"id": "step_1",
|
||||
"description": "LLM step",
|
||||
"action": {
|
||||
"action_type": "llm_call",
|
||||
"prompt": "Hello",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
plan = Plan.from_json(data)
|
||||
|
||||
assert plan.steps[0].action.action_type == ActionType.LLM_CALL
|
||||
|
||||
def test_plan_from_json_all_action_types(self):
|
||||
"""All action types are correctly converted."""
|
||||
action_types = ["llm_call", "tool_use", "sub_graph", "function", "code_execution"]
|
||||
|
||||
for action_type in action_types:
|
||||
data = {
|
||||
"id": "plan",
|
||||
"goal_id": "goal",
|
||||
"description": "Test",
|
||||
"steps": [
|
||||
{
|
||||
"id": "step",
|
||||
"description": "Step",
|
||||
"action": {"action_type": action_type},
|
||||
}
|
||||
],
|
||||
}
|
||||
plan = Plan.from_json(data)
|
||||
assert plan.steps[0].action.action_type.value == action_type
|
||||
|
||||
def test_from_json_invalid_action_type(self):
|
||||
"""Unknown action_type raises ValueError."""
|
||||
data = {
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": [
|
||||
{
|
||||
"id": "step_1",
|
||||
"description": "Invalid step",
|
||||
"action": {
|
||||
"action_type": "invalid_type",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Plan.from_json(data)
|
||||
|
||||
def test_from_json_malformed_json_string(self):
|
||||
"""Invalid JSON syntax raises parse error."""
|
||||
invalid_json = "{ invalid json }"
|
||||
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
Plan.from_json(invalid_json)
|
||||
|
||||
def test_from_json_missing_step_id(self):
|
||||
"""Step without 'id' raises validation error."""
|
||||
data = {
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Step without ID",
|
||||
"action": {"action_type": "function"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
Plan.from_json(data)
|
||||
|
||||
def test_from_json_wrong_type_for_steps(self):
|
||||
"""Non-list steps value raises error."""
|
||||
data = {
|
||||
"id": "plan_1",
|
||||
"goal_id": "goal_1",
|
||||
"description": "Test plan",
|
||||
"steps": "not a list",
|
||||
}
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
Plan.from_json(data)
|
||||
|
||||
def test_from_json_empty_data(self):
|
||||
"""Empty dict creates plan with defaults."""
|
||||
plan = Plan.from_json({})
|
||||
|
||||
assert plan.id == "plan"
|
||||
assert plan.goal_id == ""
|
||||
assert plan.steps == []
|
||||
|
||||
|
||||
class TestPlanMethods:
|
||||
"""Tests for Plan instance methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_plan(self):
|
||||
"""Create a sample plan with multiple steps."""
|
||||
return Plan(
|
||||
id="test_plan",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="First step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=[],
|
||||
status=StepStatus.COMPLETED,
|
||||
result={"data": "result1"},
|
||||
),
|
||||
PlanStep(
|
||||
id="step_2",
|
||||
description="Second step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1"],
|
||||
status=StepStatus.PENDING,
|
||||
),
|
||||
PlanStep(
|
||||
id="step_3",
|
||||
description="Third step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
dependencies=["step_1"],
|
||||
status=StepStatus.FAILED,
|
||||
error="Something went wrong",
|
||||
attempts=3,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def test_plan_get_step(self, sample_plan):
|
||||
"""Find step by ID."""
|
||||
step = sample_plan.get_step("step_2")
|
||||
|
||||
assert step is not None
|
||||
assert step.id == "step_2"
|
||||
assert step.description == "Second step"
|
||||
|
||||
def test_plan_get_step_not_found(self, sample_plan):
|
||||
"""Returns None for missing step ID."""
|
||||
step = sample_plan.get_step("nonexistent")
|
||||
|
||||
assert step is None
|
||||
|
||||
def test_plan_get_ready_steps(self, sample_plan):
|
||||
"""Filter steps ready to execute."""
|
||||
ready = sample_plan.get_ready_steps()
|
||||
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "step_2"
|
||||
|
||||
def test_plan_get_completed_steps(self, sample_plan):
|
||||
"""Filter completed steps."""
|
||||
completed = sample_plan.get_completed_steps()
|
||||
|
||||
assert len(completed) == 1
|
||||
assert completed[0].id == "step_1"
|
||||
|
||||
def test_plan_is_complete_false(self, sample_plan):
|
||||
"""Plan is not complete when steps are pending/failed."""
|
||||
assert sample_plan.is_complete() is False
|
||||
|
||||
def test_plan_is_complete_true(self):
|
||||
"""Plan is complete when all steps are completed."""
|
||||
plan = Plan(
|
||||
id="test_plan",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="First step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.COMPLETED,
|
||||
),
|
||||
PlanStep(
|
||||
id="step_2",
|
||||
description="Second step",
|
||||
action=ActionSpec(action_type=ActionType.FUNCTION),
|
||||
status=StepStatus.COMPLETED,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert plan.is_complete() is True
|
||||
|
||||
def test_plan_is_complete_empty(self):
|
||||
"""Empty plan is considered complete."""
|
||||
plan = Plan(
|
||||
id="empty_plan",
|
||||
goal_id="goal_1",
|
||||
description="Empty plan",
|
||||
steps=[],
|
||||
)
|
||||
assert plan.is_complete() is True
|
||||
|
||||
def test_plan_to_feedback_context(self, sample_plan):
|
||||
"""Serializes context for replanning."""
|
||||
context = sample_plan.to_feedback_context()
|
||||
|
||||
assert context["plan_id"] == "test_plan"
|
||||
assert context["revision"] == 1
|
||||
assert len(context["completed_steps"]) == 1
|
||||
assert context["completed_steps"][0]["id"] == "step_1"
|
||||
assert len(context["failed_steps"]) == 1
|
||||
assert context["failed_steps"][0]["id"] == "step_3"
|
||||
assert context["failed_steps"][0]["error"] == "Something went wrong"
|
||||
|
||||
|
||||
class TestPlanRoundTrip:
|
||||
"""Tests for Plan serialization round-trip."""
|
||||
|
||||
def test_plan_round_trip_model_dump(self):
|
||||
"""from_json(plan.model_dump()) preserves data."""
|
||||
original = Plan(
|
||||
id="plan_1",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="First step",
|
||||
action=ActionSpec(
|
||||
action_type=ActionType.LLM_CALL,
|
||||
prompt="Hello world",
|
||||
),
|
||||
dependencies=[],
|
||||
expected_outputs=["greeting"],
|
||||
),
|
||||
],
|
||||
context={"key": "value"},
|
||||
revision=2,
|
||||
)
|
||||
|
||||
# Round-trip through dict
|
||||
data = original.model_dump()
|
||||
restored = Plan.from_json(data)
|
||||
|
||||
assert restored.id == original.id
|
||||
assert restored.goal_id == original.goal_id
|
||||
assert restored.description == original.description
|
||||
assert restored.context == original.context
|
||||
assert restored.revision == original.revision
|
||||
assert len(restored.steps) == len(original.steps)
|
||||
assert restored.steps[0].id == original.steps[0].id
|
||||
assert restored.steps[0].action.action_type == original.steps[0].action.action_type
|
||||
|
||||
def test_plan_round_trip_json_string(self):
|
||||
"""from_json(plan.model_dump_json()) preserves data."""
|
||||
original = Plan(
|
||||
id="plan_1",
|
||||
goal_id="goal_1",
|
||||
description="Test plan",
|
||||
steps=[
|
||||
PlanStep(
|
||||
id="step_1",
|
||||
description="First step",
|
||||
action=ActionSpec(
|
||||
action_type=ActionType.TOOL_USE,
|
||||
tool_name="my_tool",
|
||||
tool_args={"arg1": "value1"},
|
||||
),
|
||||
dependencies=[],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Round-trip through JSON string
|
||||
json_str = original.model_dump_json()
|
||||
restored = Plan.from_json(json_str)
|
||||
|
||||
assert restored.id == original.id
|
||||
assert len(restored.steps) == 1
|
||||
assert restored.steps[0].action.tool_name == "my_tool"
|
||||
|
||||
def test_plan_step_serialization(self):
|
||||
"""PlanStep serializes and deserializes correctly."""
|
||||
step = PlanStep(
|
||||
id="step_1",
|
||||
description="Test step",
|
||||
action=ActionSpec(
|
||||
action_type=ActionType.CODE_EXECUTION,
|
||||
code="print('hello')",
|
||||
language="python",
|
||||
),
|
||||
inputs={"input1": "value1"},
|
||||
expected_outputs=["output1", "output2"],
|
||||
dependencies=["dep1", "dep2"],
|
||||
requires_approval=True,
|
||||
approval_message="Please approve",
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
data = step.model_dump()
|
||||
|
||||
assert data["id"] == "step_1"
|
||||
assert data["action"]["action_type"] == "code_execution"
|
||||
assert data["action"]["code"] == "print('hello')"
|
||||
assert data["inputs"] == {"input1": "value1"}
|
||||
assert data["expected_outputs"] == ["output1", "output2"]
|
||||
assert data["dependencies"] == ["dep1", "dep2"]
|
||||
assert data["requires_approval"] is True
|
||||
Reference in New Issue
Block a user