-
Notifications
You must be signed in to change notification settings - Fork 62
/
Qwen2.py
300 lines (240 loc) · 14.9 KB
/
Qwen2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# Copied from the implementation for Mistral.
# transfromers version 4.38.2
# No support of sliding window. Check our paper for more reason about why we don't use it.
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
from transformers.cache_utils import Cache
import numpy as np
from flash_attn import flash_attn_func, flash_attn_varlen_func
from .selfextend_flash_attn import self_extend_flash_forward
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos[:,:, -q.shape[2]:]) + (rotate_half(q) * sin[:,:, -q.shape[2]:]) if q is not None else None
k_embed = (k * cos) + (rotate_half(k) * sin) if k is not None else None
return q_embed, k_embed
def apply_grouped_rotary_pos_emb(q, k, cos, sin, position_ids, g_size_1=1, g_size_2=4096):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
position_ids_q = position_ids//g_size_1 + g_size_2 - g_size_2//g_size_1
position_ids_k = position_ids//g_size_1
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos_q = cos[position_ids_q].unsqueeze(1) # [bs, 1, seq_len, dim]
sin_q = sin[position_ids_q].unsqueeze(1) # [bs, 1, seq_len, dim]
cos_k = cos[position_ids_k].unsqueeze(1) # [bs, 1, seq_len, dim]
sin_k = sin[position_ids_k].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos_q) + (rotate_half(q) * sin_q) if q is not None else None
k_embed = (k * cos_k) + (rotate_half(k) * sin_k) if k is not None else None
return q_embed, k_embed
def self_extend_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
group_size_1: Optional[float] = 8,
group_size_2: Optional[float] = 2048,
scale_base: Optional[float] = -1,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if scale_base > 0:
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
else:
scaled_query = query_states
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_position = position_ids
# only consider bsz=1 for now.
key_position = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len)
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position)
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, cos, sin, key_position)
_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
group_query_states, _ = apply_grouped_rotary_pos_emb(scaled_query, None, cos, sin, query_position, g_size_1=group_size_1, g_size_2=_re_group_size_2)
_, group_key_states = apply_grouped_rotary_pos_emb(None, key_states, cos, sin, key_position, g_size_1=group_size_1, g_size_2=_re_group_size_2)
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups)
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
neighbor_attn_weights = torch.matmul(neighbor_query_states, neighbor_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
group_attn_weights = torch.matmul(group_query_states, group_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if group_attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {group_attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
group_attn_weights = group_attn_weights + attention_mask
neighbor_attn_weights = neighbor_attn_weights + attention_mask
if q_len == 1:
neighbor_attention_mask = torch.zeros((q_len, kv_seq_len), device=neighbor_attn_weights.device)
neighbor_attention_mask[:, -group_size_2:] = 1
elif q_len == kv_seq_len:
neighbor_attention_mask = torch.ones((q_len, kv_seq_len), device=neighbor_attn_weights.device)
neighbor_attention_mask = torch.tril(neighbor_attention_mask)
if q_len-group_size_2 > 0:
group_attention_mask = torch.tril(torch.ones((q_len-group_size_2, kv_seq_len-group_size_2), device=group_attn_weights.device))
neighbor_attention_mask[group_size_2:, :-group_size_2] -= group_attention_mask
else:
raise ValueError("q_len should be 1 or seq_len.")
neighbor_attention_mask = neighbor_attention_mask.bool()
attn_weights = torch.where(neighbor_attention_mask, neighbor_attn_weights, group_attn_weights)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def flash_self_extend_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
group_size_1: Optional[float] = 8,
group_size_2: Optional[float] = 2048,
scale_base: Optional[float] = -1,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
if scale_base > 0:
scaled_query = query_states * ((position_ids + 1)[:, None, :, None].log() / np.log(scale_base)).clip(1).to(query_states.dtype) # log scale
#scaled_query = query_states * (((0.1*(((position_ids+1)[:, None, :, None]/scale_base).log())+1)**2).clip(1)).to(query_states.dtype) # Yarn scale
else:
scaled_query = query_states
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
query_position = position_ids
key_position = torch.arange(kv_seq_len, dtype=position_ids.dtype).to(query_position.device).view(1, kv_seq_len) # only support batch=1 for now.
attn_dropout = self.config.attention_dropout if self.training else 0.0
if q_len == 1:
_re_group_size_2 = 0 if position_ids.max() < group_size_2 else group_size_2
neighbor_key_position = position_ids[:, -1] - key_position
group_key_position = position_ids[:, -1]//group_size_1 - key_position//group_size_1 + (_re_group_size_2 - _re_group_size_2//group_size_1)
decode_key_position = torch.cat([group_key_position[:, :-group_size_2], neighbor_key_position[:,-group_size_2:]], dim=1)
#import pdb; pdb.set_trace()
#neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position_ids)
decode_query_states = scaled_query.transpose(1,2).contiguous() # position 0: cos 0 = 1, sin 0 = 0
_, decode_key_states = apply_rotary_pos_emb(None, key_states, cos, -sin, decode_key_position)
decode_key_states = repeat_kv(decode_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
decode_value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
attn_output = flash_attn_func(decode_query_states,
decode_key_states,
decode_value_states,
attn_dropout,
softmax_scale=None,
causal=True)
elif q_len == kv_seq_len:
# set correct position_ids & apply RoPE.
_re_group_size_2 = 0 if query_position.max() < group_size_2 else group_size_2 # in case that, the smallest q position, g2-g2//g1 exceed the max position
neighbor_query_states, _ = apply_rotary_pos_emb(scaled_query, None, cos, sin, query_position)
_, neighbor_key_states = apply_rotary_pos_emb(None, key_states, cos, sin, key_position)
group_query_states, _ = apply_grouped_rotary_pos_emb(scaled_query, None, cos, sin, query_position, g_size_1=group_size_1, g_size_2=_re_group_size_2)
_, group_key_states = apply_grouped_rotary_pos_emb(None, key_states, cos, sin, key_position, g_size_1=group_size_1, g_size_2=_re_group_size_2)
neighbor_query_states = neighbor_query_states.transpose(1, 2).contiguous()
neighbor_key_states = repeat_kv(neighbor_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
group_query_states = group_query_states.transpose(1, 2).contiguous()
group_key_states = repeat_kv(group_key_states, self.num_key_value_groups).transpose(1, 2).contiguous()
value_states = repeat_kv(value_states, self.num_key_value_groups).transpose(1, 2).contiguous()
attn_output = self_extend_flash_forward(self,
query_position,
group_size_2,
neighbor_query_states,
neighbor_key_states,
group_query_states,
group_key_states,
value_states,
attention_mask,
bsz,
q_len,
kv_seq_len,
attn_dropout,
)
else:
raise ValueError("q_len should be 1 or seq_len.")
attn_output = attn_output.contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value