diff --git a/llm/llama2/position_embeddings.py b/llm/llama2/position_embeddings.py index 08e7aa9148..b385403f64 100644 --- a/llm/llama2/position_embeddings.py +++ b/llm/llama2/position_embeddings.py @@ -29,6 +29,9 @@ class RotaryPositionalEmbeddings(nn.Module): Args: x (tensor): input tensor to which rope is applied + Returns: + torch.Tensor: output tensor with RoPE applied + """ def __init__( @@ -64,9 +67,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: seq_len = x.size(1) rope_cache = self.cache[:seq_len] - # cast because the reference does + # reshape input; the last dimension is used for computing the output + # cast to float to match the reference implementation xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0]