Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/features/multimodal_inputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.

!!! tip
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.

## Offline Inference

To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
Expand Down
6 changes: 6 additions & 0 deletions docs/usage/security.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
- Implement proper authentication and authorization for management interfaces
- Follow the principle of least privilege for all system components

### 4. **Restrict Domains Access for Media URLs:**

Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)

## Security and Firewalls: Protecting Exposed vLLM Systems

While vLLM is designed to allow unsafe network services to be isolated to
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_lora_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class MockModelConfig:
logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class MockModelConfig:
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
Expand Down
33 changes: 32 additions & 1 deletion tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
raw_image_url: str, suffix: str):
connector = MediaConnector()
connector = MediaConnector(
# Domain restriction should not apply to data URLs.
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
url_image = url_images[raw_image_url]

try:
Expand Down Expand Up @@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
modality_idxs = argsort_mm_positions(mm_positions)

assert modality_idxs == expected_modality_idxs


@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}},
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])

video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async

disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
with pytest.raises(ValueError):
_, _ = connector.fetch_video(disallowed_url)

with pytest.raises(ValueError):
_, _ = await connector.fetch_video_async(disallowed_url)
3 changes: 3 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class ModelConfig:
"""Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only
be enabled in trusted environments."""
allowed_media_domains: Optional[list[str]] = None
"""If set, only media URLs that belong to this domain can be used for
multi-modal inputs. """
revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version."""
Expand Down
2 changes: 2 additions & 0 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def __post_init__(self):
trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class EngineArgs:
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: Optional[
list[str]] = ModelConfig.allowed_media_domains
download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format
Expand Down Expand Up @@ -531,6 +533,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"])
model_group.add_argument("--allowed-media-domains",
**model_kwargs["allowed_media_domains"])
model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision",
**model_kwargs["code_revision"])
Expand Down Expand Up @@ -997,6 +1001,7 @@ def create_model_config(self) -> ModelConfig:
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,10 @@ def model_cls(self) -> type[SupportsMultiModal]:
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path

@property
def allowed_media_domains(self):
return self._model_config.allowed_media_domains

@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
Expand Down Expand Up @@ -832,6 +836,7 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
)

def parse_image(
Expand Down Expand Up @@ -916,6 +921,7 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
)

def parse_image(
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class LLM:
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
Expand Down Expand Up @@ -169,6 +171,7 @@ def __init__(
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
tensor_parallel_size: int = 1,
dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None,
Expand Down Expand Up @@ -264,6 +267,7 @@ def __init__(
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
allowed_media_domains=allowed_media_domains,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
Expand Down
16 changes: 16 additions & 0 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -82,6 +83,9 @@ def __init__(
allowed_local_media_path_ = None

self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
Comment on lines +86 to +88
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For performance, it's better to convert allowed_media_domains to a set during initialization. Checking for an item's existence is O(1) on average for a set, while it's O(n) for a list. This can be significant if the list of allowed domains is large.

Suggested change
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
if allowed_media_domains is None:
self.allowed_media_domains = set()
else:
self.allowed_media_domains = set(allowed_media_domains)


def _load_data_url(
self,
Expand Down Expand Up @@ -115,6 +119,14 @@ def _load_file_url(

return media_io.load_file(filepath)

def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if self.allowed_media_domains and url_spec.hostname not in \
self.allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}")

def load_from_url(
self,
url: str,
Expand All @@ -125,6 +137,8 @@ def load_from_url(
url_spec = urlparse(url)

if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
Comment on lines 127 to +140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation checks the domain of the initial URL, but it does not prevent SSRF vulnerabilities that arise from HTTP redirects. If an allowed URL redirects to a URL on a disallowed domain (including internal network addresses), connection.get_bytes might still fetch it if it follows redirects by default, which is mentioned as the root cause in the security advisory GHSA-3f6c-7fw2-ppm4j.

To properly mitigate this, you should either disable redirects or verify the domain of the final URL after all redirects have been followed. Disabling redirects might be the safest option if they are not a required feature. If they are required, the HTTPConnection class should be modified to not follow redirects automatically, and instead, redirects should be handled manually within MediaConnector to ensure every URL in the redirect chain is validated against the allowed domains.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huachenheli @DarkLight1337 this seems like a good point. I wouldn't block merging the current change over this, but it seems worth a follow-up change.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This PR should be fine for immediate needs. We already have media_io_kwargs that we can use to control MediaConnector behavior, so we just need to pass that to the HttpConnection to disallow redirects.


connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)

Expand All @@ -150,6 +164,8 @@ async def load_from_url_async(
loop = asyncio.get_running_loop()

if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)

connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
future = loop.run_in_executor(global_thread_pool,
Expand Down