414 lines
11 KiB
Python
414 lines
11 KiB
Python
"""
|
|
Code Sandbox for Safe Execution of Dynamic Code.
|
|
|
|
Provides a restricted execution environment for code generated by
|
|
the external planner. This is critical for open-ended planning where
|
|
the planner can create arbitrary code actions.
|
|
|
|
Security measures:
|
|
1. Restricted builtins (no file I/O, no imports of dangerous modules)
|
|
2. Timeout enforcement
|
|
3. Memory limits (via resource module on Unix)
|
|
4. Namespace isolation
|
|
"""
|
|
|
|
import ast
|
|
import signal
|
|
import sys
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
# Safe builtins whitelist
|
|
SAFE_BUILTINS = {
|
|
# Basic types
|
|
"True": True,
|
|
"False": False,
|
|
"None": None,
|
|
# Type constructors
|
|
"bool": bool,
|
|
"int": int,
|
|
"float": float,
|
|
"str": str,
|
|
"list": list,
|
|
"dict": dict,
|
|
"set": set,
|
|
"tuple": tuple,
|
|
"frozenset": frozenset,
|
|
# Basic functions
|
|
"abs": abs,
|
|
"all": all,
|
|
"any": any,
|
|
"bin": bin,
|
|
"chr": chr,
|
|
"divmod": divmod,
|
|
"enumerate": enumerate,
|
|
"filter": filter,
|
|
"format": format,
|
|
"hex": hex,
|
|
"isinstance": isinstance,
|
|
"issubclass": issubclass,
|
|
"iter": iter,
|
|
"len": len,
|
|
"map": map,
|
|
"max": max,
|
|
"min": min,
|
|
"next": next,
|
|
"oct": oct,
|
|
"ord": ord,
|
|
"pow": pow,
|
|
"range": range,
|
|
"repr": repr,
|
|
"reversed": reversed,
|
|
"round": round,
|
|
"slice": slice,
|
|
"sorted": sorted,
|
|
"sum": sum,
|
|
"zip": zip,
|
|
}
|
|
|
|
# Modules that can be imported
|
|
ALLOWED_MODULES = {
|
|
"math",
|
|
"json",
|
|
"re",
|
|
"datetime",
|
|
"collections",
|
|
"itertools",
|
|
"functools",
|
|
"operator",
|
|
"string",
|
|
"random",
|
|
"statistics",
|
|
"decimal",
|
|
"fractions",
|
|
}
|
|
|
|
# Dangerous AST nodes to block
|
|
BLOCKED_AST_NODES = {
|
|
ast.Import,
|
|
ast.ImportFrom,
|
|
ast.Global,
|
|
ast.Nonlocal,
|
|
}
|
|
|
|
|
|
class CodeSandboxError(Exception):
|
|
"""Error during sandboxed code execution."""
|
|
|
|
pass
|
|
|
|
|
|
class TimeoutError(CodeSandboxError):
|
|
"""Code execution timed out."""
|
|
|
|
pass
|
|
|
|
|
|
class SecurityError(CodeSandboxError):
|
|
"""Code contains potentially dangerous operations."""
|
|
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class SandboxResult:
|
|
"""Result of sandboxed code execution."""
|
|
|
|
success: bool
|
|
result: Any = None
|
|
error: str | None = None
|
|
stdout: str = ""
|
|
variables: dict[str, Any] = field(default_factory=dict)
|
|
execution_time_ms: int = 0
|
|
|
|
|
|
class RestrictedImporter:
|
|
"""Custom importer that only allows whitelisted modules."""
|
|
|
|
def __init__(self, allowed_modules: set[str]):
|
|
self.allowed_modules = allowed_modules
|
|
self._cache: dict[str, Any] = {}
|
|
|
|
def __call__(self, name: str, *args, **kwargs):
|
|
if name not in self.allowed_modules:
|
|
raise SecurityError(f"Import of module '{name}' is not allowed")
|
|
|
|
if name not in self._cache:
|
|
import importlib
|
|
|
|
self._cache[name] = importlib.import_module(name)
|
|
|
|
return self._cache[name]
|
|
|
|
|
|
class CodeValidator:
|
|
"""Validates code for safety before execution."""
|
|
|
|
def __init__(self, blocked_nodes: set[type] | None = None):
|
|
self.blocked_nodes = blocked_nodes or BLOCKED_AST_NODES
|
|
|
|
def validate(self, code: str) -> list[str]:
|
|
"""
|
|
Validate code and return list of issues.
|
|
|
|
Returns empty list if code is safe.
|
|
"""
|
|
issues = []
|
|
|
|
try:
|
|
tree = ast.parse(code)
|
|
except SyntaxError as e:
|
|
return [f"Syntax error: {e}"]
|
|
|
|
for node in ast.walk(tree):
|
|
# Check for blocked node types
|
|
if type(node) in self.blocked_nodes:
|
|
lineno = getattr(node, "lineno", "?")
|
|
issues.append(f"Blocked operation: {type(node).__name__} at line {lineno}")
|
|
|
|
# Check for dangerous attribute access
|
|
if isinstance(node, ast.Attribute):
|
|
if node.attr.startswith("_"):
|
|
issues.append(
|
|
f"Access to private attribute '{node.attr}' at line {node.lineno}"
|
|
)
|
|
|
|
# Check for exec/eval calls
|
|
if isinstance(node, ast.Call):
|
|
if isinstance(node.func, ast.Name):
|
|
if node.func.id in ("exec", "eval", "compile", "__import__"):
|
|
issues.append(
|
|
f"Blocked function call: {node.func.id} at line {node.lineno}"
|
|
)
|
|
|
|
return issues
|
|
|
|
|
|
class CodeSandbox:
|
|
"""
|
|
Sandboxed environment for executing dynamic code.
|
|
|
|
Usage:
|
|
sandbox = CodeSandbox(timeout_seconds=5)
|
|
result = sandbox.execute(
|
|
code="x = 1 + 2\\nresult = x * 3",
|
|
inputs={"multiplier": 2},
|
|
)
|
|
if result.success:
|
|
print(result.variables["result"]) # 6
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
timeout_seconds: int = 10,
|
|
allowed_modules: set[str] | None = None,
|
|
safe_builtins: dict[str, Any] | None = None,
|
|
):
|
|
self.timeout_seconds = timeout_seconds
|
|
self.allowed_modules = allowed_modules or ALLOWED_MODULES
|
|
self.safe_builtins = safe_builtins or SAFE_BUILTINS
|
|
self.validator = CodeValidator()
|
|
self.importer = RestrictedImporter(self.allowed_modules)
|
|
|
|
@contextmanager
|
|
def _timeout_context(self, seconds: int):
|
|
"""Context manager for timeout enforcement."""
|
|
|
|
def handler(signum, frame):
|
|
raise TimeoutError(f"Code execution timed out after {seconds} seconds")
|
|
|
|
# Only works on Unix-like systems
|
|
if hasattr(signal, "SIGALRM"):
|
|
old_handler = signal.signal(signal.SIGALRM, handler)
|
|
signal.alarm(seconds)
|
|
try:
|
|
yield
|
|
finally:
|
|
signal.alarm(0)
|
|
signal.signal(signal.SIGALRM, old_handler)
|
|
else:
|
|
# Windows: no timeout support, just execute
|
|
yield
|
|
|
|
def _create_namespace(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
|
"""Create isolated namespace for code execution."""
|
|
namespace = {
|
|
"__builtins__": dict(self.safe_builtins),
|
|
"__import__": self.importer,
|
|
}
|
|
|
|
# Add input variables
|
|
namespace.update(inputs)
|
|
|
|
return namespace
|
|
|
|
def execute(
|
|
self,
|
|
code: str,
|
|
inputs: dict[str, Any] | None = None,
|
|
extract_vars: list[str] | None = None,
|
|
) -> SandboxResult:
|
|
"""
|
|
Execute code in sandbox.
|
|
|
|
Args:
|
|
code: Python code to execute
|
|
inputs: Variables to inject into namespace
|
|
extract_vars: Variable names to extract from namespace after execution
|
|
|
|
Returns:
|
|
SandboxResult with execution outcome
|
|
"""
|
|
import time
|
|
|
|
inputs = inputs or {}
|
|
extract_vars = extract_vars or []
|
|
|
|
# Validate code first
|
|
issues = self.validator.validate(code)
|
|
if issues:
|
|
return SandboxResult(
|
|
success=False,
|
|
error=f"Code validation failed: {'; '.join(issues)}",
|
|
)
|
|
|
|
# Create isolated namespace
|
|
namespace = self._create_namespace(inputs)
|
|
|
|
# Capture stdout
|
|
import io
|
|
|
|
old_stdout = sys.stdout
|
|
sys.stdout = captured_stdout = io.StringIO()
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
with self._timeout_context(self.timeout_seconds):
|
|
# Compile and execute
|
|
compiled = compile(code, "<sandbox>", "exec")
|
|
exec(compiled, namespace)
|
|
|
|
execution_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
# Extract requested variables
|
|
extracted = {}
|
|
for var in extract_vars:
|
|
if var in namespace:
|
|
extracted[var] = namespace[var]
|
|
|
|
# Also extract any new variables (not in inputs or builtins)
|
|
for key, value in namespace.items():
|
|
if key not in inputs and key not in self.safe_builtins and not key.startswith("_"):
|
|
extracted[key] = value
|
|
|
|
return SandboxResult(
|
|
success=True,
|
|
result=namespace.get("result"), # Convention: 'result' is the return value
|
|
stdout=captured_stdout.getvalue(),
|
|
variables=extracted,
|
|
execution_time_ms=execution_time_ms,
|
|
)
|
|
|
|
except TimeoutError as e:
|
|
return SandboxResult(
|
|
success=False,
|
|
error=str(e),
|
|
execution_time_ms=self.timeout_seconds * 1000,
|
|
)
|
|
|
|
except SecurityError as e:
|
|
return SandboxResult(
|
|
success=False,
|
|
error=f"Security violation: {e}",
|
|
execution_time_ms=int((time.time() - start_time) * 1000),
|
|
)
|
|
|
|
except Exception as e:
|
|
return SandboxResult(
|
|
success=False,
|
|
error=f"{type(e).__name__}: {e}",
|
|
stdout=captured_stdout.getvalue(),
|
|
execution_time_ms=int((time.time() - start_time) * 1000),
|
|
)
|
|
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
def execute_expression(
|
|
self,
|
|
expression: str,
|
|
inputs: dict[str, Any] | None = None,
|
|
) -> SandboxResult:
|
|
"""
|
|
Execute a single expression and return its value.
|
|
|
|
Simpler than execute() - just evaluates one expression.
|
|
"""
|
|
inputs = inputs or {}
|
|
|
|
# Validate
|
|
try:
|
|
ast.parse(expression, mode="eval")
|
|
except SyntaxError as e:
|
|
return SandboxResult(success=False, error=f"Syntax error: {e}")
|
|
|
|
namespace = self._create_namespace(inputs)
|
|
|
|
try:
|
|
with self._timeout_context(self.timeout_seconds):
|
|
result = eval(expression, namespace)
|
|
|
|
return SandboxResult(success=True, result=result)
|
|
|
|
except Exception as e:
|
|
return SandboxResult(
|
|
success=False,
|
|
error=f"{type(e).__name__}: {e}",
|
|
)
|
|
|
|
|
|
# Singleton instance with default settings
|
|
default_sandbox = CodeSandbox()
|
|
|
|
|
|
def safe_exec(
|
|
code: str,
|
|
inputs: dict[str, Any] | None = None,
|
|
timeout_seconds: int = 10,
|
|
) -> SandboxResult:
|
|
"""
|
|
Convenience function for safe code execution.
|
|
|
|
Args:
|
|
code: Python code to execute
|
|
inputs: Variables to inject
|
|
timeout_seconds: Max execution time
|
|
|
|
Returns:
|
|
SandboxResult
|
|
"""
|
|
sandbox = CodeSandbox(timeout_seconds=timeout_seconds)
|
|
return sandbox.execute(code, inputs)
|
|
|
|
|
|
def safe_eval(
|
|
expression: str,
|
|
inputs: dict[str, Any] | None = None,
|
|
timeout_seconds: int = 5,
|
|
) -> SandboxResult:
|
|
"""
|
|
Convenience function for safe expression evaluation.
|
|
|
|
Args:
|
|
expression: Python expression to evaluate
|
|
inputs: Variables to inject
|
|
timeout_seconds: Max execution time
|
|
|
|
Returns:
|
|
SandboxResult
|
|
"""
|
|
sandbox = CodeSandbox(timeout_seconds=timeout_seconds)
|
|
return sandbox.execute_expression(expression, inputs)
|