Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 33 additions & 38 deletions tests/test_mcp_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,23 +232,23 @@ def test_injection_in_url_blocked(self):
assert "dangerous pattern" in str(exc_info.value)


class TestUnsafeMode:
"""Tests for unsafe mode (development only)."""
class TestUnsafeModeRemoved:
"""Tests that unsafe mode has been removed (security hardening)."""

def test_unsafe_mode_allows_any_command(self):
"""Test that unsafe mode allows any command (with warning)."""
validator = MCPCommandValidator(allow_unsafe=True)
def test_allow_unsafe_parameter_rejected(self):
"""Test that allow_unsafe parameter is no longer accepted."""
with pytest.raises(TypeError):
MCPCommandValidator(allow_unsafe=True)

# These should not raise even though they're dangerous
validator.validate_command("bash", "test")
validator.validate_command("/bin/sh -c 'dangerous'", "test")
def test_dangerous_commands_always_blocked(self):
"""Test that dangerous commands are always blocked (no bypass)."""
validator = MCPCommandValidator(check_path_exists=False)

def test_unsafe_mode_allows_any_args(self):
"""Test that unsafe mode allows any arguments."""
validator = MCPCommandValidator(allow_unsafe=True)
with pytest.raises(MCPSecurityError):
validator.validate_command("bash", "test")

# These should not raise
validator.validate_args(["; rm -rf /"], "test")
with pytest.raises(MCPSecurityError):
validator.validate_args(["; rm -rf /"], "test")


class TestCustomWhitelist:
Expand Down Expand Up @@ -328,17 +328,16 @@ def test_valid_sse_config(self):
)
assert config.url == "https://api.example.com/mcp"

def test_skip_security_validation(self):
"""Test that skip_security_validation allows any command (with warning)."""
# This should not raise even with dangerous command
config = MCPServerConfig(
name="unsafe-server",
transport=MCPTransport.STDIO,
command="bash",
args=["-c", "echo hello"],
skip_security_validation=True,
)
assert config.command == "bash"
def test_skip_security_validation_removed(self):
"""Test that skip_security_validation field has been removed."""
with pytest.raises(TypeError):
MCPServerConfig(
name="unsafe-server",
transport=MCPTransport.STDIO,
command="bash",
args=["-c", "echo hello"],
skip_security_validation=True,
)


class TestDefaultWhitelist:
Expand Down Expand Up @@ -697,38 +696,34 @@ def test_clear_audit_log(self):


class TestToolSandboxHighRiskTools:
"""Tests for high-risk tool detection."""

def test_high_risk_tool_warning(self, caplog):
"""Test that high-risk tools trigger warning."""
import logging
"""Tests for high-risk tool blocking."""

def test_high_risk_tool_blocked(self):
"""Test that high-risk tools are blocked with MCPSecurityError."""
sandbox = ToolSandbox()

with caplog.at_level(logging.WARNING):
with pytest.raises(MCPSecurityError) as exc_info:
sandbox.validate_tool_execution(
tool_name="execute_command",
server_name="test",
arguments={"cmd": "ls"},
)

assert "High-risk tool detected" in caplog.text
assert "execute" in caplog.text

def test_high_risk_shell_tool(self, caplog):
"""Test that shell tools trigger warning."""
import logging
assert "High-risk tool blocked" in str(exc_info.value)
assert "execute" in str(exc_info.value)

def test_high_risk_shell_tool_blocked(self):
"""Test that shell tools are blocked with MCPSecurityError."""
sandbox = ToolSandbox()

with caplog.at_level(logging.WARNING):
with pytest.raises(MCPSecurityError) as exc_info:
sandbox.validate_tool_execution(
tool_name="run_shell",
server_name="test",
arguments={},
)

assert "High-risk tool detected" in caplog.text
assert "High-risk tool blocked" in str(exc_info.value)


class TestCustomBlockedPatterns:
Expand Down
34 changes: 16 additions & 18 deletions tests/test_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,39 +184,37 @@ def test_save_frames_to_temp(self, test_video_path):
class TestImageProcessing:
"""Test image processing functions."""

def test_process_image_input_local_file(self, test_image_path):
"""Test processing local image file."""
def test_process_image_input_local_file_rejected(self, test_image_path):
"""Test that local file paths are rejected (path traversal prevention)."""
from vllm_mlx.models.mllm import process_image_input

