Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,40 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")


def run_minimax(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"

model_name = "MiniMaxAI/MiniMax-VL-01"

# The configuration below has been confirmed to launch on a single L40 GPU.
engine_args = EngineArgs(
model=model_name,
max_model_len=14336,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [
[
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": question}],
}
]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)

return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# Mistral-3 HF-format
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1412,6 +1446,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"mantis": run_mantis,
"minicpmo": run_minicpmo,
"minicpmv": run_minicpmv,
"minimax": run_minimax,
"mistral3": run_mistral3,
"mllama": run_mllama,
"molmo": run_molmo,
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/minimax_vl_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class MiniMaxVL01ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Shape:
`(batch_size * num_images * num_patches, num_channels, height, width)`

Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Expand Down Expand Up @@ -312,7 +313,7 @@ def _parse_and_validate_image_input(
return MiniMaxVL01ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
flatten_bn(flatten_bn(pixel_values), concat=True)),
)

if image_embeds is not None:
Expand Down