Skip to content

Commit 69ff99f

Browse files
NickLuccheNickLucche
and
NickLucche
authored
[Core] Optimizing cross-attention QKVParallelLinear computation (vllm-project#12325)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: NickLucche <[email protected]> Co-authored-by: NickLucche <[email protected]>
1 parent 5d80252 commit 69ff99f

File tree

4 files changed

+121
-44
lines changed

4 files changed

+121
-44
lines changed

vllm/model_executor/layers/linear.py

+95
Original file line numberDiff line numberDiff line change
@@ -1227,3 +1227,98 @@ def extra_repr(self) -> str:
12271227
s += f", tp_size={self.tp_size}"
12281228
s += f", reduce_results={self.reduce_results}"
12291229
return s
1230+
1231+
1232+
class QKVCrossParallelLinear(torch.nn.Module):
1233+
1234+
def __init__(self,
1235+
hidden_size: int,
1236+
head_size: int,
1237+
total_num_heads: int,
1238+
total_num_kv_heads: Optional[int] = None,
1239+
bias: bool = True,
1240+
skip_bias_add: bool = False,
1241+
params_dtype: Optional[torch.dtype] = None,
1242+
quant_config: Optional[QuantizationConfig] = None,
1243+
prefix: str = ""):
1244+
super().__init__()
1245+
# Empty placeholders for loading as a single module.
1246+
self.weight = torch.nn.Parameter()
1247+
set_weight_attrs(self.weight, {
1248+
"weight_loader": self.weight_loader_weight,
1249+
})
1250+
# Use a dictionary to avoid submodules parameters auto-registration:
1251+
# drop-in replacement for a `QKVParallelLinear` module.
1252+
self.proj = dict()
1253+
self.proj["q_proj_decoder"] = ColumnParallelLinear(
1254+
input_size=hidden_size,
1255+
output_size=total_num_heads * head_size,
1256+
bias=bias,
1257+
quant_config=quant_config,
1258+
skip_bias_add=skip_bias_add,
1259+
params_dtype=params_dtype,
1260+
prefix=f"{prefix}.q_proj_decoder")
1261+
1262+
self.proj["kv_proj_encoder"] = QKVParallelLinear(
1263+
hidden_size=hidden_size,
1264+
head_size=head_size,
1265+
total_num_heads=0,
1266+
total_num_kv_heads=total_num_kv_heads,
1267+
bias=bias,
1268+
quant_config=quant_config,
1269+
skip_bias_add=skip_bias_add,
1270+
params_dtype=params_dtype,
1271+
prefix=f"{prefix}.kv_proj_encoder")
1272+
1273+
# `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1274+
self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size
1275+
1276+
if bias:
1277+
self.bias = torch.nn.Parameter()
1278+
set_weight_attrs(self.bias, {
1279+
"weight_loader": self.weight_loader_bias,
1280+
})
1281+
1282+
@property
1283+
def q_proj_decoder(self):
1284+
return self.proj["q_proj_decoder"]
1285+
1286+
@property
1287+
def kv_proj_encoder(self):
1288+
return self.proj["kv_proj_encoder"]
1289+
1290+
def forward(self, decoder_hidden_states, encoder_hidden_states):
1291+
q, _ = self.q_proj_decoder(decoder_hidden_states)
1292+
if encoder_hidden_states is None:
1293+
# Encoder KV already cached.
1294+
k = None
1295+
v = None
1296+
else:
1297+
# Prefill phase, encoder KV cached here.
1298+
kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
1299+
# Split kv in half
1300+
k, v = kv_enc.split(self.kv_size, dim=-1)
1301+
return q, k, v
1302+
1303+
def weight_loader_weight(self,
1304+
param: torch.nn.Parameter,
1305+
loaded_weight: torch.Tensor,
1306+
loaded_shard_id: Optional[str] = None):
1307+
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
1308+
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
1309+
else self.kv_proj_encoder.weight
1310+
param.weight_loader(
1311+
param,
1312+
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1313+
param, loaded_weight, loaded_shard_id)
1314+
1315+
def weight_loader_bias(self,
1316+
param: torch.nn.Parameter,
1317+
loaded_weight: torch.Tensor,
1318+
loaded_shard_id: Optional[str] = None):
1319+
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
1320+
else self.kv_proj_encoder.bias
1321+
param.weight_loader(
1322+
param,
1323+
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
1324+
param, loaded_weight, loaded_shard_id)

vllm/model_executor/models/bart.py

+13-26
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from vllm.distributed import get_tensor_model_parallel_world_size
3232
from vllm.model_executor.layers.activation import get_act_fn
3333
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
34+
QKVCrossParallelLinear,
3435
QKVParallelLinear,
3536
RowParallelLinear)
3637
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -169,7 +170,7 @@ def __init__(
169170
# Number of KV heads is less than TP size, so we replicate
170171
# the KV heads across multiple tensor parallel GPUs.
171172
assert tp_world_size % self.total_num_kv_heads == 0
172-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
173+
self.num_kv_heads = self.num_heads
173174
self.q_size = self.num_heads * self.head_dim
174175
self.kv_size = self.num_kv_heads * self.head_dim
175176

@@ -248,7 +249,7 @@ def __init__(
248249
# Number of KV heads is less than TP size, so we replicate
249250
# the KV heads across multiple tensor parallel GPUs.
250251
assert tp_world_size % self.total_num_kv_heads == 0
251-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
252+
self.num_kv_heads = self.num_heads
252253
self.q_size = self.num_heads * self.head_dim
253254
self.kv_size = self.num_kv_heads * self.head_dim
254255

@@ -299,14 +300,14 @@ def __init__(
299300
f" and `num_heads`: {num_heads}).")
300301
self.scaling = self.head_dim**-0.5
301302

302-
self.qkv_proj = QKVParallelLinear(
303-
self.d_model,
304-
self.d_model // self.total_num_heads,
305-
self.total_num_heads,
306-
self.total_num_kv_heads,
307-
bias=bias,
308-
quant_config=quant_config,
309-
)
303+
# TP sharding sizes is accounted for within "*Parallel" layers.
304+
self.qkv_proj = QKVCrossParallelLinear(self.d_model,
305+
self.d_model //
306+
self.total_num_heads,
307+
self.total_num_heads,
308+
self.total_num_kv_heads,
309+
bias,
310+
quant_config=quant_config)
310311

311312
self.out_proj = RowParallelLinear(
312313
embed_dim,
@@ -327,10 +328,7 @@ def __init__(
327328
# Number of KV heads is less than TP size, so we replicate
328329
# the KV heads across multiple tensor parallel GPUs.
329330
assert tp_world_size % self.total_num_kv_heads == 0
330-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
331-
self.q_size = self.num_heads * self.head_dim
332-
self.kv_size = self.num_kv_heads * self.head_dim
333-
331+
self.num_kv_heads = self.num_heads # No GQA in bart
334332
self.attn = Attention(self.num_heads,
335333
self.head_dim,
336334
self.scaling,
@@ -347,18 +345,7 @@ def forward(
347345
) -> torch.Tensor:
348346
"""Input shape: Batch x Time x Channel"""
349347

350-
# (afeldman-nm 2024/07/22) TODO:
351-
# Need a more efficient solution for q/k/v
352-
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
353-
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
354-
dim=-1)
355-
if encoder_hidden_states is None:
356-
k = None
357-
v = None
358-
else:
359-
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
360-
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
361-
dim=-1)
348+
q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states)
362349

363350
attn_output = self.attn(q, k, v)
364351

vllm/model_executor/models/mllama.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from vllm.logger import init_logger
4444
from vllm.model_executor.layers.layernorm import RMSNorm
4545
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46+
QKVCrossParallelLinear,
4647
QKVParallelLinear,
4748
RowParallelLinear)
4849
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -798,21 +799,22 @@ def __init__(
798799
self.config = config
799800
self.pipeline_parallel_rank = get_pp_group().rank_in_group
800801
self.tensor_parallel_size = get_tp_group().world_size
801-
self.num_heads = self.config.num_attention_heads
802+
self.num_heads = config.num_attention_heads
803+
self.num_key_value_heads = config.num_key_value_heads
804+
802805
self.num_local_heads = self.num_heads // self.tensor_parallel_size
803-
self.num_key_value_heads = self.config.num_key_value_heads
804806
self.num_local_key_value_heads = \
805807
self.num_key_value_heads // self.tensor_parallel_size
806-
self.dropout = config.dropout
807808
self.hidden_size = config.hidden_size
808809
self.head_dim = config.hidden_size // self.num_heads
810+
self.num_key_value_heads = config.num_key_value_heads
811+
809812
self.layer_idx = layer_idx
810813
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
811814
self.q_local_size = self.num_local_heads * self.head_dim
812815
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
813816

814-
# TODO: change to Q/KV separate linear after #7448 is merged
815-
self.qkv_proj = QKVParallelLinear(
817+
self.qkv_proj = QKVCrossParallelLinear(
816818
self.hidden_size,
817819
self.head_dim,
818820
self.num_heads,
@@ -821,6 +823,7 @@ def __init__(
821823
quant_config=quant_config,
822824
prefix=f"{prefix}.qkv_proj",
823825
)
826+
824827
self.o_proj = RowParallelLinear(
825828
self.num_heads * self.head_dim,
826829
self.hidden_size,
@@ -851,21 +854,12 @@ def forward(
851854
kv_range_for_decode: Optional[List[Tuple[int, int]]],
852855
cross_attention_states: Optional[torch.Tensor],
853856
) -> torch.Tensor:
854-
qkv_dec, _ = self.qkv_proj(hidden_states)
855-
q, _, _ = qkv_dec.split(
856-
[self.q_local_size, self.kv_local_size, self.kv_local_size],
857-
dim=-1)
858-
if cross_attention_states is None:
859-
k = None
860-
v = None
861-
else:
862-
qkv_enc, _ = self.qkv_proj(cross_attention_states)
863-
_, k, v = qkv_enc.split(
864-
[self.q_local_size, self.kv_local_size, self.kv_local_size],
865-
dim=-1)
857+
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
858+
if cross_attention_states is not None:
866859
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
867860
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
868861
k = self.k_norm(k)
862+
869863
q = q.view(-1, self.num_local_heads, self.head_dim)
870864
q = self.q_norm(q)
871865

@@ -889,6 +883,7 @@ def _attention_with_mask(
889883
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
890884
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
891885
# Skip writing kv-cache for the initial profiling run.
886+
# TODO (NickLucche) replace with custom attn bias and use standard attn
892887
if len(kv_cache.shape) > 1:
893888
i = torch.ones(1, dtype=torch.float32)
894889
if self.attn.backend in (_Backend.FLASH_ATTN,

vllm/model_executor/models/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -650,4 +650,4 @@ def cast_overflow_tensors(
650650
if tensors.isinf().any() or tensors.isnan().any():
651651
clamp_value = torch.finfo(tensors.dtype).max - offset
652652
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
653-
return tensors
653+
return tensors

0 commit comments

Comments
 (0)