Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tests/models/multimodal/test_nano_nemotron_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def load_weights(self, weights):
self.loaded_weights = list(weights)


class _FakeTensor:
"""Sentinel stand-in for torch.Tensor in load_weights tests. Supports the
.detach().clone() chain used by load_weights for buffered mm weights;
both methods return self so identity (and the existing equality
assertions) are preserved through cloning."""

def detach(self):
return self

def clone(self):
return self


def test_nano_nemotron_vl_skips_multimodal_weights_in_text_only_mode():
model = object.__new__(NemotronH_Nano_VL_V2)
language_model = _LanguageModel()
Expand Down Expand Up @@ -86,7 +99,7 @@ def test_nano_nemotron_vl_loads_vision_weights_without_sound_encoder():
object.__setattr__(model, "sound_encoder", None)

language_weight = object()
vision_weight = object()
vision_weight = _FakeTensor()
model.load_weights(
[
("language_model.layers.0.weight", language_weight),
Expand Down
68 changes: 41 additions & 27 deletions vllm/model_executor/models/nano_nemotron_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,37 +1518,51 @@ def is_vision_weights(name: str) -> bool:
def is_sound_weights(name: str) -> bool:
return name.startswith("sound")

# Separate weights by component
llm_weights = []
vision_weights = []
sound_weights = []

for name, w in weights:
if is_llm(name):
# Strip 'language_model.' prefix for LLM weights
llm_weights.append((".".join(name.split(".")[1:]), w))
elif is_adapter_weights((name, w)):
if not load_multimodal_weights:
continue
# LLM weights (the bulk of the model) are streamed lazily through a
# generator so each tensor is copied into its parameter before the
# iterator advances, avoiding stale-reference corruption with
# reusable-buffer streamers. The smaller mm components (mlp1, vision,
# sound) are detach+cloned on append so they are independent of any
# reusable buffer the streamer may use, then loaded after the LLM.
adapter_weights: list[tuple[str, torch.Tensor]] = []
vision_weights: list[tuple[str, torch.Tensor]] = []
sound_weights: list[tuple[str, torch.Tensor]] = []

def llm_weights_gen():
for name, w in weights:
if is_llm(name):
# Strip 'language_model.' prefix for LLM weights
yield ".".join(name.split(".")[1:]), w
elif is_adapter_weights((name, w)):
if not load_multimodal_weights:
continue
trimmed_name = ".".join(name.split(".")[1:])
adapter_weights.append((trimmed_name, w.detach().clone()))
elif is_vision_weights(name):
if not load_multimodal_weights:
continue
# Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len("vision_model.") :]
vision_weights.append((hf_key, w.detach().clone()))
elif is_sound_weights(name):
if not load_multimodal_weights:
continue
assert self.sound_encoder is not None
sound_weights.append((name, w.detach().clone()))
Comment on lines +1531 to +1551

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of llm_weights_gen relies on the assumption that self.language_model.load_weights will fully consume the generator. If for any reason the language model's weight loader stops early (e.g., it only looks for a subset of weights), the multimodal weights (adapter_weights, vision_weights, sound_weights) will be only partially collected, leading to incomplete weight loading for those components. While the current AutoWeightsLoader in vLLM does consume the full iterable, this creates a fragile temporal coupling between the LLM loading phase and the multimodal collection phase. A more robust approach would be to ensure the generator is fully exhausted before proceeding to load multimodal components, or to explicitly document this dependency.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix

After self.language_model.load_weights(...) returns, the generator is now explicitly drained:

# Fully drain the generator so every mm tensor is buffered, even if
# the LLM loader stops iterating early.
llm_weights_iter = llm_weights_gen()
self.language_model.load_weights(llm_weights_iter)                                                                                                                                                                                                                               
for _ in llm_weights_iter:
    pass                                                                                                                                                                                                                                                                         

Holding the generator in a named variable lets us iterate the remainder ourselves. Generators are stateful and resume from where the previous consumer left off, so this loop is a no-op when the LLM loader already consumed everything, and a safety net when it did not.

Why this resolves the concern

  • Removes the implicit dependency. Mm-buffer completeness no longer relies on the LLM loader's iteration behavior. Whether it drains the iterable, stops after N items, or never iterates at all, the mm branches see every input tensor.
  • No new pre-allocation or buffering. LLM weights are still streamed lazily — the drain loop reads the same generator the LLM loader was reading, so there's no extra accumulation step and no change to peak memory.
  • Order of operations is preserved. mm components are still loaded after the LLM, on the same control-flow path. The only added work is finishing the iterator, which by construction has at most the remaining unprocessed input weights.


# Fully drain the generator so every mm tensor is buffered, even if
# the LLM loader stops iterating early.
llm_weights_iter = llm_weights_gen()
self.language_model.load_weights(llm_weights_iter)
for _ in llm_weights_iter:
pass

if load_multimodal_weights:
for trimmed_name, w in adapter_weights:
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_weights(name):
if not load_multimodal_weights:
continue
# Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
vision_weights.append((hf_key, w))
elif is_sound_weights(name):
if not load_multimodal_weights:
continue
assert self.sound_encoder is not None
sound_weights.append((name, w))

self.language_model.load_weights(llm_weights)
if load_multimodal_weights:
self.vision_model.load_weights(vision_weights)
if self.sound_encoder is not None and len(sound_weights) > 0:
self.sound_encoder.load_weights(sound_weights)
Expand Down
Loading