|
13 | 13 | from vllm.model_executor.layers.activation import get_act_fn |
14 | 14 | from vllm.model_executor.layers.quantization import QuantizationConfig |
15 | 15 | from vllm.model_executor.layers.sampler import Sampler, SamplerOutput |
16 | | -from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
17 | 16 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
18 | 17 | from vllm.multimodal import MULTIMODAL_REGISTRY |
19 | 18 | from vllm.sequence import IntermediateTensors, SequenceData |
20 | 19 |
|
21 | 20 | from .blip import (BlipVisionModel, dummy_image_for_blip, |
22 | 21 | get_max_blip_image_tokens) |
23 | 22 | from .interfaces import SupportsMultiModal, SupportsPP |
24 | | -from .utils import (group_weights_with_prefix, init_vllm_registered_model, |
| 23 | +from .utils import (AutoWeightsLoader, init_vllm_registered_model, |
25 | 24 | merge_multimodal_embeddings) |
26 | 25 |
|
27 | 26 | # We use this internally as placeholders since there is no image token |
@@ -687,35 +686,5 @@ def sample( |
687 | 686 | return self.language_model.sample(logits, sampling_metadata) |
688 | 687 |
|
689 | 688 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
690 | | - # prepare weight iterators for components |
691 | | - weights_group = group_weights_with_prefix(weights) |
692 | | - |
693 | | - # load vision encoder |
694 | | - self.vision_model.load_weights(weights_group["vision_model"]) |
695 | | - |
696 | | - # load query tokens |
697 | | - for name, loaded_weight in weights_group["query_tokens"]: |
698 | | - assert name == "" |
699 | | - param = self.query_tokens |
700 | | - weight_loader = getattr(param, "weight_loader", |
701 | | - default_weight_loader) |
702 | | - weight_loader(param, loaded_weight) |
703 | | - |
704 | | - # load qformer |
705 | | - qformer_params_dict = dict(self.qformer.named_parameters()) |
706 | | - for name, loaded_weight in weights_group["qformer"]: |
707 | | - param = qformer_params_dict[name] |
708 | | - weight_loader = getattr(param, "weight_loader", |
709 | | - default_weight_loader) |
710 | | - weight_loader(param, loaded_weight) |
711 | | - |
712 | | - # load mlp projector |
713 | | - mlp_params_dict = dict(self.language_projection.named_parameters()) |
714 | | - for name, loaded_weight in weights_group["language_projection"]: |
715 | | - param = mlp_params_dict[name] |
716 | | - weight_loader = getattr(param, "weight_loader", |
717 | | - default_weight_loader) |
718 | | - weight_loader(param, loaded_weight) |
719 | | - |
720 | | - # load llm backbone |
721 | | - self.language_model.load_weights(weights_group["language_model"]) |
| 689 | + loader = AutoWeightsLoader(self) |
| 690 | + loader.load_weights(weights) |
0 commit comments