-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_kv_q80.py
58 lines (52 loc) · 2.08 KB
/
run_kv_q80.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
'''
gcc --shared -fPIC -o openelm_kv_q80.so openelm_kv_q80.c -lm -fopenmp
'''
import time
from transformers import AutoTokenizer
from ctypes import CDLL
from ctypes import c_int, POINTER
openelmlib = CDLL("./openelm_kv_q80.so")
def init(batch: int, max_seq_len: int):
openelmlib.c_init(c_int(batch), c_int(max_seq_len))
def openelm_forward(token, batch, seq_len, pos)->list:
openelmlib.c_openelm_forward.restype = POINTER(c_int * batch)
sample = openelmlib.c_openelm_forward(c_int(batch), c_int(seq_len), (c_int * len(token))(*token), c_int(pos))
res = []
for i in sample.contents:
res.append(int(i))
return res
def generate(batch: int, data: list, steps: int):
openelmlib.c_generate(c_int(batch), c_int(len(data)), (c_int * len(data))(*data), c_int(steps))
if __name__ == '__main__':
tokenizer = "meta-llama/Llama-2-7b-hf"
hf_access_token = "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"
tokenizer = AutoTokenizer.from_pretrained(tokenizer, token=hf_access_token, use_fast=False)
tokenized_prompt = tokenizer.encode("Once upon a time there was a man named John")
print([tokenized_prompt, tokenized_prompt])
batch = 2
seq_len = len(tokenized_prompt)
tokenized_prompt_c = [tokenized_prompt[0], tokenized_prompt[0]]
print(tokenized_prompt_c)
max_seq_len = seq_len + 5
pos = 0
init(batch, max_seq_len)
# next = openelm_forward(tokenized_prompt_c, batch, seq_len, pos)
output = []
begin = time.time()
while (pos < max_seq_len):
if pos < seq_len:
tokenized_prompt_c = [tokenized_prompt[pos], tokenized_prompt[pos]]
else:
tokenized_prompt_c = next
next = openelm_forward(tokenized_prompt_c, batch, 1, pos)
print(f"pos:{pos} {tokenized_prompt_c}")
output.append(tokenized_prompt_c[0])
pos += 1
end = time.time()
print(f"total time is: {end - begin}s, tokens: {max_seq_len} {max_seq_len / (end - begin)} tokens/s")
output.append(next[0])
output_text = tokenizer.decode(
output,
skip_special_tokens=True
)
print(output_text)