-
Notifications
You must be signed in to change notification settings - Fork 330
Add LLAMA #1446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LLAMA #1446
Changes from 15 commits
0332c34
2242fc7
c0468f4
54b82af
b451954
b95af59
1a7833a
5ca98df
aa0d4cb
1a5f4d2
7a7627f
0ca561f
4b6a2ec
2033cdd
6b617f0
6d36574
2be2c07
db7690c
d039c6d
4ee4b71
17051bd
2fc7143
89d3724
21e93c0
3854311
2ec84a1
faf928f
d827288
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
|
|
||
|
|
||
| from ...util.model import BenchmarkModel | ||
| from torchbenchmark.tasks import NLP | ||
| import torch | ||
| from .model import ModelArgs, Transformer | ||
| import torch | ||
|
|
||
| class Model(BenchmarkModel): | ||
| DEFAULT_EVAL_BSIZE = 32 | ||
| task = NLP.LANGUAGE_MODELING | ||
|
|
||
| def __init__(self, test, device, jit=False, batch_size=DEFAULT_EVAL_BSIZE, extra_args=[]): | ||
| super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) | ||
| self.model_args = ModelArgs(vocab_size=32) # TODO: Configuring arguments is breaking stuff: max_batch_size=batch_size, max_seq_len=1032 is breaking stuff | ||
xuzhao9 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.model = Transformer(self.model_args) | ||
|
|
||
| # TODO: Implement batching | ||
|
|
||
| if device == "cuda": | ||
msaroufim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| torch.set_default_device("cuda") | ||
|
||
| self.model.to(torch.device("cuda")) | ||
| self.example_inputs = [(torch.tensor([[1, 1], [1,1]], dtype=torch.int), 1)] | ||
|
|
||
|
|
||
| def get_module(self): | ||
| return self.model, self.example_inputs | ||
|
|
||
| def train(self): | ||
| error_msg = """ | ||
| As of March 6, 2023 | ||
| The weights for this model are not publicly available and require a valid research reason to use | ||
| The publicly available github repo is inference only | ||
| https://github.com/facebookresearch/llama | ||
| """ | ||
| return NotImplementedError(error_msg) | ||
|
|
||
| def eval(self): | ||
| self.model.eval() | ||
| with torch.no_grad(): | ||
| for example_input in self.example_inputs: | ||
msaroufim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| out=self.model(*example_input) | ||
| return (out,) | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | ||
|
|
||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from .tokenizer import Tokenizer | ||
| from .model import Transformer | ||
|
|
||
|
|
||
| class LLaMA: | ||
| def __init__(self, model: Transformer, tokenizer: Tokenizer): | ||
| self.model = model | ||
| self.tokenizer = tokenizer | ||
|
|
||
| def generate( | ||
| self, | ||
| prompts: List[str], | ||
| max_gen_len: int, | ||
| temperature: float = 0.8, | ||
| top_p: float = 0.95, | ||
| ) -> List[str]: | ||
| bsz = len(prompts) | ||
| params = self.model.params | ||
| assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | ||
|
|
||
| prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | ||
|
|
||
| min_prompt_size = min([len(t) for t in prompt_tokens]) | ||
| max_prompt_size = max([len(t) for t in prompt_tokens]) | ||
|
|
||
| total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | ||
|
|
||
| tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | ||
| for k, t in enumerate(prompt_tokens): | ||
| tokens[k, : len(t)] = torch.tensor(t).long() | ||
| input_text_mask = tokens != self.tokenizer.pad_id | ||
| start_pos = min_prompt_size | ||
| prev_pos = 0 | ||
| for cur_pos in range(start_pos, total_len): | ||
| logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) | ||
| if temperature > 0: | ||
| probs = torch.softmax(logits / temperature, dim=-1) | ||
| next_token = sample_top_p(probs, top_p) | ||
| else: | ||
| next_token = torch.argmax(logits, dim=-1) | ||
| next_token = next_token.reshape(-1) | ||
| # only replace token if prompt has already been generated | ||
| next_token = torch.where( | ||
| input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | ||
| ) | ||
| tokens[:, cur_pos] = next_token | ||
| prev_pos = cur_pos | ||
|
|
||
| decoded = [] | ||
| for i, t in enumerate(tokens.tolist()): | ||
| # cut to max gen len | ||
| t = t[: len(prompt_tokens[i]) + max_gen_len] | ||
| # cut to eos tok if any | ||
| try: | ||
| t = t[: t.index(self.tokenizer.eos_id)] | ||
| except ValueError: | ||
| pass | ||
| decoded.append(self.tokenizer.decode(t)) | ||
| return decoded | ||
|
|
||
|
|
||
| def sample_top_p(probs, p): | ||
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
| probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
| mask = probs_sum - probs_sort > p | ||
| probs_sort[mask] = 0.0 | ||
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
| next_token = torch.multinomial(probs_sort, num_samples=1) | ||
| next_token = torch.gather(probs_idx, -1, next_token) | ||
| return next_token |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| import subprocess | ||
| import sys | ||
|
|
||
| def pip_install_requirements(): | ||
| subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
|
||
| if __name__ == '__main__': | ||
| pip_install_requirements() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| devices: | ||
| NVIDIA A100-SXM4-40GB: | ||
| eval_batch_size: 1024 | ||
| eval_benchmark: false | ||
| eval_deterministic: false | ||
| eval_nograd: true | ||
| train_benchmark: false | ||
| train_deterministic: false |
Uh oh!
There was an error while loading. Please reload this page.