Skip to content

Commit

Permalink
Add Emscripten interface to llama2.c
Browse files Browse the repository at this point in the history
  • Loading branch information
gohai committed Aug 2, 2023
1 parent ab39930 commit ec3bac6
Show file tree
Hide file tree
Showing 10 changed files with 1,098 additions and 195 deletions.
39 changes: 39 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,45 @@ rungnu:
runompgnu:
$(CC) -Ofast -fopenmp -std=gnu11 run.c -lm -o run

# includes model & tokenizer
.PHONY: emscripten
emscripten: run.c
emcc -O3 run.c \
-o web/src/llama2.js \
-s EXPORTED_FUNCTIONS='["_main", "_main_loop", "_malloc", "_free", "_register_callback", "_set_parameters", "_generate", "_manual_start", "_manual_next", "_get_vocab", "_get_vocab_size"]' \
-s EXPORTED_RUNTIME_METHODS='["ccall", "addFunction", "UTF8ToString"]' \
-s ALLOW_MEMORY_GROWTH=1 \
-s ALLOW_TABLE_GROWTH=1 \
-s MODULARIZE \
-s EXPORT_NAME='Llama2' \
--preload-file model.bin \
--preload-file tokenizer.bin

# includes tokenizer only, model loaded from URL
.PHONY: emscripten-small
emscripten-small: run.c
emcc -O3 run.c \
-o web/src/llama2.js \
-s EXPORTED_FUNCTIONS='["_main", "_main_loop", "_malloc", "_free", "_register_callback", "_set_parameters", "_generate", "_manual_start", "_manual_next", "_get_vocab", "_get_vocab_size"]' \
-s EXPORTED_RUNTIME_METHODS='["ccall", "addFunction", "UTF8ToString"]' \
-s ALLOW_MEMORY_GROWTH=1 \
-s ALLOW_TABLE_GROWTH=1 \
-s MODULARIZE \
-s EXPORT_NAME='Llama2' \
--preload-file tokenizer.bin

# model & tokenizer loaded from URL
.PHONY: emscripten-min
emscripten-min: run.c
emcc -O3 run.c \
-o web/src/llama2.js \
-s EXPORTED_FUNCTIONS='["_main", "_main_loop", "_malloc", "_free", "_register_callback", "_set_parameters", "_generate", "_manual_start", "_manual_next", "_get_vocab", "_get_vocab_size"]' \
-s EXPORTED_RUNTIME_METHODS='["ccall", "addFunction", "UTF8ToString"]' \
-s ALLOW_MEMORY_GROWTH=1 \
-s ALLOW_TABLE_GROWTH=1 \
-s MODULARIZE \
-s EXPORT_NAME='Llama2'

.PHONY: clean
clean:
rm -f run
258 changes: 105 additions & 153 deletions README.md

Large diffs are not rendered by default.

219 changes: 177 additions & 42 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Then run with:
#include <unistd.h>
#include <sys/mman.h>
#endif
#if defined __EMSCRIPTEN__
#include <emscripten.h>
#endif

// ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management

Expand Down Expand Up @@ -448,12 +452,72 @@ int argmax(float* v, int n) {
}
// ----------------------------------------------------------------------------


float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
Config config;
TransformerWeights weights;
char** vocab;
float* vocab_scores;
unsigned int max_token_length;
RunState state;
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence


void* on_token_callback = NULL;


void main_loop(void * dummy) {
#if defined __EMSCRIPTEN__
if (pos >= steps) {
emscripten_pause_main_loop();
return; // pointers might be invalid when this resumes
}
#endif

// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);

if(pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
// sample the next token
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = argmax(state.logits, config.vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
}
}

if (on_token_callback) {
((void (*)(char*, int, float, int))on_token_callback)(vocab[next], next, state.logits[next], (pos+1 >= steps));
}

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
fflush(stdout);

// advance forward
token = next;
pos++;
}

