Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def main():
torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501

img_prompt = dict(
img_data = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)

prompt = dict(data=img_data)

llm = LLM(
model="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
skip_tokenizer_init=True,
Expand All @@ -41,7 +43,7 @@ def main():
enable_mm_embeds=True,
)

pooler_output = llm.encode(img_prompt, pooling_task="plugin")
pooler_output = llm.encode(prompt, pooling_task="plugin")
output = pooler_output[0].outputs

print(output)
Expand Down
6 changes: 4 additions & 2 deletions tests/plugins_tests/test_io_processor_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ async def test_prithvi_mae_plugin_online(
def test_prithvi_mae_plugin_offline(
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
):
img_prompt = dict(
img_data = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)

prompt = dict(data=img_data)

with vllm_runner(
model_name,
runner="pooling",
Expand All @@ -139,7 +141,7 @@ def test_prithvi_mae_plugin_offline(
io_processor_plugin=plugin,
default_torch_num_threads=1,
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin")
output = pooler_output[0].outputs

# verify the output is formatted as expected for this plugin
Expand Down
10 changes: 9 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,15 @@ def encode(
)

# Validate the request data is valid for the loaded plugin
validated_prompt = self.io_processor.parse_data(prompts)
prompt_data = prompts.get("data")
if prompt_data is None:
raise ValueError(
"The 'data' field of the prompt is expected to contain "
"the prompt data and it cannot be None. "
"Refer to the documentation of the IOProcessor "
"in use for more details."
)
validated_prompt = self.io_processor.parse_data(prompt_data)

# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
Expand Down