diff --git a/examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py b/examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py index f0f1fddb73fa..db634d8be760 100644 --- a/examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py +++ b/examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py @@ -20,13 +20,15 @@ def main(): torch.set_default_dtype(torch.float16) image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 - img_prompt = dict( + img_data = dict( data=image_url, data_format="url", image_format="tiff", out_data_format="b64_json", ) + prompt = dict(data=img_data) + llm = LLM( model="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", skip_tokenizer_init=True, @@ -41,7 +43,7 @@ def main(): enable_mm_embeds=True, ) - pooler_output = llm.encode(img_prompt, pooling_task="plugin") + pooler_output = llm.encode(prompt, pooling_task="plugin") output = pooler_output[0].outputs print(output) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 4d0e7be0ef70..04cb19499296 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -120,13 +120,15 @@ async def test_prithvi_mae_plugin_online( def test_prithvi_mae_plugin_offline( vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str ): - img_prompt = dict( + img_data = dict( data=image_url, data_format="url", image_format="tiff", out_data_format="b64_json", ) + prompt = dict(data=img_data) + with vllm_runner( model_name, runner="pooling", @@ -139,7 +141,7 @@ def test_prithvi_mae_plugin_offline( io_processor_plugin=plugin, default_torch_num_threads=1, ) as llm_runner: - pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin") + pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin") output = pooler_output[0].outputs # verify the output is formatted as expected for this plugin diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d27fa7074c35..91b39f798858 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1135,7 +1135,15 @@ def encode( ) # Validate the request data is valid for the loaded plugin - validated_prompt = self.io_processor.parse_data(prompts) + prompt_data = prompts.get("data") + if prompt_data is None: + raise ValueError( + "The 'data' field of the prompt is expected to contain " + "the prompt data and it cannot be None. " + "Refer to the documentation of the IOProcessor " + "in use for more details." + ) + validated_prompt = self.io_processor.parse_data(prompt_data) # obtain the actual model prompts from the pre-processor prompts = self.io_processor.pre_process(prompt=validated_prompt)