diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 7a525ad8e494..ac95b65fd03f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -27,7 +27,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -154,6 +154,26 @@ def forward( return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "A_log" in name: + name = name.replace("A_log", "A") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, SupportsV0Only): @@ -257,20 +277,5 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "A_log" in name: - name = name.replace("A_log", "A") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights)