From dd2a6a82e3f41b4673b1dbb24b2e99230ea96981 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 2 Sep 2024 23:48:56 +0800 Subject: [PATCH] [Bugfix] Fix internlm2 tensor parallel inference (#8055) --- vllm/model_executor/models/internlm2.py | 47 ++++++++++++++++++------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 9b7cada187ce1..23669b540f561 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from functools import partial from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -7,7 +8,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -70,20 +74,21 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= self.tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % self.tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -122,11 +127,27 @@ def __init__( quant_config=quant_config) def split_qkv(self, qkv: torch.Tensor): - qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128) - q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2) - q = q.reshape(-1, self.q_size) - k = k.reshape(-1, self.kv_size) - v = v.reshape(-1, self.kv_size) + seq_len = qkv.shape[0] + if self.tp_size > 1: + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size + qkv = tensor_model_parallel_all_gather(qkv) + qkv = torch.split(qkv, qkv_map, dim=-1) + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] + qkv = torch.cat(qkv, dim=-1) + + qkv = qkv.view(seq_len, self.total_num_kv_heads, + self.key_value_groups + 2, self.head_dim) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) + q = q.reshape(seq_len, self.q_size * self.tp_size) + k = k.reshape(seq_len, self.kv_size * self.tp_size) + v = v.reshape(seq_len, self.kv_size * self.tp_size) + + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] return q, k, v def forward(