Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def __init__(self, config):
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(self.rotary_ndims, base=config.rotary_emb_base)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
Expand Down Expand Up @@ -207,7 +209,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
query,
key.transpose(1, 2),
beta=1.0,
alpha=(1.0 / self.norm_factor),
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

Expand Down Expand Up @@ -238,25 +240,32 @@ def attention_mask_func(attention_scores, ltor_mask):


class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, device=None):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)


def rotate_half(x):
Expand Down