result = process_image_input(test_image_path)
assert result == test_image_path
with pytest.raises(ValueError):
process_image_input(test_image_path)

def test_process_image_input_dict_format(self, test_image_path):
"""Test processing image in dict format."""
def test_process_image_input_dict_local_file_rejected(self, test_image_path):
"""Test that local file paths in dict format are rejected."""
from vllm_mlx.models.mllm import process_image_input

# OpenAI format
result = process_image_input({"url": test_image_path})
assert Path(result).exists()
with pytest.raises(ValueError):
process_image_input({"url": test_image_path})


class TestVideoProcessing:
"""Test video processing functions."""

def test_process_video_input_local_file(self, test_video_path):
"""Test processing local video file."""
def test_process_video_input_local_file_rejected(self, test_video_path):
"""Test that local file paths are rejected (path traversal prevention)."""
from vllm_mlx.models.mllm import process_video_input

result = process_video_input(test_video_path)
assert result == test_video_path
with pytest.raises(ValueError):
process_video_input(test_video_path)

def test_process_video_input_dict_format(self, test_video_path):
"""Test processing video in dict format."""
def test_process_video_input_dict_local_file_rejected(self, test_video_path):
"""Test that local file paths in dict format are rejected."""
from vllm_mlx.models.mllm import process_video_input

# OpenAI format
result = process_video_input({"url": test_video_path})
assert Path(result).exists()
with pytest.raises(ValueError):
process_video_input({"url": test_video_path})

def test_process_video_input_empty_raises(self):
"""Test that empty input raises error."""
Expand Down
2 changes: 1 addition & 1 deletion vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class BatchedEngine(BaseEngine):
def __init__(
self,
model_name: str,
trust_remote_code: bool = True,
trust_remote_code: bool = False,
scheduler_config: Any | None = None,
stream_interval: int = 1,
force_mllm: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SimpleEngine(BaseEngine):
def __init__(
self,
model_name: str,
trust_remote_code: bool = True,
trust_remote_code: bool = False,
enable_cache: bool = True,
force_mllm: bool = False,
mtp: bool = False,
Expand Down
2 changes: 0 additions & 2 deletions vllm_mlx/mcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# Default config search paths
CONFIG_SEARCH_PATHS = [
"./mcp.json",
"./mcp.yaml",
"~/.config/vllm-mlx/mcp.json",
"~/.config/vllm-mlx/mcp.yaml",
]
Expand Down
50 changes: 20 additions & 30 deletions vllm_mlx/mcp/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@
re.compile(r"\|\|\s*"), # Command chaining with ||
re.compile(r"`"), # Backtick command substitution
re.compile(r"\$\("), # $() command substitution
re.compile(r"\$\{"), # ${} variable expansion
re.compile(r">\s*"), # Output redirection
re.compile(r"<\s*"), # Input redirection
re.compile(r"\.\./"), # Path traversal
re.compile(r"~"), # Home directory expansion (can be abused)
re.compile(r"[\n\r]"), # Newline command injection
]

# Dangerous argument patterns
Expand All @@ -73,8 +75,13 @@
re.compile(r"\$\{"),
re.compile(r">\s*/"), # Redirect to absolute path
re.compile(r"<\s*/"), # Read from absolute path
re.compile(r"[\n\r]"), # Newline command injection
]

# Interpreter commands that can execute inline code
INTERPRETER_COMMANDS = {"python", "python3", "node", "ruby", "perl"}
DANGEROUS_INTERPRETER_ARGS = {"-c", "--command", "-e", "--eval", "exec", "eval"}