int main(int argc, char *argv[]) {

// poor man's C argparse
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
char *prompt = NULL; // prompt string

// 'checkpoint' is necessary arg
Expand All @@ -479,8 +543,6 @@ int main(int argc, char *argv[]) {
rng_seed = (unsigned int)time(NULL);

// read in the model.bin file
Config config;
TransformerWeights weights;
int fd = 0; // file descriptor for memory mapping
float* data = NULL; // memory mapped data pointer
long file_size; // size of the checkpoint file in bytes
Expand Down Expand Up @@ -508,9 +570,8 @@ int main(int argc, char *argv[]) {
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }

// read in the tokenizer.bin file
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
unsigned int max_token_length;
vocab = (char**)malloc(config.vocab_size * sizeof(char*));
vocab_scores = (float*)malloc(config.vocab_size * sizeof(float));
{
FILE *file = fopen("tokenizer.bin", "rb");
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
Expand All @@ -527,57 +588,29 @@ int main(int argc, char *argv[]) {
}

// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);

// process the prompt, if any
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
if (prompt != NULL) {
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
}

// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
while (pos < steps) {

// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);

if(pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
// sample the next token
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = argmax(state.logits, config.vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
}
}

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
fflush(stdout);
#if defined __EMSCRIPTEN__
pos = steps; // to make the loop pause initially
emscripten_set_main_loop_arg(main_loop, NULL, 0, 1);
#else
while (pos < steps) {
main_loop(NULL);

// advance forward
token = next;
pos++;
// init our timer here because the first iteration is slow due to memmap
if (start == 0) { start = time_in_ms(); }
}
#endif

// report achieved tok/s
long end = time_in_ms();
Expand All @@ -593,3 +626,105 @@ int main(int argc, char *argv[]) {
if (fd != -1) close(fd);
return 0;
}


#if defined __EMSCRIPTEN__

void register_callback(void* _on_token_callback) {
on_token_callback = _on_token_callback;
}

void set_parameters(float _tempature, int _steps) {
temperature = _tempature;
if (_steps <= 0 || _steps > config.seq_len) {
steps = config.seq_len;
} else {
steps = _steps;
}
}

void generate(char* prompt) {
// reset state
free_run_state(&state);
if (prompt_tokens != NULL) {
free(prompt_tokens);
prompt_tokens = NULL;
}
malloc_run_state(&state, &config);

// process prompt
if (prompt != NULL) {
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
}

token = 1;
pos = 0;

// (re-) start the main loop for generation
emscripten_resume_main_loop();
}

//
// Besides generate(), which will use the main loop to invoke a
// callback function for every token, the manual_ functions below
// let the caller pick the next token synchronously. You'd want
// to use one or the other.
//

char** get_vocab() {
return vocab;
}

int get_vocab_size() {
return config.vocab_size;
}

int manual_start(char* prompt) {
// stop the main loop of any prior generate()
// the manual_ functions aren't using it
emscripten_pause_main_loop();

// reset state
free_run_state(&state);
if (prompt_tokens != NULL) {
free(prompt_tokens);
}
malloc_run_state(&state, &config);

// process prompt
if (prompt != NULL) {
prompt_tokens = (int*)malloc(config.seq_len * sizeof(int));
bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens);
}

token = 1;
pos = 0;

// run the transformer over the prompt
while (pos < num_prompt_tokens) {
transformer(token, pos, &config, &state, &weights);
token = prompt_tokens[pos];
pos++;
}

return token; // return the first token to pass to _next()
}

float* manual_next(int _token) {
token = _token;

transformer(token, pos, &config, &state, &weights);

if (temperature != 0.0f) {
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
softmax(state.logits, config.vocab_size);
}

token = 0;
pos++;

return state.logits;
}

#endif
39 changes: 39 additions & 0 deletions web/dist/basic.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>llama2.c-emscripten example</title>
<script src="llama2.js"></script>
</head>
<body>
<input id="prompt" placeholder="Prompt"></input> <input type="button" id="generate" value="Generate"></input>
<div id="output"></div>

<script>
(async () => {

const llama2 = await new LLAMA2();

document.querySelector('#generate').addEventListener('click', async function() {
const prompt = document.querySelector('#prompt').value;

const options = {
temperature: 0.9,
};

const out = await llama2.generate(prompt);
document.querySelector('#output').innerHTML = out;
});

llama2.on('token', () => {
console.log('token', llama2.tokens[llama2.tokens.length-1]);
});

llama2.on('word', (word) => {
console.log('word', word);
});

})();
</script>
</body>
</html>
Loading

0 comments on commit ec3bac6

Please sign in to comment.