Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implent two types of rope & Add support to tinyllama-1.1B-chat-v0.2 #408

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,44 @@ There is also an even better 110M param model available, see [models](#models).

Quick note on sampling, the recommendation for ~best results is to sample with `-t 1.0 -p 0.9`, i.e. temperature 1.0 (default) but also top-p sampling at 0.9 (default). Intuitively, top-p ensures that tokens with tiny probabilities do not get sampled, so we can't get "unlucky" during sampling, and we are less likely to go "off the rails" afterwards. More generally, to control the diversity of samples use either the temperature (i.e. vary `-t` between 0 and 1 and keep top-p off with `-p 0`) or the top-p value (i.e. vary `-p` between 0 and 1 and keep `-t 1`), but not both. Nice explainers on LLM sampling strategies include [this](https://peterchng.com/blog/2023/05/02/token-selection-strategies-top-k-top-p-and-temperature/), [this](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p) or [this](https://huggingface.co/blog/how-to-generate).

## Tiny Llama 1.1B model
The [TinyLlama](https://github.com/jzhang38/TinyLlama) is a 1.1B Llama model trained on 3 trillion tokens. This compactness allows it to cater to a multitude of applications demanding a restricted computation and memory footprint. This is also the reason why we select it as the first billion parameter model to support.

Let's download the model and the tokenizer from huggingface https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin.

```bash
wget https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin/resolve/main/tok_tl-chat.bin
wget https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin/resolve/main/tl-chat.bin
```

Run the model.
```bash
./run tl-chat.bin -z tok_tl-chat.bin \
-n 512 -t 0.0 -s 100 \
-i "<|im_start|>user\nExplain huggingface.<|im_end|>\n<|im_start|>assistant\n"
```

Sample output:
```<|im_start|>user
Explain huggingface.<|im_end|>
<|im_start|>assistant
Huggingface is a software platform that provides tools and resources for building and hosting large-scale machine learning models and datasets. It is designed to make it easier and faster to build, train, and deploy models for a wide range of applications, including natural language processing, computer vision, and generative models.

Huggingface provides a set of tools and resources, including:

1. A framework for building and hosting large-scale machine learning models and datasets.
2. A set of pre-trained models and datasets that can be used with your Huggingface model.
3. A set of tools for data preparation, cleaning, and formatting.
4. A set of tools for model training, evaluation, and inference.
5. A set of metrics and tools for measuring the performance of your models.

Huggingface also provides a library of pre-built components and utilities that can be used with your Huggingface model. These components and utilities include:

1. A library of pre-trained
achieved tok/s: 4.200850
```


## Meta's Llama 2 models

As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format.
Expand Down
97 changes: 66 additions & 31 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ typedef struct {
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
void (*rope)(Config *, RunState *, int, int);
} Transformer;

void malloc_run_state(RunState* s, Config* p) {
Expand All @@ -83,16 +84,13 @@ void malloc_run_state(RunState* s, Config* p) {
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
s->q = calloc(p->dim, sizeof(float));
s->k = calloc(kv_dim, sizeof(float));
s->v = calloc(kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->key_cache || !s->value_cache || !s->att || !s->logits)
{
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
Expand All @@ -105,8 +103,6 @@ void free_run_state(RunState* s) {
free(s->hb);
free(s->hb2);
free(s->q);
free(s->k);
free(s->v);
free(s->att);
free(s->logits);
free(s->key_cache);
Expand Down Expand Up @@ -166,11 +162,60 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
memory_map_weights(weights, config, weights_ptr, shared_weights);
}

// rope_falcon(p, s, head_size, pos)
void rope_falcon(Config *p, RunState *s, int head_size, int pos) {
for (int i = 0; i < p->n_heads; i++) {
for (int j = 0; j < head_size / 2; j++) {
float freq = 1.0f / powf(10000.0f, 2.0f * (float)j / (float)head_size);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
float q0 = s->q[i * head_size + j];
float q1 = s->q[i * head_size + j + head_size / 2];
s->q[i * head_size + j] = q0 * fcr - q1 * fci;
s->q[i * head_size + j + head_size / 2] = q0 * fci + q1 * fcr;
if (i < p->n_kv_heads) {
float k0 = s->k[i * head_size + j];
float k1 = s->k[i * head_size + j + head_size / 2];
s->k[i * head_size + j] = k0 * fcr - k1 * fci;
s->k[i * head_size + j + head_size / 2] = k0 * fci + k1 * fcr;
}
}
}
}

// rope_llama(p, s, head_size, pos)
void rope_llama(Config *p, RunState *s, int head_size, int pos) {
for (int i = 0; i < p->n_heads; i++) {
for (int j = 0; j < head_size; j += 2) {
float freq = 1.0f / powf(10000.0f, (float)j / (float)head_size);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
float q0 = s->q[i * head_size + j];
float q1 = s->q[i * head_size + j + 1];
s->q[i * head_size + j] = q0 * fcr - q1 * fci;
s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr;
if (i < p->n_kv_heads) {
float k0 = s->k[i * head_size + j];
float k1 = s->k[i * head_size + j + 1];
s->k[i * head_size + j] = k0 * fcr - k1 * fci;
s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr;
}
}
}
}

void build_transformer(Transformer *t, char* checkpoint_path) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
malloc_run_state(&t->state, &t->config);
// rope architecture
if (strstr(checkpoint_path, "stories") != NULL)
t->rope = rope_llama;
else
t->rope = rope_falcon;
}

void free_transformer(Transformer* t) {
Expand Down Expand Up @@ -239,6 +284,7 @@ float* forward(Transformer* transformer, int token, int pos) {
Config* p = &transformer->config;
TransformerWeights* w = &transformer->weights;
RunState* s = &transformer->state;
void (*rope)(Config *, RunState *, int, int) = *(&transformer->rope);
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
Expand All @@ -256,34 +302,18 @@ float* forward(Transformer* transformer, int token, int pos) {
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

// key and value point to the kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
s->k = s->key_cache + loff + pos*kv_dim;
s->v = s->value_cache + loff + pos*kv_dim;

// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);

// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}

// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
(*rope)(p, s, head_size, pos);

// multihead attention. iterate over all heads
int h;
Expand Down Expand Up @@ -451,7 +481,12 @@ void safe_printf(char *piece) {

int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
char *input = "<0x0A>";
if (strcmp(str, "\\n") != 0)
{
input = str;
}
TokenIndex tok = {.str = input}; // acts as the key to search for
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
return res != NULL ? res->id : -1;
}
Expand Down Expand Up @@ -977,4 +1012,4 @@ int main(int argc, char *argv[]) {
free_transformer(&transformer);
return 0;
}
#endif
#endif