|
42 | 42 | from vllm.model_executor.layers.rotary_embedding import get_rope |
43 | 43 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
44 | 44 | ParallelLMHead, VocabParallelEmbedding) |
45 | | -from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 45 | +from vllm.model_executor.model_loader.weight_utils import ( |
| 46 | + default_weight_loader, row_parallel_weight_loader) |
46 | 47 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
47 | 48 | from vllm.sequence import IntermediateTensors |
48 | 49 |
|
@@ -384,7 +385,7 @@ def __init__( |
384 | 385 | lora_config = vllm_config.lora_config |
385 | 386 | self.config = config |
386 | 387 | self.lora_config = lora_config |
387 | | - |
| 388 | + self.tp_size = get_tensor_model_parallel_world_size() |
388 | 389 | self.quant_config = quant_config |
389 | 390 | self.model = BaiChuanModel(vllm_config=vllm_config, |
390 | 391 | prefix=prefix, |
@@ -438,8 +439,10 @@ def lm_head_weight_loader(self, param: nn.Parameter, |
438 | 439 | is_baichuan2 = self.config.vocab_size == 125696 |
439 | 440 | if is_baichuan2: |
440 | 441 | loaded_weight = torch.nn.functional.normalize(loaded_weight) |
441 | | - |
442 | | - default_weight_loader(param, loaded_weight) |
| 442 | + if self.tp_size > 1: |
| 443 | + row_parallel_weight_loader(param, loaded_weight) |
| 444 | + else: |
| 445 | + default_weight_loader(param, loaded_weight) |
443 | 446 |
|
444 | 447 |
|
445 | 448 | class BaichuanForCausalLM(BaiChuanBaseForCausalLM): |
|
0 commit comments