Files
hive/core/framework/pipeline/runner.py
T
2026-04-07 15:20:31 -07:00

112 lines
4.0 KiB
Python

"""Pipeline runner -- executes registered stages in order."""
from __future__ import annotations
import logging
from typing import Any
from framework.pipeline.stage import (
PipelineContext,
PipelineRejectedError,
PipelineStage,
)
logger = logging.getLogger(__name__)
class PipelineRunner:
"""Executes a list of :class:`PipelineStage` instances in ``order``.
The runner is the orchestration layer that :class:`AgentRuntime` calls
on every trigger. Stages execute in ascending ``order`` (ties broken by
registration order). A stage returning ``reject`` short-circuits the
pipeline and causes the trigger to raise :class:`PipelineRejectedError`.
"""
def __init__(self, stages: list[PipelineStage] | None = None) -> None:
self._stages: list[PipelineStage] = sorted(stages or [], key=lambda s: s.order)
@property
def stages(self) -> list[PipelineStage]:
return list(self._stages)
def add_stage(self, stage: PipelineStage) -> None:
"""Add a stage after construction (for dynamic registration)."""
self._stages.append(stage)
self._stages.sort(key=lambda s: s.order)
async def initialize_all(self) -> None:
"""Call ``initialize`` on every registered stage."""
for stage in self._stages:
name = stage.__class__.__name__
logger.info("[pipeline] Initializing %s (order=%d)", name, stage.order)
await stage.initialize()
logger.info("[pipeline] %s initialized", name)
if self._stages:
logger.info(
"[pipeline] Ready: %d stages [%s]",
len(self._stages),
" -> ".join(s.__class__.__name__ for s in self._stages),
)
async def run(self, ctx: PipelineContext) -> PipelineContext:
"""Run all stages. Raises ``PipelineRejectedError`` on rejection.
Returns the (possibly transformed) context.
"""
if not self._stages:
return ctx
import time
pipeline_start = time.perf_counter()
logger.info(
"[pipeline] Running %d stages for entry_point=%s",
len(self._stages),
ctx.entry_point_id,
)
for stage in self._stages:
stage_name = stage.__class__.__name__
t0 = time.perf_counter()
result = await stage.process(ctx)
elapsed_ms = (time.perf_counter() - t0) * 1000
if result.action == "reject":
reason = result.rejection_reason or "(no reason given)"
logger.warning(
"[pipeline] REJECTED by %s (%.1fms): %s",
stage_name, elapsed_ms, reason,
)
raise PipelineRejectedError(stage_name, reason)
if result.action == "transform":
logger.info(
"[pipeline] %s TRANSFORMED input (%.1fms)",
stage_name, elapsed_ms,
)
if result.input_data is not None:
ctx.input_data = result.input_data
else:
logger.info(
"[pipeline] %s passed (%.1fms)",
stage_name, elapsed_ms,
)
total_ms = (time.perf_counter() - pipeline_start) * 1000
logger.info("[pipeline] Complete (%.1fms total)", total_ms)
return ctx
async def run_post(self, ctx: PipelineContext, result: Any) -> Any:
"""Run all stages' ``post_process`` hooks in order.
Each stage can transform the result; the final value is returned.
Exceptions are logged and swallowed -- post-processing must not
break a successful execution.
"""
current = result
for stage in self._stages:
try:
current = await stage.post_process(ctx, current)
except Exception:
logger.exception(
"Pipeline post_process raised in %s; continuing with previous result",
stage.__class__.__name__,
)
return current