diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index a4f7913ba161..569ead7bdf3f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -91,7 +91,9 @@ def __init__(self, config): ), ) self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.rotary_emb = RotaryEmbedding(self.rotary_ndims, base=config.rotary_emb_base) + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + ) self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -207,7 +209,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): query, key.transpose(1, 2), beta=1.0, - alpha=(1.0 / self.norm_factor), + alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), ) attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) @@ -238,17 +240,24 @@ def attention_mask_func(attention_scores, ltor_mask): class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, device=None): + def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) @@ -256,7 +265,7 @@ def forward(self, x, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos()[None, None, :, :] self.sin_cached = emb.sin()[None, None, :, :] - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) def rotate_half(x):