Skip to content

Commit 29adb39

Browse files
MengqingCaoYuqi Zhang
authored andcommitted
[Bugfix][Model] Fix baichuan model loader for tp (vllm-project#18597)
Signed-off-by: Mengqing Cao <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent c38604d commit 29adb39

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

vllm/model_executor/models/baichuan.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
from vllm.model_executor.layers.rotary_embedding import get_rope
4343
from vllm.model_executor.layers.vocab_parallel_embedding import (
4444
ParallelLMHead, VocabParallelEmbedding)
45-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45+
from vllm.model_executor.model_loader.weight_utils import (
46+
default_weight_loader, row_parallel_weight_loader)
4647
from vllm.model_executor.sampling_metadata import SamplingMetadata
4748
from vllm.sequence import IntermediateTensors
4849

@@ -384,7 +385,7 @@ def __init__(
384385
lora_config = vllm_config.lora_config
385386
self.config = config
386387
self.lora_config = lora_config
387-
388+
self.tp_size = get_tensor_model_parallel_world_size()
388389
self.quant_config = quant_config
389390
self.model = BaiChuanModel(vllm_config=vllm_config,
390391
prefix=prefix,
@@ -438,8 +439,10 @@ def lm_head_weight_loader(self, param: nn.Parameter,
438439
is_baichuan2 = self.config.vocab_size == 125696
439440
if is_baichuan2:
440441
loaded_weight = torch.nn.functional.normalize(loaded_weight)
441-
442-
default_weight_loader(param, loaded_weight)
442+
if self.tp_size > 1:
443+
row_parallel_weight_loader(param, loaded_weight)
444+
else:
445+
default_weight_loader(param, loaded_weight)
443446

444447

445448
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):

0 commit comments

Comments
 (0)