Merge pull request #6898 from sundaram2021/fix/ast_pow_ddos_mitigation
micro-fix(security): mitigate ast.Pow DoS and enforce safe_eval timeout
This commit is contained in:
@@ -1,7 +1,84 @@
|
||||
import ast
|
||||
import operator
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
# Power operations can allocate extremely large integers. Keep conservative
|
||||
# limits here so untrusted edge conditions cannot exhaust CPU or memory.
|
||||
MAX_POWER_ABS_EXPONENT = 1_000
|
||||
MAX_POWER_RESULT_BITS = 4_096
|
||||
# Typical edge-condition evaluations in this repo complete well under 1ms.
|
||||
# 100ms leaves ample headroom for legitimate checks while failing fast on abuse.
|
||||
DEFAULT_TIMEOUT_MS = 100
|
||||
|
||||
|
||||
def _safe_pow(base: Any, exp: Any) -> Any:
|
||||
if isinstance(exp, (int, float)) and abs(exp) > MAX_POWER_ABS_EXPONENT:
|
||||
raise ValueError(f"Power exponent exceeds safe limit ({MAX_POWER_ABS_EXPONENT})")
|
||||
|
||||
if isinstance(base, int) and isinstance(exp, int) and exp > 0:
|
||||
abs_base = abs(base)
|
||||
if abs_base > 1:
|
||||
# Estimate bit growth instead of materializing a huge integer.
|
||||
estimated_bits = exp * abs_base.bit_length()
|
||||
if estimated_bits > MAX_POWER_RESULT_BITS:
|
||||
raise ValueError("Power operation exceeds safe size limit")
|
||||
|
||||
return operator.pow(base, exp)
|
||||
|
||||
|
||||
def _timeout_message(timeout_ms: int) -> str:
|
||||
return f"safe_eval exceeded {timeout_ms}ms execution timeout"
|
||||
|
||||
|
||||
def _check_timeout(deadline: float | None, timeout_ms: int | None) -> None:
|
||||
if deadline is not None and timeout_ms is not None and time.perf_counter() >= deadline:
|
||||
raise TimeoutError(_timeout_message(timeout_ms))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _execution_timeout(timeout_ms: int | None):
|
||||
if timeout_ms is None:
|
||||
yield
|
||||
return
|
||||
|
||||
if timeout_ms <= 0:
|
||||
raise ValueError("timeout_ms must be greater than 0")
|
||||
|
||||
can_use_alarm = (
|
||||
hasattr(signal, "SIGALRM")
|
||||
and hasattr(signal, "ITIMER_REAL")
|
||||
and hasattr(signal, "getitimer")
|
||||
and hasattr(signal, "setitimer")
|
||||
and threading.current_thread() is threading.main_thread()
|
||||
)
|
||||
if not can_use_alarm:
|
||||
yield
|
||||
return
|
||||
|
||||
current_delay, current_interval = signal.getitimer(signal.ITIMER_REAL)
|
||||
if current_delay > 0 or current_interval > 0:
|
||||
# safe_eval runs inside a shared framework process, so it must not
|
||||
# replace a timer another subsystem already owns.
|
||||
yield
|
||||
return
|
||||
|
||||
def _handle_timeout(signum, frame):
|
||||
raise TimeoutError(_timeout_message(timeout_ms))
|
||||
|
||||
old_handler = signal.getsignal(signal.SIGALRM)
|
||||
signal.signal(signal.SIGALRM, _handle_timeout)
|
||||
old_delay, old_interval = signal.setitimer(signal.ITIMER_REAL, timeout_ms / 1000)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
signal.setitimer(signal.ITIMER_REAL, old_delay, old_interval)
|
||||
|
||||
|
||||
# Safe operators whitelist
|
||||
SAFE_OPERATORS = {
|
||||
ast.Add: operator.add,
|
||||
@@ -10,7 +87,7 @@ SAFE_OPERATORS = {
|
||||
ast.Div: operator.truediv,
|
||||
ast.FloorDiv: operator.floordiv,
|
||||
ast.Mod: operator.mod,
|
||||
ast.Pow: operator.pow,
|
||||
ast.Pow: _safe_pow,
|
||||
ast.LShift: operator.lshift,
|
||||
ast.RShift: operator.rshift,
|
||||
ast.BitOr: operator.or_,
|
||||
@@ -54,10 +131,19 @@ SAFE_FUNCTIONS = {
|
||||
|
||||
|
||||
class SafeEvalVisitor(ast.NodeVisitor):
|
||||
def __init__(self, context: dict[str, Any]):
|
||||
def __init__(
|
||||
self,
|
||||
context: dict[str, Any],
|
||||
*,
|
||||
deadline: float | None = None,
|
||||
timeout_ms: int | None = None,
|
||||
):
|
||||
self.context = context
|
||||
self.deadline = deadline
|
||||
self.timeout_ms = timeout_ms
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
_check_timeout(self.deadline, self.timeout_ms)
|
||||
# Override visit to prevent default behavior and ensure only explicitly allowed nodes work
|
||||
method = "visit_" + node.__class__.__name__
|
||||
visitor = getattr(self, method, self.generic_visit)
|
||||
@@ -183,6 +269,7 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
raise AttributeError(f"Object has no attribute '{node.attr}'")
|
||||
|
||||
def visit_Call(self, node: ast.Call) -> Any:
|
||||
_check_timeout(self.deadline, self.timeout_ms)
|
||||
# Only allow calling whitelisted functions
|
||||
func = self.visit(node.func)
|
||||
|
||||
@@ -226,16 +313,24 @@ class SafeEvalVisitor(ast.NodeVisitor):
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
keywords = {kw.arg: self.visit(kw.value) for kw in node.keywords}
|
||||
|
||||
_check_timeout(self.deadline, self.timeout_ms)
|
||||
return func(*args, **keywords)
|
||||
|
||||
|
||||
def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
def safe_eval(
|
||||
expr: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
*,
|
||||
timeout_ms: int | None = DEFAULT_TIMEOUT_MS,
|
||||
) -> Any:
|
||||
"""
|
||||
Safely evaluate a python expression string.
|
||||
|
||||
Args:
|
||||
expr: The expression string to evaluate.
|
||||
context: Dictionary of variables available in the expression.
|
||||
timeout_ms: Maximum evaluation time in milliseconds. Use ``None`` to
|
||||
disable the timeout.
|
||||
|
||||
Returns:
|
||||
The result of the evaluation.
|
||||
@@ -251,10 +346,18 @@ def safe_eval(expr: str, context: dict[str, Any] | None = None) -> Any:
|
||||
full_context = context.copy()
|
||||
full_context.update(SAFE_FUNCTIONS)
|
||||
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
|
||||
deadline = None if timeout_ms is None else time.perf_counter() + (timeout_ms / 1000)
|
||||
|
||||
visitor = SafeEvalVisitor(full_context)
|
||||
return visitor.visit(tree)
|
||||
with _execution_timeout(timeout_ms):
|
||||
try:
|
||||
tree = ast.parse(expr, mode="eval")
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"Invalid syntax in expression: {e}") from e
|
||||
|
||||
_check_timeout(deadline, timeout_ms)
|
||||
visitor = SafeEvalVisitor(
|
||||
full_context,
|
||||
deadline=deadline,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
return visitor.visit(tree)
|
||||
|
||||
@@ -9,6 +9,7 @@ AST nodes, disallowed function calls).
|
||||
|
||||
import pytest
|
||||
|
||||
import framework.graph.safe_eval as safe_eval_module
|
||||
from framework.graph.safe_eval import safe_eval
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -94,10 +95,124 @@ class TestArithmetic:
|
||||
def test_power(self):
|
||||
assert safe_eval("2 ** 10") == 1024
|
||||
|
||||
def test_power_large_exponent_blocked(self):
|
||||
with pytest.raises(ValueError, match="Power exponent"):
|
||||
safe_eval("2 ** 1001")
|
||||
|
||||
def test_power_large_result_blocked(self):
|
||||
with pytest.raises(ValueError, match="Power operation"):
|
||||
safe_eval("99 ** 1000")
|
||||
|
||||
def test_nested_power_blocked(self):
|
||||
with pytest.raises(ValueError, match="Power exponent"):
|
||||
safe_eval("2 ** 2 ** 20")
|
||||
|
||||
def test_complex_expression(self):
|
||||
assert safe_eval("(2 + 3) * 4 - 1") == 19
|
||||
|
||||
|
||||
class TestExecutionTimeout:
|
||||
def test_default_timeout(self):
|
||||
assert safe_eval_module.DEFAULT_TIMEOUT_MS == 100
|
||||
|
||||
def test_timeout_must_be_positive(self):
|
||||
with pytest.raises(ValueError, match="timeout_ms"):
|
||||
safe_eval("1 + 1", timeout_ms=0)
|
||||
|
||||
def test_timeout_can_be_disabled(self):
|
||||
assert safe_eval("1 + 1", timeout_ms=None) == 2
|
||||
|
||||
def test_timeout_exceeded_raises(self, monkeypatch):
|
||||
ticks = iter([0.0, 1.0])
|
||||
monkeypatch.setattr(safe_eval_module.time, "perf_counter", lambda: next(ticks))
|
||||
|
||||
with pytest.raises(TimeoutError, match="1ms"):
|
||||
safe_eval("1 + 1", timeout_ms=1)
|
||||
|
||||
def test_existing_process_timer_is_preserved(self, monkeypatch):
|
||||
calls: list[tuple[str, object]] = []
|
||||
main_thread = object()
|
||||
|
||||
monkeypatch.setattr(safe_eval_module.signal, "SIGALRM", object(), raising=False)
|
||||
monkeypatch.setattr(safe_eval_module.signal, "ITIMER_REAL", object(), raising=False)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.signal,
|
||||
"getitimer",
|
||||
lambda which: (5.0, 0.0),
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.signal,
|
||||
"setitimer",
|
||||
lambda *args: calls.append(("setitimer", args)),
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.signal,
|
||||
"signal",
|
||||
lambda *args: calls.append(("signal", args)),
|
||||
)
|
||||
monkeypatch.setattr(safe_eval_module.threading, "main_thread", lambda: main_thread)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.threading,
|
||||
"current_thread",
|
||||
lambda: main_thread,
|
||||
)
|
||||
|
||||
with safe_eval_module._execution_timeout(100):
|
||||
pass
|
||||
|
||||
assert calls == []
|
||||
|
||||
def test_timeout_restores_alarm_state(self, monkeypatch):
|
||||
calls: list[tuple[str, object]] = []
|
||||
main_thread = object()
|
||||
old_handler = object()
|
||||
|
||||
monkeypatch.setattr(safe_eval_module.signal, "SIGALRM", object(), raising=False)
|
||||
monkeypatch.setattr(safe_eval_module.signal, "ITIMER_REAL", object(), raising=False)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.signal,
|
||||
"getitimer",
|
||||
lambda which: (0.0, 0.0),
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.signal,
|
||||
"getsignal",
|
||||
lambda which: old_handler,
|
||||
)
|
||||
|
||||
def fake_signal(which, handler):
|
||||
calls.append(("signal", handler))
|
||||
|
||||
def fake_setitimer(which, delay, interval=0.0):
|
||||
calls.append(("setitimer", (delay, interval)))
|
||||
return (0.0, 0.0)
|
||||
|
||||
monkeypatch.setattr(safe_eval_module.signal, "signal", fake_signal)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.signal,
|
||||
"setitimer",
|
||||
fake_setitimer,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(safe_eval_module.threading, "main_thread", lambda: main_thread)
|
||||
monkeypatch.setattr(
|
||||
safe_eval_module.threading,
|
||||
"current_thread",
|
||||
lambda: main_thread,
|
||||
)
|
||||
|
||||
with safe_eval_module._execution_timeout(100):
|
||||
pass
|
||||
|
||||
assert calls[0][0] == "signal"
|
||||
assert calls[1] == ("setitimer", (0.1, 0.0))
|
||||
assert calls[2] == ("signal", old_handler)
|
||||
assert calls[3] == ("setitimer", (0.0, 0.0))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unary operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user