116 lines
4.1 KiB
Python
116 lines
4.1 KiB
Python
"""Pipeline runner -- executes registered stages in order."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from framework.pipeline.stage import (
|
|
PipelineContext,
|
|
PipelineRejectedError,
|
|
PipelineStage,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PipelineRunner:
|
|
"""Executes a list of :class:`PipelineStage` instances in ``order``.
|
|
|
|
The runner is the orchestration layer that :class:`AgentRuntime` calls
|
|
on every trigger. Stages execute in ascending ``order`` (ties broken by
|
|
registration order). A stage returning ``reject`` short-circuits the
|
|
pipeline and causes the trigger to raise :class:`PipelineRejectedError`.
|
|
"""
|
|
|
|
def __init__(self, stages: list[PipelineStage] | None = None) -> None:
|
|
self._stages: list[PipelineStage] = sorted(stages or [], key=lambda s: s.order)
|
|
|
|
@property
|
|
def stages(self) -> list[PipelineStage]:
|
|
return list(self._stages)
|
|
|
|
def add_stage(self, stage: PipelineStage) -> None:
|
|
"""Add a stage after construction (for dynamic registration)."""
|
|
self._stages.append(stage)
|
|
self._stages.sort(key=lambda s: s.order)
|
|
|
|
async def initialize_all(self) -> None:
|
|
"""Call ``initialize`` on every registered stage."""
|
|
for stage in self._stages:
|
|
name = stage.__class__.__name__
|
|
logger.info("[pipeline] Initializing %s (order=%d)", name, stage.order)
|
|
await stage.initialize()
|
|
logger.info("[pipeline] %s initialized", name)
|
|
if self._stages:
|
|
logger.info(
|
|
"[pipeline] Ready: %d stages [%s]",
|
|
len(self._stages),
|
|
" -> ".join(s.__class__.__name__ for s in self._stages),
|
|
)
|
|
|
|
async def run(self, ctx: PipelineContext) -> PipelineContext:
|
|
"""Run all stages. Raises ``PipelineRejectedError`` on rejection.
|
|
|
|
Returns the (possibly transformed) context.
|
|
"""
|
|
if not self._stages:
|
|
return ctx
|
|
import time
|
|
|
|
pipeline_start = time.perf_counter()
|
|
logger.info(
|
|
"[pipeline] Running %d stages for entry_point=%s",
|
|
len(self._stages),
|
|
ctx.entry_point_id,
|
|
)
|
|
for stage in self._stages:
|
|
stage_name = stage.__class__.__name__
|
|
t0 = time.perf_counter()
|
|
result = await stage.process(ctx)
|
|
elapsed_ms = (time.perf_counter() - t0) * 1000
|
|
if result.action == "reject":
|
|
reason = result.rejection_reason or "(no reason given)"
|
|
logger.warning(
|
|
"[pipeline] REJECTED by %s (%.1fms): %s",
|
|
stage_name,
|
|
elapsed_ms,
|
|
reason,
|
|
)
|
|
raise PipelineRejectedError(stage_name, reason)
|
|
if result.action == "transform":
|
|
logger.info(
|
|
"[pipeline] %s TRANSFORMED input (%.1fms)",
|
|
stage_name,
|
|
elapsed_ms,
|
|
)
|
|
if result.input_data is not None:
|
|
ctx.input_data = result.input_data
|
|
else:
|
|
logger.info(
|
|
"[pipeline] %s passed (%.1fms)",
|
|
stage_name,
|
|
elapsed_ms,
|
|
)
|
|
total_ms = (time.perf_counter() - pipeline_start) * 1000
|
|
logger.info("[pipeline] Complete (%.1fms total)", total_ms)
|
|
return ctx
|
|
|
|
async def run_post(self, ctx: PipelineContext, result: Any) -> Any:
|
|
"""Run all stages' ``post_process`` hooks in order.
|
|
|
|
Each stage can transform the result; the final value is returned.
|
|
Exceptions are logged and swallowed -- post-processing must not
|
|
break a successful execution.
|
|
"""
|
|
current = result
|
|
for stage in self._stages:
|
|
try:
|
|
current = await stage.post_process(ctx, current)
|
|
except Exception:
|
|
logger.exception(
|
|
"Pipeline post_process raised in %s; continuing with previous result",
|
|
stage.__class__.__name__,
|
|
)
|
|
return current
|