Merge pull request #6898 from sundaram2021/fix/ast_pow_ddos_mitigation

micro-fix(security): mitigate ast.Pow DoS and enforce safe_eval timeout
This commit is contained in:
Bryan @ Aden
2026-04-06 13:36:03 -07:00
committed by GitHub
2 changed files with 227 additions and 9 deletions
+112 -9
View File
@@ -1,7 +1,84 @@
import ast
import operator
import signal
import threading
import time
from contextlib import contextmanager
from typing import Any
# Power operations can allocate extremely large integers. Keep conservative
# limits here so untrusted edge conditions cannot exhaust CPU or memory.
MAX_POWER_ABS_EXPONENT = 1_000
MAX_POWER_RESULT_BITS = 4_096
# Typical edge-condition evaluations in this repo complete well under 1ms.
# 100ms leaves ample headroom for legitimate checks while failing fast on abuse.
DEFAULT_TIMEOUT_MS = 100
def _safe_pow(base: Any, exp: Any) -> Any:
if isinstance(exp, (int, float)) and abs(exp) > MAX_POWER_ABS_EXPONENT:
raise ValueError(f"Power exponent exceeds safe limit ({MAX_POWER_ABS_EXPONENT})")
if isinstance(base, int) and isinstance(exp, int) and exp > 0:
abs_base = abs(base)
if abs_base > 1:
# Estimate bit growth instead of materializing a huge integer.
estimated_bits = exp * abs_base.bit_length()
if estimated_bits > MAX_POWER_RESULT_BITS:
raise ValueError("Power operation exceeds safe size limit")
return operator.pow(base, exp)
def _timeout_message(timeout_ms: int) -> str:
return f"safe_eval exceeded {timeout_ms}ms execution timeout"
def _check_timeout(deadline: float | None, timeout_ms: int | None) -> None:
if deadline is not None and timeout_ms is not None and time.perf_counter() >= deadline:
raise TimeoutError(_timeout_message(timeout_ms))
@contextmanager
def _execution_timeout(timeout_ms: int | None):
if timeout_ms is None:
yield
return
if timeout_ms <= 0:
raise ValueError("timeout_ms must be greater than 0")
can_use_alarm = (
hasattr(signal, "SIGALRM")
and hasattr(signal, "ITIMER_REAL")
and hasattr(signal, "getitimer")
and hasattr(signal, "setitimer")
and threading.current_thread() is threading.main_thread()
)
if not can_use_alarm:
yield
return
current_delay, current_interval = signal.getitimer(signal.ITIMER_REAL)
if current_delay > 0 or current_interval > 0:
# safe_eval runs inside a shared framework process, so it must not
# replace a timer another subsystem already owns.
yield
return
def _handle_timeout(signum, frame):
raise TimeoutError(_timeout_message(timeout_ms))
old_handler = signal.getsignal(signal.SIGALRM)
signal.signal(signal.SIGALRM, _handle_timeout)
old_delay, old_interval = signal.setitimer(signal.ITIMER_REAL, timeout_ms / 1000)
try:
yield
finally:
signal.signal(signal.SIGALRM, old_handler)
signal.setitimer(signal.ITIMER_REAL, old_delay, old_interval)
# Safe operators whitelist
SAFE_OPERATORS = {
ast.Add: operator.add,
@@ -10,7 +87,7 @@ SAFE_OPERATORS = {
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.Pow: _safe_pow,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.BitOr: operator.or_,
@@ -54,10 +131,19 @@ SAFE_FUNCTIONS = {
class SafeEvalVisitor(ast.NodeVisitor):
def __init__(self, context: dict[str, Any]):
def __init__(
self,
context: dict[str, Any],
*,
deadline: float | None = None,
timeout_ms: int | None = None,
):
self.context = context
self.deadline = deadline
self.timeout_ms = timeout_ms
def visit(self, node: ast.AST) -> Any:
_check_timeout(self.deadline, self.timeout_ms)
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
method = "visit_" + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
@@ -183,6 +269,7 @@ class SafeEvalVisitor(ast.NodeVisitor):
raise AttributeError(f"Object has no attribute '{node.attr}'")
def visit_Call(self, node: ast.Call) -> Any:
_check_timeout(self.deadline, self.timeout_ms)
# Only allow calling whitelisted functions
func = self.visit(node.func)
@@ -226,16 +313,24 @@ class SafeEvalVisitor(ast.NodeVisitor):
args = [self.visit(arg) for arg in node.args]
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
_check_timeout(self.deadline, self.timeout_ms)
return func(*args, **keywords)
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
def safe_eval(
expr: str,
context: dict[str, Any] | None = None,
*,
timeout_ms: int | None = DEFAULT_TIMEOUT_MS,
) -> Any:
"""
Safely evaluate a python expression string.
Args:
expr: The expression string to evaluate.
context: Dictionary of variables available in the expression.
timeout_ms: Maximum evaluation time in milliseconds. Use ``None`` to
disable the timeout.
Returns:
The result of the evaluation.
@@ -251,10 +346,18 @@ def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
full_context = context.copy()
full_context.update(SAFE_FUNCTIONS)
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
deadline = None if timeout_ms is None else time.perf_counter() + (timeout_ms / 1000)
visitor = SafeEvalVisitor(full_context)
return visitor.visit(tree)
with _execution_timeout(timeout_ms):
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
_check_timeout(deadline, timeout_ms)
visitor = SafeEvalVisitor(
full_context,
deadline=deadline,
timeout_ms=timeout_ms,
)
return visitor.visit(tree)
+115
View File
@@ -9,6 +9,7 @@ AST nodes, disallowed function calls).
import pytest
import framework.graph.safe_eval as safe_eval_module
from framework.graph.safe_eval import safe_eval
# ---------------------------------------------------------------------------
@@ -94,10 +95,124 @@ class TestArithmetic:
def test_power(self):
assert safe_eval("2 ** 10") == 1024
def test_power_large_exponent_blocked(self):
with pytest.raises(ValueError, match="Power exponent"):
safe_eval("2 ** 1001")
def test_power_large_result_blocked(self):
with pytest.raises(ValueError, match="Power operation"):
safe_eval("99 ** 1000")
def test_nested_power_blocked(self):
with pytest.raises(ValueError, match="Power exponent"):
safe_eval("2 ** 2 ** 20")
def test_complex_expression(self):
assert safe_eval("(2 + 3) * 4 - 1") == 19
class TestExecutionTimeout:
def test_default_timeout(self):
assert safe_eval_module.DEFAULT_TIMEOUT_MS == 100
def test_timeout_must_be_positive(self):
with pytest.raises(ValueError, match="timeout_ms"):
safe_eval("1 + 1", timeout_ms=0)
def test_timeout_can_be_disabled(self):
assert safe_eval("1 + 1", timeout_ms=None) == 2
def test_timeout_exceeded_raises(self, monkeypatch):
ticks = iter([0.0, 1.0])
monkeypatch.setattr(safe_eval_module.time, "perf_counter", lambda: next(ticks))
with pytest.raises(TimeoutError, match="1ms"):
safe_eval("1 + 1", timeout_ms=1)
def test_existing_process_timer_is_preserved(self, monkeypatch):
calls: list[tuple[str, object]] = []
main_thread = object()
monkeypatch.setattr(safe_eval_module.signal, "SIGALRM", object(), raising=False)
monkeypatch.setattr(safe_eval_module.signal, "ITIMER_REAL", object(), raising=False)
monkeypatch.setattr(
safe_eval_module.signal,
"getitimer",
lambda which: (5.0, 0.0),
raising=False,
)
monkeypatch.setattr(
safe_eval_module.signal,
"setitimer",
lambda *args: calls.append(("setitimer", args)),
raising=False,
)
monkeypatch.setattr(
safe_eval_module.signal,
"signal",
lambda *args: calls.append(("signal", args)),
)
monkeypatch.setattr(safe_eval_module.threading, "main_thread", lambda: main_thread)
monkeypatch.setattr(
safe_eval_module.threading,
"current_thread",
lambda: main_thread,
)
with safe_eval_module._execution_timeout(100):
pass
assert calls == []
def test_timeout_restores_alarm_state(self, monkeypatch):
calls: list[tuple[str, object]] = []
main_thread = object()
old_handler = object()
monkeypatch.setattr(safe_eval_module.signal, "SIGALRM", object(), raising=False)
monkeypatch.setattr(safe_eval_module.signal, "ITIMER_REAL", object(), raising=False)
monkeypatch.setattr(
safe_eval_module.signal,
"getitimer",
lambda which: (0.0, 0.0),
raising=False,
)
monkeypatch.setattr(
safe_eval_module.signal,
"getsignal",
lambda which: old_handler,
)
def fake_signal(which, handler):
calls.append(("signal", handler))
def fake_setitimer(which, delay, interval=0.0):
calls.append(("setitimer", (delay, interval)))
return (0.0, 0.0)
monkeypatch.setattr(safe_eval_module.signal, "signal", fake_signal)
monkeypatch.setattr(
safe_eval_module.signal,
"setitimer",
fake_setitimer,
raising=False,
)
monkeypatch.setattr(safe_eval_module.threading, "main_thread", lambda: main_thread)
monkeypatch.setattr(
safe_eval_module.threading,
"current_thread",
lambda: main_thread,
)
with safe_eval_module._execution_timeout(100):
pass
assert calls[0][0] == "signal"
assert calls[1] == ("setitimer", (0.1, 0.0))
assert calls[2] == ("signal", old_handler)
assert calls[3] == ("setitimer", (0.0, 0.0))
# ---------------------------------------------------------------------------
# Unary operators
# ---------------------------------------------------------------------------