diff --git a/tests/models/multimodal/test_nano_nemotron_vl.py b/tests/models/multimodal/test_nano_nemotron_vl.py index 6922af79c08e..aa93ee31168d 100644 --- a/tests/models/multimodal/test_nano_nemotron_vl.py +++ b/tests/models/multimodal/test_nano_nemotron_vl.py @@ -53,6 +53,19 @@ def load_weights(self, weights): self.loaded_weights = list(weights) +class _FakeTensor: + """Sentinel stand-in for torch.Tensor in load_weights tests. Supports the + .detach().clone() chain used by load_weights for buffered mm weights; + both methods return self so identity (and the existing equality + assertions) are preserved through cloning.""" + + def detach(self): + return self + + def clone(self): + return self + + def test_nano_nemotron_vl_skips_multimodal_weights_in_text_only_mode(): model = object.__new__(NemotronH_Nano_VL_V2) language_model = _LanguageModel() @@ -86,7 +99,7 @@ def test_nano_nemotron_vl_loads_vision_weights_without_sound_encoder(): object.__setattr__(model, "sound_encoder", None) language_weight = object() - vision_weight = object() + vision_weight = _FakeTensor() model.load_weights( [ ("language_model.layers.0.weight", language_weight), diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 994b52606b18..64667503d578 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -1518,37 +1518,51 @@ def is_vision_weights(name: str) -> bool: def is_sound_weights(name: str) -> bool: return name.startswith("sound") - # Separate weights by component - llm_weights = [] - vision_weights = [] - sound_weights = [] - - for name, w in weights: - if is_llm(name): - # Strip 'language_model.' prefix for LLM weights - llm_weights.append((".".join(name.split(".")[1:]), w)) - elif is_adapter_weights((name, w)): - if not load_multimodal_weights: - continue + # LLM weights (the bulk of the model) are streamed lazily through a + # generator so each tensor is copied into its parameter before the + # iterator advances, avoiding stale-reference corruption with + # reusable-buffer streamers. The smaller mm components (mlp1, vision, + # sound) are detach+cloned on append so they are independent of any + # reusable buffer the streamer may use, then loaded after the LLM. + adapter_weights: list[tuple[str, torch.Tensor]] = [] + vision_weights: list[tuple[str, torch.Tensor]] = [] + sound_weights: list[tuple[str, torch.Tensor]] = [] + + def llm_weights_gen(): + for name, w in weights: + if is_llm(name): + # Strip 'language_model.' prefix for LLM weights + yield ".".join(name.split(".")[1:]), w + elif is_adapter_weights((name, w)): + if not load_multimodal_weights: + continue + trimmed_name = ".".join(name.split(".")[1:]) + adapter_weights.append((trimmed_name, w.detach().clone())) + elif is_vision_weights(name): + if not load_multimodal_weights: + continue + # Convert: vision_model.radio_model.* → radio_model.* + hf_key = name[len("vision_model.") :] + vision_weights.append((hf_key, w.detach().clone())) + elif is_sound_weights(name): + if not load_multimodal_weights: + continue + assert self.sound_encoder is not None + sound_weights.append((name, w.detach().clone())) + + # Fully drain the generator so every mm tensor is buffered, even if + # the LLM loader stops iterating early. + llm_weights_iter = llm_weights_gen() + self.language_model.load_weights(llm_weights_iter) + for _ in llm_weights_iter: + pass + + if load_multimodal_weights: + for trimmed_name, w in adapter_weights: # Load vision-language adapter weights directly - trimmed_name = ".".join(name.split(".")[1:]) param = adapter_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) - elif is_vision_weights(name): - if not load_multimodal_weights: - continue - # Convert: vision_model.radio_model.* → radio_model.* - hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix - vision_weights.append((hf_key, w)) - elif is_sound_weights(name): - if not load_multimodal_weights: - continue - assert self.sound_encoder is not None - sound_weights.append((name, w)) - - self.language_model.load_weights(llm_weights) - if load_multimodal_weights: self.vision_model.load_weights(vision_weights) if self.sound_encoder is not None and len(sound_weights) > 0: self.sound_encoder.load_weights(sound_weights)