Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/models/fixtures/ministral_3b_chat.json

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions tests/models/multimodal/generation/test_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MINISTRAL_3B_ID = "mistralai/Ministral-3-3B-Instruct-2512"

MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]

Expand Down Expand Up @@ -116,6 +117,7 @@ def _create_engine_inputs_hf(urls: list[str]) -> TextPrompt:
FIXTURE_LOGPROBS_CHAT = {
PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
MINISTRAL_3B_ID: FIXTURES_PATH / "ministral_3b_chat.json",
}

OutputsLogprobs = list[tuple[list[int], str, SampleLogprobs | None]]
Expand Down Expand Up @@ -209,3 +211,41 @@ def test_chat(
name_0="h100_ref",
name_1="output",
)


@large_gpu_test(min_gb=16)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_chat_consolidated(vllm_runner, dtype: str, local_asset_server) -> None:
Comment thread
juliendenize marked this conversation as resolved.
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
FIXTURE_LOGPROBS_CHAT[MINISTRAL_3B_ID]
)
with vllm_runner(
MINISTRAL_3B_ID,
dtype=dtype,
tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral",
max_model_len=8192,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
outputs = []
urls_all = [local_asset_server.url_for(u) for u in IMG_URLS]
msgs = [
_create_msg_format(urls_all[:1]),
_create_msg_format(urls_all[:2]),
_create_msg_format(urls_all),
]
for msg in msgs:
output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS)
outputs.extend(output)

logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
for i in range(len(logprobs)):
assert logprobs[i][-1] is None
logprobs[i] = logprobs[i][:-1]
check_logprobs_close(
outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output",
)
19 changes: 19 additions & 0 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,27 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
_vision_encoder_stacked_params = [
# (param_name, shard_name, shard_id)
# HF format
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
# Mistral native (consolidated) format
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wo and w2 parameters are handled via _vision_encoder_name_remap rather than through _vision_encoder_stacked_params. Since they're not sharded across TP ranks like qkv/w1/w3, they don't appear in the stacked params list. Is there a reason they couldn't be added to the stacked params list with their shard_id, or is the remap approach more robust to variations in how these keys appear across different checkpoint formats?

(".qkv_proj", ".wq", "q"),
(".qkv_proj", ".wk", "k"),
(".qkv_proj", ".wv", "v"),
(".gate_up_proj", ".w1", 0),
(".gate_up_proj", ".w3", 1),
Comment thread
juliendenize marked this conversation as resolved.
]

# Remap Mistral native names to HF-style names
# used by the vLLM vision encoder modules.
_vision_encoder_name_remap = {
".wo.": ".o_proj.",
".w2.": ".down_proj.",
}

def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith(("vision_encoder", "vision_tower"))

Expand Down Expand Up @@ -518,6 +532,11 @@ def llm_weights_generator():
weight_loader(param, w, shard_id)
break
else:
for old, new in _vision_encoder_name_remap.items():
if old in trimmed_name:
trimmed_name = trimmed_name.replace(old, new)
break

param = vision_encoder_dict.get(trimmed_name)
if param is not None:
weight_loader = getattr(
Expand Down
Loading