""" Session State Schema - Unified state for session execution. This schema consolidates data from Run, ExecutionResult, and runtime logs into a single source of truth for session status and resumability. """ from datetime import datetime from enum import StrEnum from typing import TYPE_CHECKING, Any from pydantic import AliasChoices, BaseModel, Field, computed_field if TYPE_CHECKING: from framework.orchestrator.orchestrator import ExecutionResult from framework.schemas.run import Run class SessionStatus(StrEnum): """Status of a session execution.""" ACTIVE = "active" # Currently executing PAUSED = "paused" # Waiting for resume (client input, pause node) COMPLETED = "completed" # Finished successfully FAILED = "failed" # Finished with error CANCELLED = "cancelled" # User/system cancelled class SessionTimestamps(BaseModel): """Timestamps tracking session lifecycle.""" started_at: str # ISO 8601 format updated_at: str # ISO 8601 format (updated on every state write) completed_at: str | None = None paused_at_time: str | None = None # When it was paused model_config = {"extra": "allow"} class SessionProgress(BaseModel): """Execution progress tracking.""" current_node: str | None = None paused_at: str | None = None # Node ID where paused resume_from: str | None = None # Entry point or node ID to resume from steps_executed: int = 0 total_tokens: int = 0 total_latency_ms: int = 0 path: list[str] = Field(default_factory=list) # Node IDs traversed # Quality metrics (from ExecutionResult) total_retries: int = 0 nodes_with_failures: list[str] = Field(default_factory=list) retry_details: dict[str, int] = Field(default_factory=dict) had_partial_failures: bool = False execution_quality: str = "clean" # "clean", "degraded", or "failed" node_visit_counts: dict[str, int] = Field(default_factory=dict) model_config = {"extra": "allow"} class SessionResult(BaseModel): """Final result of session execution.""" success: bool | None = None # None if still running error: str | None = None output: dict[str, Any] = Field(default_factory=dict) model_config = {"extra": "allow"} class SessionMetrics(BaseModel): """Execution metrics (from Run.metrics).""" decision_count: int = 0 problem_count: int = 0 total_input_tokens: int = 0 total_output_tokens: int = 0 nodes_executed: list[str] = Field(default_factory=list) edges_traversed: list[str] = Field(default_factory=list) model_config = {"extra": "allow"} class SessionState(BaseModel): """ Complete state for a session execution. This is the single source of truth for session status and resumability. Consolidates data from ExecutionResult, ExecutionContext, Run, and runtime logs. Version History: - v1.0: Initial schema (2026-02-06) - v1.1: Added checkpoint support (2026-02-08) """ # Schema version for forward/backward compatibility schema_version: str = "1.1" # Identity session_id: str # Format: session_YYYYMMDD_HHMMSS_{uuid_8char} stream_id: str = "" # Which ExecutionStream created this correlation_id: str = "" # For correlating related executions # Status status: SessionStatus = SessionStatus.ACTIVE # Goal/Agent context goal_id: str agent_id: str = "" entry_point: str = "start" # Timestamps timestamps: SessionTimestamps # Progress progress: SessionProgress = Field(default_factory=SessionProgress) # Result result: SessionResult = Field(default_factory=SessionResult) # Data buffer (for resumability) data_buffer: dict[str, Any] = Field( default_factory=dict, validation_alias=AliasChoices("data_buffer", "memory"), ) # Metrics metrics: SessionMetrics = Field(default_factory=SessionMetrics) # Problems (from Run.problems) problems: list[dict[str, Any]] = Field(default_factory=list) # Decisions (from Run.decisions - can be large, so store references) decisions: list[dict[str, Any]] = Field(default_factory=list) # Input data (for debugging/replay) input_data: dict[str, Any] = Field(default_factory=dict) current_run_id: str | None = None # Process ID of the owning process (for cross-process stale session detection) pid: int | None = None # Isolation level (from ExecutionContext) isolation_level: str = "shared" # Checkpointing (for crash recovery and resume-from-failure) checkpoint_enabled: bool = False latest_checkpoint_id: str | None = None # Trigger activation state (IDs of triggers the queen/user turned on) active_triggers: list[str] = Field(default_factory=list) # Per-trigger task strings (user overrides, keyed by trigger ID) trigger_tasks: dict[str, str] = Field(default_factory=dict) # True after first successful worker execution (gates trigger delivery on restart) worker_configured: bool = Field(default=False) model_config = {"extra": "allow"} @property def memory(self) -> dict[str, Any]: """Backward-compatible alias for legacy callers.""" return self.data_buffer @memory.setter def memory(self, value: dict[str, Any]) -> None: """Backward-compatible alias for legacy callers.""" self.data_buffer = value @computed_field @property def duration_ms(self) -> int: """Duration of the session in milliseconds.""" if not self.timestamps.completed_at: return 0 started = datetime.fromisoformat(self.timestamps.started_at) completed = datetime.fromisoformat(self.timestamps.completed_at) return int((completed - started).total_seconds() * 1000) @computed_field @property def is_resumable(self) -> bool: """Can this session be resumed? Only sessions with a valid checkpoint can be resumed. State-based resume (without a checkpoint) is no longer supported. """ return self.is_resumable_from_checkpoint @computed_field @property def is_resumable_from_checkpoint(self) -> bool: """Can this session be resumed from a checkpoint?""" # ANY session with checkpoints can be resumed (not just failed ones) # This enables: pause/resume, iterative execution, continuation after completion return self.checkpoint_enabled and self.latest_checkpoint_id is not None @classmethod def from_execution_result( cls, session_id: str, goal_id: str, result: "ExecutionResult", stream_id: str = "", correlation_id: str = "", started_at: str = "", input_data: dict[str, Any] | None = None, agent_id: str = "", entry_point: str = "start", ) -> "SessionState": """Create SessionState from ExecutionResult.""" now = datetime.now().isoformat() # Determine status based on execution result if result.paused_at: status = SessionStatus.PAUSED elif result.success: status = SessionStatus.COMPLETED else: status = SessionStatus.FAILED return cls( session_id=session_id, stream_id=stream_id, correlation_id=correlation_id, goal_id=goal_id, agent_id=agent_id, entry_point=entry_point, status=status, timestamps=SessionTimestamps( started_at=started_at or now, updated_at=now, completed_at=now if not result.paused_at else None, paused_at_time=now if result.paused_at else None, ), progress=SessionProgress( current_node=result.paused_at or (result.path[-1] if result.path else None), paused_at=result.paused_at, resume_from=result.session_state.get("resume_from") if result.session_state else None, steps_executed=result.steps_executed, total_tokens=result.total_tokens, total_latency_ms=result.total_latency_ms, path=result.path, total_retries=result.total_retries, nodes_with_failures=result.nodes_with_failures, retry_details=result.retry_details, had_partial_failures=result.had_partial_failures, execution_quality=result.execution_quality, node_visit_counts=result.node_visit_counts, ), result=SessionResult( success=result.success, error=result.error, output=result.output, ), data_buffer=result.session_state.get( "data_buffer", result.session_state.get("memory", {}) ) if result.session_state else {}, input_data=input_data or {}, ) @classmethod def from_legacy_run(cls, run: "Run", session_id: str, stream_id: str = "") -> "SessionState": """Create SessionState from legacy Run object.""" from framework.schemas.run import RunStatus now = datetime.now().isoformat() # Map RunStatus to SessionStatus status_mapping = { RunStatus.RUNNING: SessionStatus.ACTIVE, RunStatus.COMPLETED: SessionStatus.COMPLETED, RunStatus.FAILED: SessionStatus.FAILED, RunStatus.CANCELLED: SessionStatus.CANCELLED, RunStatus.STUCK: SessionStatus.FAILED, } status = status_mapping.get(run.status, SessionStatus.FAILED) return cls( schema_version="1.0", session_id=session_id, stream_id=stream_id, goal_id=run.goal_id, status=status, timestamps=SessionTimestamps( started_at=run.started_at.isoformat(), updated_at=now, completed_at=run.completed_at.isoformat() if run.completed_at else None, ), result=SessionResult( success=run.status == RunStatus.COMPLETED, output=run.output_data, ), metrics=SessionMetrics( decision_count=run.metrics.total_decisions, problem_count=len(run.problems), total_input_tokens=run.metrics.total_tokens, # Approximate total_output_tokens=0, # Not tracked in old format nodes_executed=run.metrics.nodes_executed, edges_traversed=run.metrics.edges_traversed, ), decisions=[d.model_dump() for d in run.decisions], problems=[p.model_dump() for p in run.problems], input_data=run.input_data, ) def to_session_state_dict(self) -> dict[str, Any]: """Convert to session_state format for GraphExecutor.execute(). NOTE: state-based resume via paused_at/resume_from is deprecated. Use checkpoint-based resume (``resume_from_checkpoint`` key) instead. """ # Derive resume target: explicit > last node in path > entry point resume_from = ( self.progress.resume_from or self.progress.paused_at or (self.progress.path[-1] if self.progress.path else None) ) return { "paused_at": resume_from, "resume_from": resume_from, "data_buffer": self.data_buffer, "execution_path": self.progress.path, "node_visit_counts": self.progress.node_visit_counts, }