Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/models/fixtures/ministral_3b_chat.json

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions tests/models/multimodal/generation/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MINISTRAL_3B_ID = "mistralai/Ministral-3-3B-Instruct-2512"

MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]

Expand Down Expand Up @@ -116,6 +117,7 @@ def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt:
FIXTURE_LOGPROBS_CHAT = {
PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
MINISTRAL_3B_ID: FIXTURES_PATH / "ministral_3b_chat.json",
}

OutputsLogprobs = list[tuple[list[int], str, SampleLogprobs | None]]
Expand Down Expand Up @@ -209,3 +211,41 @@ def test_chat(
name_0="h100_ref",
name_1="output",
)


@large_gpu_test(min_gb=16)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_chat_consolidated(vllm_runner, dtype: str, local_asset_server) -> None:
Comment thread
juliendenize marked this conversation as resolved.
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
FIXTURE_LOGPROBS_CHAT[MINISTRAL_3B_ID]
)
with vllm_runner(
MINISTRAL_3B_ID,
dtype=dtype,
tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral",
max_model_len=8192,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
outputs = []
urls_all = [local_asset_server.url_for(u) for u in IMG_URLS]
msgs = [
_create_msg_format(urls_all[:1]),
_create_msg_format(urls_all[:2]),
_create_msg_format(urls_all),
]
for msg in msgs:
output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS)
outputs.extend(output)

logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
for i in range(len(logprobs)):
assert logprobs[i][-1] is None
logprobs[i] = logprobs[i][:-1]
check_logprobs_close(
outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output",
)
19 changes: 19 additions & 0 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,27 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
_vision_encoder_stacked_params = [
# (param_name, shard_name, shard_id)
# HF format
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
# Mistral native (consolidated) format

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wo and w2 parameters are handled via _vision_encoder_name_remap rather than through _vision_encoder_stacked_params. Since they're not sharded across TP ranks like qkv/w1/w3, they don't appear in the stacked params list. Is there a reason they couldn't be added to the stacked params list with their shard_id, or is the remap approach more robust to variations in how these keys appear across different checkpoint formats?

(".qkv_proj", ".wq", "q"),
(".qkv_proj", ".wk", "k"),
(".qkv_proj", ".wv", "v"),
(".gate_up_proj", ".w1", 0),
(".gate_up_proj", ".w3", 1),
Comment thread
juliendenize marked this conversation as resolved.
]

# Remap Mistral native names to HF-style names
# used by the vLLM vision encoder modules.
_vision_encoder_name_remap = {
".wo.": ".o_proj.",
".w2.": ".down_proj.",
}

def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith(("vision_encoder", "vision_tower"))

Expand Down Expand Up @@ -518,6 +532,11 @@ def llm_weights_generator():
weight_loader(param, w, shard_id)
break
else:
for old, new in _vision_encoder_name_remap.items():
if old in trimmed_name:
trimmed_name = trimmed_name.replace(old, new)
break

param = vision_encoder_dict.get(trimmed_name)
if param is not None:
weight_loader = getattr(
Expand Down
Loading