diff --git a/tests/test_mcp_security.py b/tests/test_mcp_security.py index 4db122e89..f84195ddb 100644 --- a/tests/test_mcp_security.py +++ b/tests/test_mcp_security.py @@ -232,23 +232,23 @@ def test_injection_in_url_blocked(self): assert "dangerous pattern" in str(exc_info.value) -class TestUnsafeMode: - """Tests for unsafe mode (development only).""" +class TestUnsafeModeRemoved: + """Tests that unsafe mode has been removed (security hardening).""" - def test_unsafe_mode_allows_any_command(self): - """Test that unsafe mode allows any command (with warning).""" - validator = MCPCommandValidator(allow_unsafe=True) + def test_allow_unsafe_parameter_rejected(self): + """Test that allow_unsafe parameter is no longer accepted.""" + with pytest.raises(TypeError): + MCPCommandValidator(allow_unsafe=True) - # These should not raise even though they're dangerous - validator.validate_command("bash", "test") - validator.validate_command("/bin/sh -c 'dangerous'", "test") + def test_dangerous_commands_always_blocked(self): + """Test that dangerous commands are always blocked (no bypass).""" + validator = MCPCommandValidator(check_path_exists=False) - def test_unsafe_mode_allows_any_args(self): - """Test that unsafe mode allows any arguments.""" - validator = MCPCommandValidator(allow_unsafe=True) + with pytest.raises(MCPSecurityError): + validator.validate_command("bash", "test") - # These should not raise - validator.validate_args(["; rm -rf /"], "test") + with pytest.raises(MCPSecurityError): + validator.validate_args(["; rm -rf /"], "test") class TestCustomWhitelist: @@ -328,17 +328,16 @@ def test_valid_sse_config(self): ) assert config.url == "https://api.example.com/mcp" - def test_skip_security_validation(self): - """Test that skip_security_validation allows any command (with warning).""" - # This should not raise even with dangerous command - config = MCPServerConfig( - name="unsafe-server", - transport=MCPTransport.STDIO, - command="bash", - args=["-c", "echo hello"], - skip_security_validation=True, - ) - assert config.command == "bash" + def test_skip_security_validation_removed(self): + """Test that skip_security_validation field has been removed.""" + with pytest.raises(TypeError): + MCPServerConfig( + name="unsafe-server", + transport=MCPTransport.STDIO, + command="bash", + args=["-c", "echo hello"], + skip_security_validation=True, + ) class TestDefaultWhitelist: @@ -697,38 +696,34 @@ def test_clear_audit_log(self): class TestToolSandboxHighRiskTools: - """Tests for high-risk tool detection.""" - - def test_high_risk_tool_warning(self, caplog): - """Test that high-risk tools trigger warning.""" - import logging + """Tests for high-risk tool blocking.""" + def test_high_risk_tool_blocked(self): + """Test that high-risk tools are blocked with MCPSecurityError.""" sandbox = ToolSandbox() - with caplog.at_level(logging.WARNING): + with pytest.raises(MCPSecurityError) as exc_info: sandbox.validate_tool_execution( tool_name="execute_command", server_name="test", arguments={"cmd": "ls"}, ) - assert "High-risk tool detected" in caplog.text - assert "execute" in caplog.text - - def test_high_risk_shell_tool(self, caplog): - """Test that shell tools trigger warning.""" - import logging + assert "High-risk tool blocked" in str(exc_info.value) + assert "execute" in str(exc_info.value) + def test_high_risk_shell_tool_blocked(self): + """Test that shell tools are blocked with MCPSecurityError.""" sandbox = ToolSandbox() - with caplog.at_level(logging.WARNING): + with pytest.raises(MCPSecurityError) as exc_info: sandbox.validate_tool_execution( tool_name="run_shell", server_name="test", arguments={}, ) - assert "High-risk tool detected" in caplog.text + assert "High-risk tool blocked" in str(exc_info.value) class TestCustomBlockedPatterns: diff --git a/tests/test_mllm.py b/tests/test_mllm.py index ea9ee6593..a41fc1839 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -184,39 +184,37 @@ def test_save_frames_to_temp(self, test_video_path): class TestImageProcessing: """Test image processing functions.""" - def test_process_image_input_local_file(self, test_image_path): - """Test processing local image file.""" + def test_process_image_input_local_file_rejected(self, test_image_path): + """Test that local file paths are rejected (path traversal prevention).""" from vllm_mlx.models.mllm import process_image_input - result = process_image_input(test_image_path) - assert result == test_image_path + with pytest.raises(ValueError): + process_image_input(test_image_path) - def test_process_image_input_dict_format(self, test_image_path): - """Test processing image in dict format.""" + def test_process_image_input_dict_local_file_rejected(self, test_image_path): + """Test that local file paths in dict format are rejected.""" from vllm_mlx.models.mllm import process_image_input - # OpenAI format - result = process_image_input({"url": test_image_path}) - assert Path(result).exists() + with pytest.raises(ValueError): + process_image_input({"url": test_image_path}) class TestVideoProcessing: """Test video processing functions.""" - def test_process_video_input_local_file(self, test_video_path): - """Test processing local video file.""" + def test_process_video_input_local_file_rejected(self, test_video_path): + """Test that local file paths are rejected (path traversal prevention).""" from vllm_mlx.models.mllm import process_video_input - result = process_video_input(test_video_path) - assert result == test_video_path + with pytest.raises(ValueError): + process_video_input(test_video_path) - def test_process_video_input_dict_format(self, test_video_path): - """Test processing video in dict format.""" + def test_process_video_input_dict_local_file_rejected(self, test_video_path): + """Test that local file paths in dict format are rejected.""" from vllm_mlx.models.mllm import process_video_input - # OpenAI format - result = process_video_input({"url": test_video_path}) - assert Path(result).exists() + with pytest.raises(ValueError): + process_video_input({"url": test_video_path}) def test_process_video_input_empty_raises(self): """Test that empty input raises error.""" diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628e..b2b244130 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -133,7 +133,7 @@ class BatchedEngine(BaseEngine): def __init__( self, model_name: str, - trust_remote_code: bool = True, + trust_remote_code: bool = False, scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index e96317ef0..85c408b70 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -52,7 +52,7 @@ class SimpleEngine(BaseEngine): def __init__( self, model_name: str, - trust_remote_code: bool = True, + trust_remote_code: bool = False, enable_cache: bool = True, force_mllm: bool = False, mtp: bool = False, diff --git a/vllm_mlx/mcp/config.py b/vllm_mlx/mcp/config.py index cdf3cc878..12ce75f1b 100644 --- a/vllm_mlx/mcp/config.py +++ b/vllm_mlx/mcp/config.py @@ -15,8 +15,6 @@ # Default config search paths CONFIG_SEARCH_PATHS = [ - "./mcp.json", - "./mcp.yaml", "~/.config/vllm-mlx/mcp.json", "~/.config/vllm-mlx/mcp.yaml", ] diff --git a/vllm_mlx/mcp/security.py b/vllm_mlx/mcp/security.py index 9d98006c8..ecfeac07e 100644 --- a/vllm_mlx/mcp/security.py +++ b/vllm_mlx/mcp/security.py @@ -56,10 +56,12 @@ re.compile(r"\|\|\s*"), # Command chaining with || re.compile(r"`"), # Backtick command substitution re.compile(r"\$\("), # $() command substitution + re.compile(r"\$\{"), # ${} variable expansion re.compile(r">\s*"), # Output redirection re.compile(r"<\s*"), # Input redirection re.compile(r"\.\./"), # Path traversal re.compile(r"~"), # Home directory expansion (can be abused) + re.compile(r"[\n\r]"), # Newline command injection ] # Dangerous argument patterns @@ -73,8 +75,13 @@ re.compile(r"\$\{"), re.compile(r">\s*/"), # Redirect to absolute path re.compile(r"<\s*/"), # Read from absolute path + re.compile(r"[\n\r]"), # Newline command injection ] +# Interpreter commands that can execute inline code +INTERPRETER_COMMANDS = {"python", "python3", "node", "ruby", "perl"} +DANGEROUS_INTERPRETER_ARGS = {"-c", "--command", "-e", "--eval", "exec", "eval"} + class MCPSecurityError(Exception): """Raised when MCP security validation fails.""" @@ -93,7 +100,6 @@ class MCPCommandValidator: def __init__( self, allowed_commands: Optional[Set[str]] = None, - allow_unsafe: bool = False, custom_whitelist: Optional[Set[str]] = None, check_path_exists: bool = True, ): @@ -102,25 +108,15 @@ def __init__( Args: allowed_commands: Set of allowed command names. If None, uses default whitelist. - allow_unsafe: If True, allows any command (for development only). - WARNING: This disables security checks! custom_whitelist: Additional commands to allow beyond the default whitelist. check_path_exists: If True, verify command exists in PATH. Set to False for testing. """ - self.allow_unsafe = allow_unsafe self.allowed_commands = allowed_commands or ALLOWED_COMMANDS.copy() self.check_path_exists = check_path_exists if custom_whitelist: self.allowed_commands.update(custom_whitelist) - if allow_unsafe: - logger.warning( - "MCP SECURITY WARNING: Unsafe mode enabled. " - "All commands will be allowed without validation. " - "This should NEVER be used in production!" - ) - def validate_command(self, command: str, server_name: str) -> None: """ Validate that a command is safe to execute. @@ -132,13 +128,6 @@ def validate_command(self, command: str, server_name: str) -> None: Raises: MCPSecurityError: If the command is not allowed """ - if self.allow_unsafe: - logger.warning( - f"MCP security bypassed for server '{server_name}': " - f"allowing command '{command}' (unsafe mode)" - ) - return - # Check for dangerous patterns in command for pattern in DANGEROUS_PATTERNS: if pattern.search(command): @@ -193,9 +182,6 @@ def validate_args(self, args: List[str], server_name: str) -> None: Raises: MCPSecurityError: If any argument contains dangerous patterns """ - if self.allow_unsafe: - return - for i, arg in enumerate(args): for pattern in DANGEROUS_ARG_PATTERNS: if pattern.search(arg): @@ -204,6 +190,14 @@ def validate_args(self, args: List[str], server_name: str) -> None: f"pattern: '{arg}'. Potential command injection blocked." ) + # Block dangerous interpreter flags (e.g. python -c, node -e) + for arg in args: + if arg in DANGEROUS_INTERPRETER_ARGS: + raise MCPSecurityError( + f"MCP server '{server_name}': Argument '{arg}' blocked — " + f"inline code execution not allowed" + ) + logger.debug( f"MCP server '{server_name}': {len(args)} arguments validated successfully" ) @@ -219,7 +213,7 @@ def validate_env(self, env: Optional[Dict[str, str]], server_name: str) -> None: Raises: MCPSecurityError: If any env var contains dangerous patterns """ - if self.allow_unsafe or not env: + if not env: return # Dangerous environment variables that could affect execution @@ -264,9 +258,6 @@ def validate_url(self, url: str, server_name: str) -> None: Raises: MCPSecurityError: If the URL is not safe """ - if self.allow_unsafe: - return - # Must be http or https if not url.startswith(("http://", "https://")): raise MCPSecurityError( @@ -497,15 +488,14 @@ def _is_blocked(self, tool_name: str, full_name: str) -> bool: ) def _check_high_risk_tool(self, tool_name: str) -> None: - """Check if tool matches high-risk patterns.""" + """Check if tool matches high-risk patterns and block it.""" tool_lower = tool_name.lower() for pattern in HIGH_RISK_TOOL_PATTERNS: if pattern in tool_lower: - logger.warning( - f"High-risk tool detected: '{tool_name}' matches pattern '{pattern}'. " - f"Ensure this tool is from a trusted MCP server." + raise MCPSecurityError( + f"High-risk tool blocked: '{tool_name}' matches pattern '{pattern}'. " + f"Add to allowed_tools explicitly to permit." ) - break def _validate_arguments(self, tool_name: str, arguments: Dict[str, Any]) -> None: """Validate tool arguments for dangerous patterns.""" diff --git a/vllm_mlx/mcp/types.py b/vllm_mlx/mcp/types.py index 4eb370b2f..8160da85b 100644 --- a/vllm_mlx/mcp/types.py +++ b/vllm_mlx/mcp/types.py @@ -43,8 +43,7 @@ class MCPServerConfig: enabled: bool = True timeout: float = 30.0 - # Security options - skip_security_validation: bool = False # WARNING: Only for development! + # Security options removed: skip_security_validation was a bypass risk def __post_init__(self): """Validate configuration.""" @@ -67,16 +66,7 @@ def __post_init__(self): def _validate_security(self) -> None: """Validate security of the configuration.""" - from .security import validate_mcp_server_config, MCPSecurityError - - if self.skip_security_validation: - import logging - - logging.getLogger(__name__).warning( - f"MCP server '{self.name}': Security validation SKIPPED. " - f"This is dangerous and should only be used in development!" - ) - return + from .security import MCPSecurityError, validate_mcp_server_config try: validate_mcp_server_config( diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 3a9090b1e..ce5083472 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -33,6 +33,30 @@ logger = logging.getLogger(__name__) +def _validate_url_safety(url: str) -> None: + """Block requests to private/internal IPs and cloud metadata endpoints.""" + import ipaddress + import socket + + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + raise ValueError(f"Invalid URL: {url}") + + # Block known dangerous hostnames + if hostname in ("localhost", "localhost.localdomain"): + raise ValueError(f"Blocked request to localhost: {url}") + + # Resolve and check IP + try: + ip = ipaddress.ip_address(socket.gethostbyname(hostname)) + except (socket.gaierror, ValueError): + return # Can't resolve — let requests handle it + + if ip.is_private or ip.is_loopback or ip.is_link_local: + raise ValueError(f"Blocked request to private/internal IP ({ip}): {url}") + + class TempFileManager: """Thread-safe manager for tracking and cleaning up temporary files.""" @@ -196,7 +220,10 @@ def download_image(url: str, timeout: int = 30, max_size: int = MAX_IMAGE_SIZE) Raises: FileSizeExceededError: If image exceeds max_size + ValueError: If URL targets private/internal network """ + _validate_url_safety(url) + headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" } @@ -284,7 +311,10 @@ def download_video(url: str, timeout: int = 120, max_size: int = MAX_VIDEO_SIZE) Raises: FileSizeExceededError: If video exceeds max_size + ValueError: If URL targets private/internal network """ + _validate_url_safety(url) + headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" } @@ -442,10 +472,6 @@ def process_video_input(video: str | dict) -> str: if not video: raise ValueError("Empty video input") - # Check if it's a local file - if Path(video).exists(): - return video - # Check if it's a URL if is_url(video): return download_video(video) @@ -525,10 +551,6 @@ def process_image_input(image: str | dict) -> str: if is_url(image): return download_image(image) - # Check if it's a local file (only for short strings that could be paths) - if len(image) < 4096 and Path(image).exists(): - return image - raise ValueError(f"Cannot process image: {image[:50]}...") @@ -687,7 +709,7 @@ class MLXMultimodalLM: def __init__( self, model_name: str, - trust_remote_code: bool = True, + trust_remote_code: bool = False, enable_cache: bool = True, cache_size: int = 50, ): diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index 41679c0ba..7c0d5db76 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -48,6 +48,9 @@ def get_finish_reason(status: "RequestStatus") -> Optional[str]: return None +MAX_TOKENS_LIMIT = 131072 # 128K hard upper bound + + @dataclass class SamplingParams: """Sampling parameters for text generation.""" @@ -66,6 +69,8 @@ def __post_init__(self): self.stop = [] if self.stop_token_ids is None: self.stop_token_ids = [] + if self.max_tokens > MAX_TOKENS_LIMIT: + self.max_tokens = MAX_TOKENS_LIMIT @dataclass diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index cf3e66596..4de407176 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -260,6 +260,8 @@ def __init__(self, requests_per_minute: int = 60, enabled: bool = False): self.window_size = 60.0 # 1 minute window self._requests: dict[str, list[float]] = defaultdict(list) self._lock = threading.Lock() + self._last_cleanup = time.time() + self._cleanup_interval = 300 # 5 minutes def is_allowed(self, client_id: str) -> tuple[bool, int]: """ @@ -275,6 +277,17 @@ def is_allowed(self, client_id: str) -> tuple[bool, int]: window_start = current_time - self.window_size with self._lock: + # Periodic cleanup of stale clients + if current_time - self._last_cleanup > self._cleanup_interval: + self._last_cleanup = current_time + stale = [ + k + for k, v in self._requests.items() + if not v or max(v) < window_start + ] + for k in stale: + del self._requests[k] + # Clean old requests outside window self._requests[client_id] = [ t for t in self._requests[client_id] if t > window_start @@ -295,6 +308,10 @@ def is_allowed(self, client_id: str) -> tuple[bool, int]: # Global rate limiter (disabled by default) _rate_limiter = RateLimiter(requests_per_minute=60, enabled=False) +# Audio/TTS limits +MAX_AUDIO_UPLOAD_SIZE = 100 * 1024 * 1024 # 100MB +MAX_TTS_INPUT_LENGTH = 10_000 # characters + async def check_rate_limit(request: Request): """Rate limiting dependency.""" @@ -490,6 +507,7 @@ def load_model( specprefill_threshold: int = 8192, specprefill_keep_pct: float = 0.3, specprefill_draft_model: str = None, + trust_remote_code: bool = False, ): """ Load a model (auto-detects MLLM vs LLM). @@ -507,6 +525,7 @@ def load_model( specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill (default: 8192) specprefill_keep_pct: Fraction of tokens to keep (default: 0.3) specprefill_draft_model: Path to small draft model for SpecPrefill scoring + trust_remote_code: Allow execution of custom code from model repos """ global _engine, _model_name, _model_path, _default_max_tokens, _tool_parser_instance @@ -523,6 +542,7 @@ def load_model( logger.info(f"Loading model with BatchedEngine: {model_name}") _engine = BatchedEngine( model_name=model_name, + trust_remote_code=trust_remote_code, scheduler_config=scheduler_config, stream_interval=stream_interval, force_mllm=force_mllm, @@ -534,6 +554,7 @@ def load_model( logger.info(f"Loading model with SimpleEngine: {model_name}") _engine = SimpleEngine( model_name=model_name, + trust_remote_code=trust_remote_code, force_mllm=force_mllm, mtp=mtp, prefill_step_size=prefill_step_size, @@ -601,7 +622,7 @@ async def health(): } -@app.get("/v1/status") +@app.get("/v1/status", dependencies=[Depends(verify_api_key)]) async def status(): """Real-time status with per-request details for debugging and monitoring.""" if _engine is None: @@ -631,7 +652,7 @@ async def status(): } -@app.get("/v1/cache/stats") +@app.get("/v1/cache/stats", dependencies=[Depends(verify_api_key)]) async def cache_stats(): """Get cache statistics for debugging and monitoring.""" try: @@ -650,7 +671,7 @@ async def cache_stats(): return {"error": "Cache stats not available (mlx_vlm not loaded)"} -@app.delete("/v1/cache") +@app.delete("/v1/cache", dependencies=[Depends(verify_api_key)]) async def clear_cache(): """Clear all caches.""" try: @@ -802,8 +823,8 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: except HTTPException: raise except Exception as e: - logger.error(f"Embedding generation failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"Embedding generation failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") # ============================================================================= @@ -860,6 +881,24 @@ async def execute_mcp_tool(request: MCPExecuteRequest) -> MCPExecuteResponse: status_code=503, detail="MCP not configured. Start server with --mcp-config" ) + # Validate through sandbox before execution + from .mcp.security import MCPSecurityError, get_sandbox + + sandbox = get_sandbox() + try: + # Extract server name from tool_name (format: server__tool) + parts = request.tool_name.split("__", 1) + server_name = parts[0] if len(parts) > 1 else "unknown" + tool_name = parts[1] if len(parts) > 1 else request.tool_name + + sandbox.validate_tool_execution( + tool_name=tool_name, + server_name=server_name, + arguments=request.arguments or {}, + ) + except MCPSecurityError as e: + raise HTTPException(status_code=403, detail=str(e)) + result = await _mcp_manager.execute_tool( request.tool_name, request.arguments, @@ -912,16 +951,26 @@ async def create_transcription( "parakeet": "mlx-community/parakeet-tdt-0.6b-v2", "parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3", } - model_name = model_map.get(model, model) + model_name = model_map.get(model) + if model_name is None: + raise HTTPException( + status_code=400, + detail=f"Unknown STT model '{model}'. Available: {list(model_map.keys())}", + ) # Load engine if needed if _stt_engine is None or _stt_engine.model_name != model_name: _stt_engine = STTEngine(model_name) _stt_engine.load() - # Save uploaded file temporarily + # Save uploaded file temporarily (with size limit) + content = await file.read() + if len(content) > MAX_AUDIO_UPLOAD_SIZE: + raise HTTPException( + status_code=413, + detail=f"Audio file too large (max {MAX_AUDIO_UPLOAD_SIZE // 1024 // 1024}MB)", + ) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() tmp.write(content) tmp_path = tmp.name @@ -945,8 +994,8 @@ async def create_transcription( detail="mlx-audio not installed. Install with: pip install mlx-audio", ) except Exception as e: - logger.error(f"Transcription failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"Transcription failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") @app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)]) @@ -968,6 +1017,13 @@ async def create_speech( """ global _tts_engine + # Validate TTS input length + if len(input) > MAX_TTS_INPUT_LENGTH: + raise HTTPException( + status_code=400, + detail=f"TTS input too long (max {MAX_TTS_INPUT_LENGTH} chars)", + ) + try: from .audio.tts import TTSEngine # Lazy import - optional feature @@ -980,7 +1036,12 @@ async def create_speech( "vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit", "voxcpm": "mlx-community/VoxCPM1.5", } - model_name = model_map.get(model, model) + model_name = model_map.get(model) + if model_name is None: + raise HTTPException( + status_code=400, + detail=f"Unknown TTS model '{model}'. Available: {list(model_map.keys())}", + ) # Load engine if needed if _tts_engine is None or _tts_engine.model_name != model_name: @@ -1001,8 +1062,8 @@ async def create_speech( detail="mlx-audio not installed. Install with: pip install mlx-audio", ) except Exception as e: - logger.error(f"TTS generation failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"TTS generation failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") @app.get("/v1/audio/voices", dependencies=[Depends(verify_api_key)]) @@ -1518,7 +1579,10 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: # ============================================================================= -@app.post("/v1/messages") +@app.post( + "/v1/messages", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) async def create_anthropic_message( request: Request, ): @@ -1645,7 +1709,10 @@ async def create_anthropic_message( ) -@app.post("/v1/messages/count_tokens") +@app.post( + "/v1/messages/count_tokens", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) async def count_anthropic_tokens(request: Request): """ Count tokens for an Anthropic Messages API request. @@ -2195,8 +2262,8 @@ def main(): parser.add_argument( "--host", type=str, - default="0.0.0.0", - help="Host to bind to", + default="127.0.0.1", + help="Host to bind to (default: 127.0.0.1; use 0.0.0.0 to expose to network)", ) parser.add_argument( "--port", @@ -2276,6 +2343,11 @@ def main(): default=None, help="Default top_p for generation when not specified in request", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow execution of custom code from HuggingFace model repos (security risk, use only with trusted models)", + ) args = parser.parse_args() @@ -2296,6 +2368,13 @@ def main(): f"Rate limiting enabled: {args.rate_limit} requests/minute per client" ) + # Warn about network exposure without auth + if args.host == "0.0.0.0" and not args.api_key: + logger.warning( + "Server binding to 0.0.0.0 (all interfaces) without --api-key. " + "This exposes the server to the network without authentication." + ) + # Security summary at startup logger.info("=" * 60) logger.info("SECURITY CONFIGURATION") @@ -2309,6 +2388,7 @@ def main(): else: logger.warning(" Rate limiting: DISABLED - Use --rate-limit to enable") logger.info(f" Request timeout: {args.timeout}s") + logger.info(f" Trust remote code: {args.trust_remote_code}") logger.info("=" * 60) # Set MCP config for lifespan @@ -2333,6 +2413,7 @@ def main(): use_batching=args.continuous_batching, max_tokens=args.max_tokens, force_mllm=args.mllm, + trust_remote_code=args.trust_remote_code, ) # Start server