-
Notifications
You must be signed in to change notification settings - Fork 330
Add LLAMA #1446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LLAMA #1446
Changes from 23 commits
0332c34
2242fc7
c0468f4
54b82af
b451954
b95af59
1a7833a
5ca98df
aa0d4cb
1a5f4d2
7a7627f
0ca561f
4b6a2ec
2033cdd
6b617f0
6d36574
2be2c07
db7690c
d039c6d
4ee4b71
17051bd
2fc7143
89d3724
21e93c0
3854311
2ec84a1
faf928f
d827288
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
|
|
||
|
|
||
| from ...util.model import BenchmarkModel | ||
| from torchbenchmark.tasks import NLP | ||
| import torch | ||
| from .model import ModelArgs, Transformer | ||
| import torch | ||
|
|
||
| class Model(BenchmarkModel): | ||
| task = NLP.LANGUAGE_MODELING | ||
|
|
||
| def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): | ||
| super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) | ||
| self.model_args = ModelArgs(vocab_size=32) | ||
| self.model = Transformer(self.model_args) | ||
|
|
||
|
|
||
| if device == "cuda": | ||
| torch.set_default_device("cuda") | ||
|
||
|
|
||
| self.model.to(torch.device(device)) | ||
| self.example_inputs = (torch.tensor([[1, 1], [1,1]], dtype=torch.int), 1) | ||
msaroufim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def get_module(self): | ||
| return self.model, self.example_inputs | ||
|
|
||
| def train(self): | ||
| error_msg = """ | ||
| As of March 6, 2023 | ||
| The weights for this model are not publicly available and require a valid research reason to use | ||
| The publicly available github repo is inference only | ||
| https://github.com/facebookresearch/llama | ||
| """ | ||
| return NotImplementedError(error_msg) | ||
|
|
||
| def eval(self): | ||
| self.model.eval() | ||
| with torch.no_grad(): | ||
| out=self.model(*self.example_inputs) | ||
| return (out,) | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from .tokenizer import Tokenizer | ||
| from .model import Transformer | ||
|
|
||
|
|
||
| class LLaMA: | ||
| def __init__(self, model: Transformer, tokenizer: Tokenizer): | ||
| self.model = model | ||
| self.tokenizer = tokenizer | ||
|
|
||
| def generate( | ||
| self, | ||
| prompts: List[str], | ||
| max_gen_len: int, | ||
| temperature: float = 0.8, | ||
| top_p: float = 0.95, | ||
| ) -> List[str]: | ||
| bsz = len(prompts) | ||
| params = self.model.params | ||
| assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | ||
|
|
||
| prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | ||
|
|
||
| min_prompt_size = min([len(t) for t in prompt_tokens]) | ||
| max_prompt_size = max([len(t) for t in prompt_tokens]) | ||
|
|
||
| total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | ||
|
|
||
| tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | ||
| for k, t in enumerate(prompt_tokens): | ||
| tokens[k, : len(t)] = torch.tensor(t).long() | ||
| input_text_mask = tokens != self.tokenizer.pad_id | ||
| start_pos = min_prompt_size | ||
| prev_pos = 0 | ||
| for cur_pos in range(start_pos, total_len): | ||
| logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) | ||
| if temperature > 0: | ||
| probs = torch.softmax(logits / temperature, dim=-1) | ||
| next_token = sample_top_p(probs, top_p) | ||
| else: | ||
| next_token = torch.argmax(logits, dim=-1) | ||
| next_token = next_token.reshape(-1) | ||
| # only replace token if prompt has already been generated | ||
| next_token = torch.where( | ||
| input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | ||
| ) | ||
| tokens[:, cur_pos] = next_token | ||
| prev_pos = cur_pos | ||
|
|
||
| decoded = [] | ||
| for i, t in enumerate(tokens.tolist()): | ||
| # cut to max gen len | ||
| t = t[: len(prompt_tokens[i]) + max_gen_len] | ||
| # cut to eos tok if any | ||
| try: | ||
| t = t[: t.index(self.tokenizer.eos_id)] | ||
| except ValueError: | ||
| pass | ||
| decoded.append(self.tokenizer.decode(t)) | ||
| return decoded | ||
|
|
||
|
|
||
| def sample_top_p(probs, p): | ||
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
| probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
| mask = probs_sum - probs_sort > p | ||
| probs_sort[mask] = 0.0 | ||
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
| next_token = torch.multinomial(probs_sort, num_samples=1) | ||
| next_token = torch.gather(probs_idx, -1, next_token) | ||
| return next_token |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import subprocess | ||
| import sys | ||
|
|
||
| def pip_install_requirements(): | ||
| subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
|
||
| if __name__ == '__main__': | ||
| pip_install_requirements() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| devices: | ||
| NVIDIA A100-SXM4-40GB: | ||
| eval_batch_size: 1024 | ||
| eval_benchmark: false | ||
| eval_deterministic: false | ||
| eval_nograd: true | ||
| train_benchmark: false | ||
| train_deterministic: false |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
| from typing import Optional, Tuple | ||
| from dataclasses import dataclass | ||
| import math | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelArgs: | ||
| dim: int = 512 | ||
| n_layers: int = 8 | ||
| n_heads: int = 8 | ||
| vocab_size: int = -1 | ||
| multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 | ||
| norm_eps: float = 1e-5 | ||
|
|
||
| max_batch_size: int = 32 | ||
| max_seq_len: int = 1024 | ||
|
|
||
|
|
||
| class RMSNorm(torch.nn.Module): | ||
| def __init__(self, dim: int, eps: float = 1e-6): | ||
| super().__init__() | ||
| self.eps = eps | ||
| self.weight = nn.Parameter(torch.ones(dim)) | ||
|
|
||
| def _norm(self, x): | ||
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
|
||
| def forward(self, x): | ||
| output = self._norm(x.float()).type_as(x) | ||
| return output * self.weight | ||
|
|
||
|
|
||
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
| t = torch.arange(end, device=freqs.device) # type: ignore | ||
| freqs = torch.outer(t, freqs).float() # type: ignore | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
| return freqs_cis | ||
|
|
||
|
|
||
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | ||
| ndim = x.ndim | ||
| assert 0 <= 1 < ndim | ||
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
| return freqs_cis.view(*shape) | ||
|
|
||
|
|
||
| def apply_rotary_emb( | ||
| xq: torch.Tensor, | ||
| xk: torch.Tensor, | ||
| freqs_cis: torch.Tensor, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
| return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
|
||
|
|
||
| class Attention(nn.Module): | ||
| def __init__(self, args: ModelArgs): | ||
| super().__init__() | ||
|
|
||
| self.n_local_heads = args.n_heads # Basically we just assume world size of 1 // fs_init.get_model_parallel_world_size() | ||
| self.head_dim = args.dim // args.n_heads | ||
|
|
||
| self.wq = nn.Linear( | ||
| args.dim, | ||
| args.n_heads * self.head_dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
| self.wk = nn.Linear( | ||
| args.dim, | ||
| args.n_heads * self.head_dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
| self.wv = nn.Linear( | ||
| args.dim, | ||
| args.n_heads * self.head_dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
| self.wo = nn.Linear( | ||
| args.n_heads * self.head_dim, | ||
| args.dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
|
|
||
| self.cache_k = torch.zeros( | ||
| (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) | ||
| ).cuda() | ||
| self.cache_v = torch.zeros( | ||
| (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) | ||
| ).cuda() | ||
|
|
||
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
| bsz, seqlen, _ = x.shape | ||
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | ||
|
|
||
| xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
| xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
| xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
|
|
||
| xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | ||
|
|
||
| self.cache_k = self.cache_k.to(xq) | ||
| self.cache_v = self.cache_v.to(xq) | ||
|
|
||
| self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk | ||
| self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv | ||
|
|
||
| keys = self.cache_k[:bsz, : start_pos + seqlen] | ||
| values = self.cache_v[:bsz, : start_pos + seqlen] | ||
|
|
||
| xq = xq.transpose(1, 2) | ||
| keys = keys.transpose(1, 2) | ||
| values = values.transpose(1, 2) | ||
| scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
|
|
||
| # TODO: RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3 | ||
| # if mask is not None: | ||
| # scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) | ||
| scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
| output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) | ||
| output = output.transpose( | ||
| 1, 2 | ||
| ).contiguous().view(bsz, seqlen, -1) | ||
|
|
||
| return self.wo(output) | ||
|
|
||
|
|
||
| class FeedForward(nn.Module): | ||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| hidden_dim: int, | ||
| multiple_of: int, | ||
| ): | ||
| super().__init__() | ||
| hidden_dim = int(2 * hidden_dim / 3) | ||
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
|
||
| self.w1 = nn.Linear( | ||
| dim, hidden_dim, bias=False | ||
| ) | ||
| self.w2 = nn.Linear( | ||
| hidden_dim, dim, bias=False | ||
| ) | ||
| self.w3 = nn.Linear( | ||
| dim, hidden_dim, bias=False | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
|
||
|
|
||
| class TransformerBlock(nn.Module): | ||
| def __init__(self, layer_id: int, args: ModelArgs): | ||
| super().__init__() | ||
| self.n_heads = args.n_heads | ||
| self.dim = args.dim | ||
| self.head_dim = args.dim // args.n_heads | ||
| self.attention = Attention(args) | ||
| self.feed_forward = FeedForward( | ||
| dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of | ||
| ) | ||
| self.layer_id = layer_id | ||
| self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
| self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
|
|
||
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
| h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) | ||
| out = h + self.feed_forward.forward(self.ffn_norm(h)) | ||
| return out | ||
|
|
||
|
|
||
| class Transformer(nn.Module): | ||
| def __init__(self, params: ModelArgs): | ||
| super().__init__() | ||
| self.params = params | ||
| self.vocab_size = params.vocab_size | ||
| self.n_layers = params.n_layers | ||
|
|
||
| self.tok_embeddings = nn.Embedding( | ||
| params.vocab_size + 1, params.dim, | ||
| ) | ||
|
|
||
|
|
||
| self.layers = torch.nn.ModuleList() | ||
| for layer_id in range(params.n_layers): | ||
| self.layers.append(TransformerBlock(layer_id, params)) | ||
|
|
||
| self.norm = RMSNorm(params.dim, eps=params.norm_eps) | ||
| self.output = nn.Linear( | ||
| params.dim, params.vocab_size + 1, bias=False | ||
| ) | ||
|
|
||
| self.freqs_cis = precompute_freqs_cis( | ||
| self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 | ||
| ) | ||
|
|
||
| def forward(self, tokens: torch.Tensor, start_pos: int): | ||
| _ , seqlen = tokens.shape | ||
|
|
||
| h = self.tok_embeddings(tokens) | ||
|
|
||
| self.freqs_cis = self.freqs_cis.to(h.device) | ||
| freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | ||
|
|
||
| mask = None | ||
|
|
||
| if seqlen > 1: | ||
| mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) | ||
| mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) | ||
|
|
||
| for layer in self.layers: | ||
| h = layer(h, start_pos, freqs_cis, mask) | ||
| h = self.norm(h) | ||
| output = self.output(h[:, -1, :]) # only compute last logits | ||
| return output.float() |
Uh oh!
There was an error while loading. Please reload this page.