fix(security): enforce safe_eval execution timeout
This commit is contained in:
@@ -1,11 +1,18 @@
|
||||
import ast
|
||||
import operator
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
# Power operations can allocate extremely large integers. Keep conservative
|
||||
# limits here so untrusted edge conditions cannot exhaust CPU or memory.
|
||||
MAX_POWER_ABS_EXPONENT = 1_000
|
||||
MAX_POWER_RESULT_BITS = 4_096
|
||||
# Typical edge-condition evaluations in this repo complete well under 1ms.
|
||||
# 100ms leaves ample headroom for legitimate checks while failing fast on abuse.
|
||||
DEFAULT_TIMEOUT_MS = 100
|
||||
|
||||
|
||||
def _safe_pow(base: Any, exp: Any) -> Any:
|
||||
@@ -25,6 +32,47 @@ def _safe_pow(base: Any, exp: Any) -> Any:
|
||||
return operator.pow(base, exp)
|
||||
|
||||
|
||||
def _timeout_message(timeout_ms: int) -> str:
|
||||
return f"safe_eval exceeded {timeout_ms}ms execution timeout"
|
||||
|
||||
|
||||
def _check_timeout(deadline: float | None, timeout_ms: int | None) -> None:
|
||||
if deadline is not None and timeout_ms is not None and time.perf_counter() >= deadline:
|
||||
raise TimeoutError(_timeout_message(timeout_ms))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _execution_timeout(timeout_ms: int | None):
|
||||
if timeout_ms is None:
|
||||
yield
|
||||
return
|
||||
|
||||
if timeout_ms <= 0:
|
||||
raise ValueError("timeout_ms must be greater than 0")
|
||||
|
||||
can_use_alarm = (
|
||||
hasattr(signal, "SIGALRM")
|
||||
and hasattr(signal, "ITIMER_REAL")
|
||||
and hasattr(signal, "setitimer")
|
||||
and threading.current_thread() is threading.main_thread()
|
||||
)
|
||||
if not can_use_alarm:
|
||||
yield
|
||||
return
|
||||
|
||||
def _handle_timeout(signum, frame):
|
||||
raise TimeoutError(_timeout_message(timeout_ms))
|
||||
|
||||
old_handler = signal.getsignal(signal.SIGALRM)
|
||||
signal.signal(signal.SIGALRM, _handle_timeout)
|
||||
signal.setitimer(signal.ITIMER_REAL, timeout_ms / 1000)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
|
||||
# Safe operators whitelist
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
@@ -77,10 +125,19 @@ SAFE_FUNCTIONS = {
|
||||
|
||||
|
||||
class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def __init__(self, context: dict[str, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
context: dict[str, Any],
|
||||
*,
|
||||
deadline: float | None = None,
|
||||
timeout_ms: int | None = None,
|
||||
):
|
||||
self.context = context
|
||||
self.deadline = deadline
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
_check_timeout(self.deadline, self.timeout_ms)
|
||||
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
|
||||
method = "visit_" + node.__class__.__name__
|
||||
visitor = getattr(self, method, self.generic_visit)
|
||||
@@ -206,6 +263,7 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
raise AttributeError(f"Object has no attribute '{node.attr}'")
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
_check_timeout(self.deadline, self.timeout_ms)
|
||||
# Only allow calling whitelisted functions
|
||||
func = self.visit(node.func)
|
||||
|
||||
@@ -249,16 +307,24 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
|
||||
|
||||
_check_timeout(self.deadline, self.timeout_ms)
|
||||
return func(*args, **keywords)
|
||||
|
||||
|
||||
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
def safe_eval(
|
||||
expr: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
*,
|
||||
timeout_ms: int | None = DEFAULT_TIMEOUT_MS,
|
||||
) -> Any:
|
||||
"""
|
||||
Safely evaluate a python expression string.
|
||||
|
||||
Args:
|
||||
expr: The expression string to evaluate.
|
||||
context: Dictionary of variables available in the expression.
|
||||
timeout_ms: Maximum evaluation time in milliseconds. Use ``None`` to
|
||||
disable the timeout.
|
||||
|
||||
Returns:
|
||||
The result of the evaluation.
|
||||
@@ -274,10 +340,18 @@ def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
full_context = context.copy()
|
||||
full_context.update(SAFE_FUNCTIONS)
|
||||
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
|
||||
deadline = None if timeout_ms is None else time.perf_counter() + (timeout_ms / 1000)
|
||||
|
||||
visitor = SafeEvalVisitor(full_context)
|
||||
return visitor.visit(tree)
|
||||
with _execution_timeout(timeout_ms):
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
|
||||
|
||||
_check_timeout(deadline, timeout_ms)
|
||||
visitor = SafeEvalVisitor(
|
||||
full_context,
|
||||
deadline=deadline,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
return visitor.visit(tree)
|
||||
|
||||
Reference in New Issue
Block a user