diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 077e36176430..bcff6eb3fd31 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -42,7 +42,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -384,7 +385,7 @@ def __init__( lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config - + self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config self.model = BaiChuanModel(vllm_config=vllm_config, prefix=prefix, @@ -438,8 +439,10 @@ def lm_head_weight_loader(self, param: nn.Parameter, is_baichuan2 = self.config.vocab_size == 125696 if is_baichuan2: loaded_weight = torch.nn.functional.normalize(loaded_weight) - - default_weight_loader(param, loaded_weight) + if self.tp_size > 1: + row_parallel_weight_loader(param, loaded_weight) + else: + default_weight_loader(param, loaded_weight) class BaichuanForCausalLM(BaiChuanBaseForCausalLM):