-
Notifications
You must be signed in to change notification settings - Fork 330
Add LLAMA #1446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Add LLAMA #1446
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
0332c34
LLAMA
msaroufim 2242fc7
Add LLAMA
msaroufim c0468f4
add intall.py
msaroufim 54b82af
Fixed some stuff
msaroufim b451954
test now runs
msaroufim b95af59
fix model
msaroufim 1a7833a
upd
msaroufim 5ca98df
updat docs
msaroufim aa0d4cb
add stuff
msaroufim 1a5f4d2
minor fix
msaroufim 7a7627f
flatten
msaroufim 0ca561f
fixed CI issues
msaroufim 4b6a2ec
Merge branch 'main' into llama
msaroufim 2033cdd
made sure model runs on GPU
msaroufim 6b617f0
pass
msaroufim 6d36574
push
msaroufim 2be2c07
update
msaroufim db7690c
upd
msaroufim d039c6d
fixed test_llama_example_cuda
msaroufim 4ee4b71
clarify batching limitation
msaroufim 17051bd
Address Xu feedback
msaroufim 2fc7143
Added support for batching
msaroufim 89d3724
push
msaroufim 21e93c0
Update torchbenchmark/models/llama/__init__.py
msaroufim 3854311
update
msaroufim 2ec84a1
Merge branch 'llama' of https://github.com/pytorch/benchmark into llama
msaroufim faf928f
push
msaroufim d827288
Update __init__.py
msaroufim File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
|
|
||
|
|
||
| from ...util.model import BenchmarkModel | ||
| from torchbenchmark.tasks import NLP | ||
| import torch | ||
| from .model import ModelArgs, Transformer | ||
| import torch | ||
|
|
||
| class Model(BenchmarkModel): | ||
| task = NLP.LANGUAGE_MODELING | ||
|
|
||
| def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): | ||
| super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) | ||
| self.model_args = ModelArgs(vocab_size=32) | ||
| self.model = Transformer(self.model_args) | ||
|
|
||
| torch.set_default_device(device) | ||
| self.example_inputs = (torch.tensor([[1, 1], [1,1]], dtype=torch.int), 1) | ||
msaroufim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def get_module(self): | ||
| return self.model, self.example_inputs | ||
|
|
||
| def train(self): | ||
| error_msg = """ | ||
| As of March 6, 2023 | ||
| The weights for this model are not publicly available and require a valid research reason to use | ||
| The publicly available github repo is inference only | ||
| https://github.com/facebookresearch/llama | ||
| """ | ||
| return NotImplementedError(error_msg) | ||
|
|
||
| def eval(self): | ||
| self.model.eval() | ||
| with torch.no_grad(): | ||
| out=self.model(*self.example_inputs) | ||
| return (out,) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from .tokenizer import Tokenizer | ||
| from .model import Transformer | ||
|
|
||
|
|
||
| class LLaMA: | ||
| def __init__(self, model: Transformer, tokenizer: Tokenizer): | ||
| self.model = model | ||
| self.tokenizer = tokenizer | ||
|
|
||
| def generate( | ||
| self, | ||
| prompts: List[str], | ||
| max_gen_len: int, | ||
| temperature: float = 0.8, | ||
| top_p: float = 0.95, | ||
| ) -> List[str]: | ||
| bsz = len(prompts) | ||
| params = self.model.params | ||
| assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | ||
|
|
||
| prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | ||
|
|
||
| min_prompt_size = min([len(t) for t in prompt_tokens]) | ||
| max_prompt_size = max([len(t) for t in prompt_tokens]) | ||
|
|
||
| total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | ||
|
|
||
| tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | ||
| for k, t in enumerate(prompt_tokens): | ||
| tokens[k, : len(t)] = torch.tensor(t).long() | ||
| input_text_mask = tokens != self.tokenizer.pad_id | ||
| start_pos = min_prompt_size | ||
| prev_pos = 0 | ||
| for cur_pos in range(start_pos, total_len): | ||
| logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) | ||
| if temperature > 0: | ||
| probs = torch.softmax(logits / temperature, dim=-1) | ||
| next_token = sample_top_p(probs, top_p) | ||
| else: | ||
| next_token = torch.argmax(logits, dim=-1) | ||
| next_token = next_token.reshape(-1) | ||
| # only replace token if prompt has already been generated | ||
| next_token = torch.where( | ||
| input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | ||
| ) | ||
| tokens[:, cur_pos] = next_token | ||
| prev_pos = cur_pos | ||
|
|
||
| decoded = [] | ||
| for i, t in enumerate(tokens.tolist()): | ||
| # cut to max gen len | ||
| t = t[: len(prompt_tokens[i]) + max_gen_len] | ||
| # cut to eos tok if any | ||
| try: | ||
| t = t[: t.index(self.tokenizer.eos_id)] | ||
| except ValueError: | ||
| pass | ||
| decoded.append(self.tokenizer.decode(t)) | ||
| return decoded | ||
|
|
||
|
|
||
| def sample_top_p(probs, p): | ||
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
| probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
| mask = probs_sum - probs_sort > p | ||
| probs_sort[mask] = 0.0 | ||
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
| next_token = torch.multinomial(probs_sort, num_samples=1) | ||
| next_token = torch.gather(probs_idx, -1, next_token) | ||
| return next_token |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import subprocess | ||
| import sys | ||
|
|
||
| def pip_install_requirements(): | ||
| subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
|
||
| if __name__ == '__main__': | ||
| pip_install_requirements() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| devices: | ||
| NVIDIA A100-SXM4-40GB: | ||
| eval_batch_size: 1024 | ||
| eval_benchmark: false | ||
| eval_deterministic: false | ||
| eval_nograd: true | ||
| train_benchmark: false | ||
| train_deterministic: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
| from typing import Optional, Tuple | ||
| from dataclasses import dataclass | ||
| import math | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelArgs: | ||
| dim: int = 512 | ||
| n_layers: int = 8 | ||
| n_heads: int = 8 | ||
| vocab_size: int = -1 | ||
| multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 | ||
| norm_eps: float = 1e-5 | ||
|
|
||
| max_batch_size: int = 32 | ||
| max_seq_len: int = 1024 | ||
|
|
||
|
|
||
| class RMSNorm(torch.nn.Module): | ||
| def __init__(self, dim: int, eps: float = 1e-6): | ||
| super().__init__() | ||
| self.eps = eps | ||
| self.weight = nn.Parameter(torch.ones(dim)) | ||
|
|
||
| def _norm(self, x): | ||
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
|
||
| def forward(self, x): | ||
| output = self._norm(x.float()).type_as(x) | ||
| return output * self.weight | ||
|
|
||
|
|
||
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | ||
| t = torch.arange(end, device=freqs.device) # type: ignore | ||
| freqs = torch.outer(t, freqs).float() # type: ignore | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | ||
| return freqs_cis | ||
|
|
||
|
|
||
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | ||
| ndim = x.ndim | ||
| assert 0 <= 1 < ndim | ||
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | ||
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | ||
| return freqs_cis.view(*shape) | ||
|
|
||
|
|
||
| def apply_rotary_emb( | ||
| xq: torch.Tensor, | ||
| xk: torch.Tensor, | ||
| freqs_cis: torch.Tensor, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | ||
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
| return xq_out.type_as(xq), xk_out.type_as(xk) | ||
|
|
||
|
|
||
| class Attention(nn.Module): | ||
| def __init__(self, args: ModelArgs): | ||
| super().__init__() | ||
|
|
||
| self.n_local_heads = args.n_heads # Basically we just assume world size of 1 // fs_init.get_model_parallel_world_size() | ||
| self.head_dim = args.dim // args.n_heads | ||
|
|
||
| self.wq = nn.Linear( | ||
| args.dim, | ||
| args.n_heads * self.head_dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
| self.wk = nn.Linear( | ||
| args.dim, | ||
| args.n_heads * self.head_dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
| self.wv = nn.Linear( | ||
| args.dim, | ||
| args.n_heads * self.head_dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
| self.wo = nn.Linear( | ||
| args.n_heads * self.head_dim, | ||
| args.dim, | ||
| bias=False, | ||
|
|
||
| ) | ||
|
|
||
| self.cache_k = torch.zeros( | ||
| (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) | ||
| ).cuda() | ||
| self.cache_v = torch.zeros( | ||
| (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) | ||
| ).cuda() | ||
|
|
||
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
| bsz, seqlen, _ = x.shape | ||
| xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | ||
|
|
||
| xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
| xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
| xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) | ||
|
|
||
| xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | ||
|
|
||
| self.cache_k = self.cache_k.to(xq) | ||
| self.cache_v = self.cache_v.to(xq) | ||
|
|
||
| self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk | ||
| self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv | ||
|
|
||
| keys = self.cache_k[:bsz, : start_pos + seqlen] | ||
| values = self.cache_v[:bsz, : start_pos + seqlen] | ||
|
|
||
| xq = xq.transpose(1, 2) | ||
| keys = keys.transpose(1, 2) | ||
| values = values.transpose(1, 2) | ||
| scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
|
|
||
| # TODO: RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3 | ||
| # if mask is not None: | ||
| # scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) | ||
| scores = F.softmax(scores.float(), dim=-1).type_as(xq) | ||
| output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) | ||
| output = output.transpose( | ||
| 1, 2 | ||
| ).contiguous().view(bsz, seqlen, -1) | ||
|
|
||
| return self.wo(output) | ||
|
|
||
|
|
||
| class FeedForward(nn.Module): | ||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| hidden_dim: int, | ||
| multiple_of: int, | ||
| ): | ||
| super().__init__() | ||
| hidden_dim = int(2 * hidden_dim / 3) | ||
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | ||
|
|
||
| self.w1 = nn.Linear( | ||
| dim, hidden_dim, bias=False | ||
| ) | ||
| self.w2 = nn.Linear( | ||
| hidden_dim, dim, bias=False | ||
| ) | ||
| self.w3 = nn.Linear( | ||
| dim, hidden_dim, bias=False | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) | ||
|
|
||
|
|
||
| class TransformerBlock(nn.Module): | ||
| def __init__(self, layer_id: int, args: ModelArgs): | ||
| super().__init__() | ||
| self.n_heads = args.n_heads | ||
| self.dim = args.dim | ||
| self.head_dim = args.dim // args.n_heads | ||
| self.attention = Attention(args) | ||
| self.feed_forward = FeedForward( | ||
| dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of | ||
| ) | ||
| self.layer_id = layer_id | ||
| self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
| self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | ||
|
|
||
| def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): | ||
| h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) | ||
| out = h + self.feed_forward.forward(self.ffn_norm(h)) | ||
| return out | ||
|
|
||
|
|
||
| class Transformer(nn.Module): | ||
| def __init__(self, params: ModelArgs): | ||
| super().__init__() | ||
| self.params = params | ||
| self.vocab_size = params.vocab_size | ||
| self.n_layers = params.n_layers | ||
|
|
||
| self.tok_embeddings = nn.Embedding( | ||
| params.vocab_size + 1, params.dim, | ||
| ) | ||
|
|
||
|
|
||
| self.layers = torch.nn.ModuleList() | ||
| for layer_id in range(params.n_layers): | ||
| self.layers.append(TransformerBlock(layer_id, params)) | ||
|
|
||
| self.norm = RMSNorm(params.dim, eps=params.norm_eps) | ||
| self.output = nn.Linear( | ||
| params.dim, params.vocab_size + 1, bias=False | ||
| ) | ||
|
|
||
| self.freqs_cis = precompute_freqs_cis( | ||
| self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 | ||
| ) | ||
|
|
||
| def forward(self, tokens: torch.Tensor, start_pos: int): | ||
| _ , seqlen = tokens.shape | ||
|
|
||
| h = self.tok_embeddings(tokens) | ||
|
|
||
| self.freqs_cis = self.freqs_cis.to(h.device) | ||
| freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] | ||
|
|
||
| mask = None | ||
|
|
||
| if seqlen > 1: | ||
| mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) | ||
| mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) | ||
|
|
||
| for layer in self.layers: | ||
| h = layer(h, start_pos, freqs_cis, mask) | ||
| h = self.norm(h) | ||
| output = self.output(h[:, -1, :]) # only compute last logits | ||
| return output.float() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| origin https://github.com/facebookresearch/llama |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| sentencepiece |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious why we don't need to explicitly move
self.modelandself.example_inputsto the device here? For example:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed