Skip to content

Commit eb4abfd

Browse files
committed
Merge remote-tracking branch 'origin/main' into upstream_merge_25_01_13
2 parents 16f8680 + 113274a commit eb4abfd

File tree

1 file changed

+250
-0
lines changed

1 file changed

+250
-0
lines changed

benchmarks/P3L_mling.py

+250
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
#!/usr/bin/env python3
2+
"""
3+
*MULTILINGUAL* Patch-Perplexity (P3L)
4+
5+
This is a script that produces a realistic PPL measurement
6+
for the quantized KV cache system by processing a sequence of
7+
non-overlapping patches of the reference text. Generation of the
8+
consecutive symbols in each patch is governed (forced)
9+
by the reference text.
10+
11+
The initial context size for the system is set by the parameter
12+
"--context-size".
13+
14+
The number of output symbols to generate starting from a given
15+
context is set by the parameter "--sample-size". This variable also
16+
defines the size of the individual patch.
17+
18+
For the N-token reference text that is split into M patches with the
19+
system's context size C it takes M*preload + (N-C)*generation time.
20+
21+
Quick correctness validation tips:
22+
23+
Running DeepSeek-V2 model
24+
(
25+
./vllm/examples/P3L_mling.py
26+
--model=meta-llama/Llama-2-7b-chat-hf
27+
--context-size=1024
28+
--sample-size=512
29+
)
30+
31+
should result in PPL ~ 8.42927
32+
33+
Running DeepSeek-V2 model
34+
(
35+
./vllm/examples/P3L_mling.py
36+
--model=meta-llama/Llama-2-7b-chat-hf
37+
--context-size=1024
38+
--sample-size=512
39+
--patch-size=1
40+
--lang-script="cmn_Hant"
41+
)
42+
should result in PPL ~ 2.67962
43+
44+
The multi-linguality is implemented through the additional
45+
key "--lang-script", which defaults to English in Latin
46+
scripture ("eng_Latn").
47+
48+
Please refer to
49+
50+
https://confluence.amd.com/display/MLSE/Multi-Lingual+P3L+Test
51+
52+
for the complete set of possible language-scripture choices.
53+
54+
55+
"""
56+
57+
import argparse
58+
import dataclasses
59+
import datetime
60+
import json
61+
import math
62+
import os
63+
64+
import pandas
65+
from huggingface_hub import hf_hub_download
66+
67+
from vllm import LLM, SamplingParams
68+
from vllm.engine.arg_utils import EngineArgs
69+
from vllm.logger import init_logger
70+
71+
logger = init_logger(__name__)
72+
73+
74+
def get_wikitext2_text(tokenizer):
75+
hf_hub_download(repo_id='alexei-v-ivanov-amd/wiki',
76+
repo_type="dataset",
77+
filename='wiki.test.raw',
78+
local_dir='./')
79+
with open('./wiki.test.raw') as f:
80+
test_text = "\n".join(line.strip() for line in f)
81+
test_enc = tokenizer(test_text)
82+
83+
os.remove('./wiki.test.raw')
84+
85+
return test_enc, test_text
86+
87+
88+
def get_flores_plus_text(tokenizer, lng_scrpt):
89+
hf_hub_download(repo_id='alexei-v-ivanov-amd/flores_plus',
90+
repo_type="dataset",
91+
filename=lng_scrpt + '.parquet',
92+
local_dir='./')
93+
94+
df = pandas.read_parquet('./' + lng_scrpt + '.parquet')
95+
test_text = "\n\n".join(line.strip() for line in df['text'])
96+
test_enc = tokenizer(test_text)
97+
98+
os.remove('./' + lng_scrpt + '.parquet')
99+
100+
return test_enc, test_text
101+
102+
103+
def vllm_init(args):
104+
engine_args = EngineArgs.from_cli_args(args)
105+
llm = LLM(**dataclasses.asdict(engine_args))
106+
107+
sampling_params = SamplingParams(n=1,
108+
temperature=0.0,
109+
top_p=1,
110+
ignore_eos=True,
111+
ppl_measurement=True,
112+
future_context=[],
113+
prompt_logprobs=1,
114+
logprobs=1,
115+
presence_penalty=0.0)
116+
117+
return llm, sampling_params
118+
119+
120+
def vllm_predict(CONT, llm, sampl_par):
121+
result = llm.generate(prompt_token_ids=CONT, sampling_params=sampl_par)
122+
return result
123+
124+
125+
def main(args: argparse.Namespace):
126+
127+
MESSAGE = f"Initialising @ {datetime.datetime.now()}"
128+
logger.info(MESSAGE)
129+
print(MESSAGE)
130+
my_ppl = 0.0
131+
132+
logger.info("Initializing the engine.")
133+
my_llm, my_sampl_par = vllm_init(args)
134+
my_tokenizer = my_llm.llm_engine.tokenizer.tokenizer
135+
logger.info(my_sampl_par)
136+
logger.info("Initialized the engine.")
137+
138+
my_n_samples = args.sample_size
139+
my_lang_script = args.lang_script
140+
141+
if (args.context_size+my_n_samples) > \
142+
my_llm.llm_engine.model_config.max_model_len:
143+
MESSAGE = ("" \
144+
"Error! The total number of tokens:\n" \
145+
f" prefix ({args.context_size}) + " \
146+
f"to be generated ({my_n_samples})" \
147+
f" can't be bigger than the model limit " \
148+
f"({my_llm.llm_engine.model_config.max_model_len}).")
149+
logger.info(MESSAGE)
150+
print(MESSAGE)
151+
return
152+
153+
my_test_enc, my_test_text = get_flores_plus_text(my_tokenizer,
154+
my_lang_script)
155+
156+
logger.info("Loaded the test data.")
157+
158+
my_n_patches = math.ceil(
159+
(len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples)
160+
if args.patch_size is not None:
161+
my_n_patches = args.patch_size
162+
163+
num_tokens_generated = 0
164+
starting_time = datetime.datetime.now()
165+
MESSAGE = (f"Starting generation @ {starting_time}\n" \
166+
" Have the test sample of "
167+
f"{len(my_test_enc['input_ids'])} tokens" \
168+
f" will try to process {my_n_patches} patche(s)," \
169+
f" generating {my_n_samples} tokens in each patch" \
170+
f" from the initial context of {args.context_size} tokens.")
171+
172+
logger.info(MESSAGE)
173+
print(MESSAGE)
174+
for c in range(my_n_patches):
175+
CONTEXT = []
176+
my_sampl_par.future_context = []
177+
CONTEXT.append(
178+
my_test_enc['input_ids'][c * my_n_samples:c * my_n_samples +
179+
args.context_size])
180+
upper_boundary = min((c + 1) * my_n_samples + args.context_size,
181+
len(my_test_enc['input_ids']))
182+
my_sampl_par.future_context.append(
183+
my_test_enc['input_ids'][c * my_n_samples +
184+
args.context_size:upper_boundary])
185+
my_sampl_par.max_tokens = len(my_sampl_par.future_context[0])
186+
my_sampl_par.cntr = c
187+
LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par)
188+
num_tokens_generated += len(LOGPROBS[0].outputs[0].token_ids)
189+
if (num_tokens_generated < my_n_samples):
190+
MESSAGE = (f"Warning: The number of generated tokens is" \
191+
f"less than requested ({num_tokens_generated}" \
192+
f" < {my_n_samples}).")
193+
logger.info(MESSAGE)
194+
print(MESSAGE)
195+
my_ppl -= LOGPROBS[0].outputs[0].cumulative_logprob
196+
MESSAGE = (f"Iteration {c+1} of {my_n_patches} Intermediate" \
197+
"Estimates:\n" \
198+
f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \
199+
f"\tPerplexity_intermediate=" \
200+
f"{math.exp(my_ppl/num_tokens_generated)}")
201+
202+
logger.info(MESSAGE)
203+
print(MESSAGE)
204+
ending_time = datetime.datetime.now()
205+
MESSAGE = (f"Done @ {ending_time} after processing for" \
206+
f" {ending_time-starting_time}" \
207+
f" generated {num_tokens_generated} tokens.")
208+
209+
logger.info(MESSAGE)
210+
print(MESSAGE)
211+
212+
MESSAGE = (f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy=" \
213+
f"{my_ppl/num_tokens_generated}" \
214+
f"\n\tPPL={math.exp(my_ppl/num_tokens_generated)}")
215+
216+
if args.output_json:
217+
results = {
218+
"integral_cross_entropy": my_ppl,
219+
"average_cross_entropy": my_ppl / num_tokens_generated,
220+
"ppl": math.exp(my_ppl / num_tokens_generated),
221+
}
222+
with open(args.output_json, "w") as f:
223+
json.dump(results, f, indent=4)
224+
225+
logger.info(MESSAGE)
226+
print(MESSAGE)
227+
return
228+
229+
230+
if __name__ == "__main__":
231+
parser = argparse.ArgumentParser(
232+
description='Measure the PPPL (P3L) score of a given model.')
233+
parser.add_argument(
234+
'--data',
235+
type=str,
236+
default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet')
237+
parser.add_argument('--context-size', type=int, default=4096)
238+
parser.add_argument('--sample-size', type=int, default=512)
239+
parser.add_argument('--patch-size', type=int, default=None)
240+
parser.add_argument('--lang-script', type=str, default="eng_Latn")
241+
parser.add_argument(
242+
'--output-json',
243+
type=str,
244+
default=None,
245+
help='Path to save the latency results in JSON format.')
246+
247+
parser = EngineArgs.add_cli_args(parser)
248+
args = parser.parse_args()
249+
250+
main(args)

0 commit comments

Comments
 (0)