Skip to content

Commit

Permalink
we can inference Meta's Llama 2 7B, yay
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jul 25, 2023
1 parent 133ad3f commit c3e0d73
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 4 deletions.
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,32 @@ This still runs at interactive rates and samples more coherent and diverse stori

*Once upon a time, there was a little girl named Lily. She loved playing with her toys on top of her bed. One day, she decided to have a tea party with her stuffed animals. She poured some tea into a tiny teapot and put it on top of the teapot. Suddenly, her little brother Max came into the room and wanted to join the tea party too. Lily didn't want to share her tea and she told Max to go away. Max started to cry and Lily felt bad. She decided to yield her tea party to Max and they both shared the teapot. But then, something unexpected happened. The teapot started to shake and wiggle. Lily and Max were scared and didn't know what to do. Suddenly, the teapot started to fly towards the ceiling and landed on the top of the bed. Lily and Max were amazed and they hugged each other. They realized that sharing was much more fun than being selfish. From that day on, they always shared their tea parties and toys.*

## Meta's Llama 2 models

As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. First you'll have to export these weights in the llama2.c format. Git clone the main repo from Meta, and cp the `export_meta_llama_bin.py` file (in the root directory of this project) over, and run it:

```bash
git clone https://github.com/facebookresearch/llama.git
cd llama
cp /path/to/llama2.c/export_meta_llama_bin.py .
torchrun --nproc_per_node 1 export_meta_llama_bin.py
```

The export will take ~20 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. Go back to the root directory of llama2.c and run!

```bash
./run path/to/llama2_7b.bin
```

This ran at about 4 tokens/s compiled with OpenMP on 96 threads on my CPU Linux box in the cloud. Example output:

*<s>The purpose of this document is to highlight the state-of-the-art of CoO generation technologies, both recent developments and those in commercial use. The focus is on the technologies with the highest merit to become the dominating processes of the future and therefore to be technologies of interest to S&amp;T ... R&amp;D. As such, CoO generation technologies developed in Russia, Japan and Europe are described in some depth. The document starts with an introduction to cobalt oxides as complex products and a short view on cobalt as an essential material. The document continues with the discussion of the available CoO generation processes with respect to energy and capital consumption as well as to environmental damage.*

base models... ¯\_(ツ)_/¯.

## models

It looks like I will have multiple models that I will train on TinyStories, I will catalogue them here.
For the sake of examples of smaller, from-scratch models, I trained multiple models on TinyStories and catalogue them here:

| model | dim | n_layers | n_heads | max context length | parameters | download
| --- | --- | --- | --- | --- | --- | --- |
Expand Down
91 changes: 91 additions & 0 deletions export_meta_llama_bin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
This script exports the Llama 2 weights in llama2c.bin format.
Place it into the root directory of:
https://github.com/facebookresearch/llama
And then run it similar to their other examples, via torchrun sadly:
torchrun --nproc_per_node 1 export_meta_llama_bin.py
"""

from llama import Llama

# -----------------------------------------------------------------------------
def export(self, filepath='model.bin'):
"""export the model weights in fp32 into .bin file to be read from C"""

f = open(filepath, 'wb')
import struct
import numpy as np

def serialize(t):
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
b = struct.pack(f'{len(d)}f', *d)
f.write(b)

# first write out the header
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
p = self.params
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, -p.vocab_size, p.max_seq_len)
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
f.write(header)

# next write out the embedding weights
print("writing tok_embeddings...")
serialize(self.tok_embeddings.weight)

# now all the layers
# attention weights
for i, layer in enumerate(self.layers):
print(f"writing attention_norm layer {i}...")
serialize(layer.attention_norm.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wq layer {i}...")
serialize(layer.attention.wq.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wk layer {i}...")
serialize(layer.attention.wk.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wv layer {i}...")
serialize(layer.attention.wv.weight)
for i, layer in enumerate(self.layers):
print(f"writing attention.wo layer {i}...")
serialize(layer.attention.wo.weight)
# ffn weights
for i, layer in enumerate(self.layers):
print(f"writing ffn_norm layer {i}...")
serialize(layer.ffn_norm.weight)
for i, layer in enumerate(self.layers):
print(f"writing feed_forward.w1 layer {i}...")
serialize(layer.feed_forward.w1.weight)
for i, layer in enumerate(self.layers):
print(f"writing feed_forward.w2 layer {i}...")
serialize(layer.feed_forward.w2.weight)
for i, layer in enumerate(self.layers):
print(f"writing feed_forward.w3 layer {i}...")
serialize(layer.feed_forward.w3.weight)
# final rmsnorm
print("writing final rmsnorm, classifier and freq_cis...")
serialize(self.norm.weight)
# freqs_cis
serialize(self.freqs_cis.real[:p.max_seq_len])
serialize(self.freqs_cis.imag[:p.max_seq_len])
# finally write the output weights
serialize(self.output.weight)

# write to binary file
f.close()
print(f"wrote {filepath}")
# -----------------------------------------------------------------------------

# init Llama as normal
generator = Llama.build(
ckpt_dir="llama-2-7b",
tokenizer_path="tokenizer.model",
max_seq_len=4096,
max_batch_size=1,
)
export(generator.model, "llama2_7b.bin")
14 changes: 11 additions & 3 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ typedef struct {
// freq_cis for RoPE relatively positional embeddings
float* freq_cis_real; // (seq_len, dim/2)
float* freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;

typedef struct {
Expand Down Expand Up @@ -110,7 +112,7 @@ void free_run_state(RunState* s) {
// ----------------------------------------------------------------------------
// initialization: read from checkpoint

void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
float* ptr = f;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
Expand Down Expand Up @@ -138,6 +140,8 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f) {
int head_size = p->dim / p->n_heads;
ptr += p->seq_len * head_size / 2;
w->freq_cis_imag = ptr;
ptr += p->seq_len * head_size / 2;
w->wcls = shared_weights ? w->token_embedding_table : ptr;
}

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -319,7 +323,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
rmsnorm(x, x, w->rms_final_weight, dim);

// classifier into logits
matmul(s->logits, x, w->token_embedding_table, p->dim, p->vocab_size);
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
}

int sample(float* probabilities, int n) {
Expand Down Expand Up @@ -395,6 +399,9 @@ int main(int argc, char *argv[]) {
}
// read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config.vocab_size > 0 ? 1 : 0;
config.vocab_size = abs(config.vocab_size);
// figure out the file size
fseek(file, 0, SEEK_END); // move file pointer to end of file
file_size = ftell(file); // get the file size, in bytes
Expand All @@ -404,7 +411,8 @@ int main(int argc, char *argv[]) {
if (fd == -1) { printf("open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
checkpoint_init_weights(&weights, &config, data + sizeof(Config)/sizeof(float));
float* weights_ptr = data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
}
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
Expand Down

0 comments on commit c3e0d73

Please sign in to comment.