class MCPSecurityError(Exception):
"""Raised when MCP security validation fails."""
Expand All @@ -93,7 +100,6 @@ class MCPCommandValidator:
def __init__(
self,
allowed_commands: Optional[Set[str]] = None,
allow_unsafe: bool = False,
custom_whitelist: Optional[Set[str]] = None,
check_path_exists: bool = True,
):
Expand All @@ -102,25 +108,15 @@ def __init__(

Args:
allowed_commands: Set of allowed command names. If None, uses default whitelist.
allow_unsafe: If True, allows any command (for development only).
WARNING: This disables security checks!
custom_whitelist: Additional commands to allow beyond the default whitelist.
check_path_exists: If True, verify command exists in PATH. Set to False for testing.
"""
self.allow_unsafe = allow_unsafe
self.allowed_commands = allowed_commands or ALLOWED_COMMANDS.copy()
self.check_path_exists = check_path_exists

if custom_whitelist:
self.allowed_commands.update(custom_whitelist)

if allow_unsafe:
logger.warning(
"MCP SECURITY WARNING: Unsafe mode enabled. "
"All commands will be allowed without validation. "
"This should NEVER be used in production!"
)

def validate_command(self, command: str, server_name: str) -> None:
"""
Validate that a command is safe to execute.
Expand All @@ -132,13 +128,6 @@ def validate_command(self, command: str, server_name: str) -> None:
Raises:
MCPSecurityError: If the command is not allowed
"""
if self.allow_unsafe:
logger.warning(
f"MCP security bypassed for server '{server_name}': "
f"allowing command '{command}' (unsafe mode)"
)
return

# Check for dangerous patterns in command
for pattern in DANGEROUS_PATTERNS:
if pattern.search(command):
Expand Down Expand Up @@ -193,9 +182,6 @@ def validate_args(self, args: List[str], server_name: str) -> None:
Raises:
MCPSecurityError: If any argument contains dangerous patterns
"""
if self.allow_unsafe:
return

for i, arg in enumerate(args):
for pattern in DANGEROUS_ARG_PATTERNS:
if pattern.search(arg):
Expand All @@ -204,6 +190,14 @@ def validate_args(self, args: List[str], server_name: str) -> None:
f"pattern: '{arg}'. Potential command injection blocked."
)

# Block dangerous interpreter flags (e.g. python -c, node -e)
for arg in args:
if arg in DANGEROUS_INTERPRETER_ARGS:
raise MCPSecurityError(
f"MCP server '{server_name}': Argument '{arg}' blocked — "
f"inline code execution not allowed"
)

logger.debug(
f"MCP server '{server_name}': {len(args)} arguments validated successfully"
)
Expand All @@ -219,7 +213,7 @@ def validate_env(self, env: Optional[Dict[str, str]], server_name: str) -> None:
Raises:
MCPSecurityError: If any env var contains dangerous patterns
"""
if self.allow_unsafe or not env:
if not env:
return

# Dangerous environment variables that could affect execution
Expand Down Expand Up @@ -264,9 +258,6 @@ def validate_url(self, url: str, server_name: str) -> None:
Raises:
MCPSecurityError: If the URL is not safe
"""
if self.allow_unsafe:
return

# Must be http or https
if not url.startswith(("http://", "https://")):
raise MCPSecurityError(
Expand Down Expand Up @@ -497,15 +488,14 @@ def _is_blocked(self, tool_name: str, full_name: str) -> bool:
)

def _check_high_risk_tool(self, tool_name: str) -> None:
"""Check if tool matches high-risk patterns."""
"""Check if tool matches high-risk patterns and block it."""
tool_lower = tool_name.lower()
for pattern in HIGH_RISK_TOOL_PATTERNS:
if pattern in tool_lower:
logger.warning(
f"High-risk tool detected: '{tool_name}' matches pattern '{pattern}'. "
f"Ensure this tool is from a trusted MCP server."
raise MCPSecurityError(
f"High-risk tool blocked: '{tool_name}' matches pattern '{pattern}'. "
f"Add to allowed_tools explicitly to permit."
)
break

def _validate_arguments(self, tool_name: str, arguments: Dict[str, Any]) -> None:
"""Validate tool arguments for dangerous patterns."""
Expand Down
14 changes: 2 additions & 12 deletions vllm_mlx/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class MCPServerConfig:
enabled: bool = True
timeout: float = 30.0

# Security options
skip_security_validation: bool = False # WARNING: Only for development!
# Security options removed: skip_security_validation was a bypass risk

def __post_init__(self):
"""Validate configuration."""
Expand All @@ -67,16 +66,7 @@ def __post_init__(self):

def _validate_security(self) -> None:
"""Validate security of the configuration."""
from .security import validate_mcp_server_config, MCPSecurityError

if self.skip_security_validation:
import logging

logging.getLogger(__name__).warning(
f"MCP server '{self.name}': Security validation SKIPPED. "
f"This is dangerous and should only be used in development!"
)
return
from .security import MCPSecurityError, validate_mcp_server_config

try:
validate_mcp_server_config(
Expand Down
Loading
Loading