From 82ec5c029fb484c2799e1a3e54444387a3de9d76 Mon Sep 17 00:00:00 2001 From: Chenhui Zhang Date: Mon, 8 Jan 2024 17:17:08 +0800 Subject: [PATCH] fix weigit loading for GQA with TP --- vllm/model_executor/layers/linear.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5190de65d795..5e1d63a6a62e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -423,7 +423,10 @@ def weight_loader(self, shard_offset = shard_offset // param.pack_factor param_data = param_data.narrow(output_dim, shard_offset, shard_size) - shard_id = tp_rank // self.num_kv_head_replicas + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)