Skip to content

Commit cfb22e0

Browse files
authored
Support Clip QKV for MPT (#31307)
1 parent b767282 commit cfb22e0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/transformers/models/mpt/modeling_mpt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(self, config: MptConfig):
8282
self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
8383

8484
self.attn_dropout_p = config.attn_config.attn_pdrop
85+
self.clip_qkv = config.attn_config.clip_qkv
8586
self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
8687
self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
8788

@@ -95,6 +96,9 @@ def forward(
9596
batch_size, seq_length = hidden_states.shape[:2]
9697

9798
mixed_qkv = self.Wqkv(hidden_states)
99+
if self.clip_qkv:
100+
mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
101+
98102
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
99103
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
100104
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)

0 commit comments

Comments
 (0)