|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +*MULTILINGUAL* Patch-Perplexity (P3L) |
| 4 | +
|
| 5 | +This is a script that produces a realistic PPL measurement |
| 6 | +for the quantized KV cache system by processing a sequence of |
| 7 | +non-overlapping patches of the reference text. Generation of the |
| 8 | +consecutive symbols in each patch is governed (forced) |
| 9 | +by the reference text. |
| 10 | +
|
| 11 | +The initial context size for the system is set by the parameter |
| 12 | +"--context-size". |
| 13 | +
|
| 14 | +The number of output symbols to generate starting from a given |
| 15 | +context is set by the parameter "--sample-size". This variable also |
| 16 | +defines the size of the individual patch. |
| 17 | +
|
| 18 | +For the N-token reference text that is split into M patches with the |
| 19 | +system's context size C it takes M*preload + (N-C)*generation time. |
| 20 | +
|
| 21 | +Quick correctness validation tips: |
| 22 | +
|
| 23 | +Running DeepSeek-V2 model |
| 24 | +( |
| 25 | + ./vllm/examples/P3L_mling.py |
| 26 | + --model=meta-llama/Llama-2-7b-chat-hf |
| 27 | + --context-size=1024 |
| 28 | + --sample-size=512 |
| 29 | +) |
| 30 | +
|
| 31 | +should result in PPL ~ 8.42927 |
| 32 | +
|
| 33 | +Running DeepSeek-V2 model |
| 34 | +( |
| 35 | + ./vllm/examples/P3L_mling.py |
| 36 | + --model=meta-llama/Llama-2-7b-chat-hf |
| 37 | + --context-size=1024 |
| 38 | + --sample-size=512 |
| 39 | + --patch-size=1 |
| 40 | + --lang-script="cmn_Hant" |
| 41 | +) |
| 42 | +should result in PPL ~ 2.67962 |
| 43 | +
|
| 44 | +The multi-linguality is implemented through the additional |
| 45 | +key "--lang-script", which defaults to English in Latin |
| 46 | +scripture ("eng_Latn"). |
| 47 | +
|
| 48 | +Please refer to |
| 49 | +
|
| 50 | +https://confluence.amd.com/display/MLSE/Multi-Lingual+P3L+Test |
| 51 | +
|
| 52 | +for the complete set of possible language-scripture choices. |
| 53 | +
|
| 54 | +
|
| 55 | +""" |
| 56 | + |
| 57 | +import argparse |
| 58 | +import dataclasses |
| 59 | +import datetime |
| 60 | +import json |
| 61 | +import math |
| 62 | +import os |
| 63 | + |
| 64 | +import pandas |
| 65 | +from huggingface_hub import hf_hub_download |
| 66 | + |
| 67 | +from vllm import LLM, SamplingParams |
| 68 | +from vllm.engine.arg_utils import EngineArgs |
| 69 | +from vllm.logger import init_logger |
| 70 | + |
| 71 | +logger = init_logger(__name__) |
| 72 | + |
| 73 | + |
| 74 | +def get_wikitext2_text(tokenizer): |
| 75 | + hf_hub_download(repo_id='alexei-v-ivanov-amd/wiki', |
| 76 | + repo_type="dataset", |
| 77 | + filename='wiki.test.raw', |
| 78 | + local_dir='./') |
| 79 | + with open('./wiki.test.raw') as f: |
| 80 | + test_text = "\n".join(line.strip() for line in f) |
| 81 | + test_enc = tokenizer(test_text) |
| 82 | + |
| 83 | + os.remove('./wiki.test.raw') |
| 84 | + |
| 85 | + return test_enc, test_text |
| 86 | + |
| 87 | + |
| 88 | +def get_flores_plus_text(tokenizer, lng_scrpt): |
| 89 | + hf_hub_download(repo_id='alexei-v-ivanov-amd/flores_plus', |
| 90 | + repo_type="dataset", |
| 91 | + filename=lng_scrpt + '.parquet', |
| 92 | + local_dir='./') |
| 93 | + |
| 94 | + df = pandas.read_parquet('./' + lng_scrpt + '.parquet') |
| 95 | + test_text = "\n\n".join(line.strip() for line in df['text']) |
| 96 | + test_enc = tokenizer(test_text) |
| 97 | + |
| 98 | + os.remove('./' + lng_scrpt + '.parquet') |
| 99 | + |
| 100 | + return test_enc, test_text |
| 101 | + |
| 102 | + |
| 103 | +def vllm_init(args): |
| 104 | + engine_args = EngineArgs.from_cli_args(args) |
| 105 | + llm = LLM(**dataclasses.asdict(engine_args)) |
| 106 | + |
| 107 | + sampling_params = SamplingParams(n=1, |
| 108 | + temperature=0.0, |
| 109 | + top_p=1, |
| 110 | + ignore_eos=True, |
| 111 | + ppl_measurement=True, |
| 112 | + future_context=[], |
| 113 | + prompt_logprobs=1, |
| 114 | + logprobs=1, |
| 115 | + presence_penalty=0.0) |
| 116 | + |
| 117 | + return llm, sampling_params |
| 118 | + |
| 119 | + |
| 120 | +def vllm_predict(CONT, llm, sampl_par): |
| 121 | + result = llm.generate(prompt_token_ids=CONT, sampling_params=sampl_par) |
| 122 | + return result |
| 123 | + |
| 124 | + |
| 125 | +def main(args: argparse.Namespace): |
| 126 | + |
| 127 | + MESSAGE = f"Initialising @ {datetime.datetime.now()}" |
| 128 | + logger.info(MESSAGE) |
| 129 | + print(MESSAGE) |
| 130 | + my_ppl = 0.0 |
| 131 | + |
| 132 | + logger.info("Initializing the engine.") |
| 133 | + my_llm, my_sampl_par = vllm_init(args) |
| 134 | + my_tokenizer = my_llm.llm_engine.tokenizer.tokenizer |
| 135 | + logger.info(my_sampl_par) |
| 136 | + logger.info("Initialized the engine.") |
| 137 | + |
| 138 | + my_n_samples = args.sample_size |
| 139 | + my_lang_script = args.lang_script |
| 140 | + |
| 141 | + if (args.context_size+my_n_samples) > \ |
| 142 | + my_llm.llm_engine.model_config.max_model_len: |
| 143 | + MESSAGE = ("" \ |
| 144 | + "Error! The total number of tokens:\n" \ |
| 145 | + f" prefix ({args.context_size}) + " \ |
| 146 | + f"to be generated ({my_n_samples})" \ |
| 147 | + f" can't be bigger than the model limit " \ |
| 148 | + f"({my_llm.llm_engine.model_config.max_model_len}).") |
| 149 | + logger.info(MESSAGE) |
| 150 | + print(MESSAGE) |
| 151 | + return |
| 152 | + |
| 153 | + my_test_enc, my_test_text = get_flores_plus_text(my_tokenizer, |
| 154 | + my_lang_script) |
| 155 | + |
| 156 | + logger.info("Loaded the test data.") |
| 157 | + |
| 158 | + my_n_patches = math.ceil( |
| 159 | + (len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples) |
| 160 | + if args.patch_size is not None: |
| 161 | + my_n_patches = args.patch_size |
| 162 | + |
| 163 | + num_tokens_generated = 0 |
| 164 | + starting_time = datetime.datetime.now() |
| 165 | + MESSAGE = (f"Starting generation @ {starting_time}\n" \ |
| 166 | + " Have the test sample of " |
| 167 | + f"{len(my_test_enc['input_ids'])} tokens" \ |
| 168 | + f" will try to process {my_n_patches} patche(s)," \ |
| 169 | + f" generating {my_n_samples} tokens in each patch" \ |
| 170 | + f" from the initial context of {args.context_size} tokens.") |
| 171 | + |
| 172 | + logger.info(MESSAGE) |
| 173 | + print(MESSAGE) |
| 174 | + for c in range(my_n_patches): |
| 175 | + CONTEXT = [] |
| 176 | + my_sampl_par.future_context = [] |
| 177 | + CONTEXT.append( |
| 178 | + my_test_enc['input_ids'][c * my_n_samples:c * my_n_samples + |
| 179 | + args.context_size]) |
| 180 | + upper_boundary = min((c + 1) * my_n_samples + args.context_size, |
| 181 | + len(my_test_enc['input_ids'])) |
| 182 | + my_sampl_par.future_context.append( |
| 183 | + my_test_enc['input_ids'][c * my_n_samples + |
| 184 | + args.context_size:upper_boundary]) |
| 185 | + my_sampl_par.max_tokens = len(my_sampl_par.future_context[0]) |
| 186 | + my_sampl_par.cntr = c |
| 187 | + LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par) |
| 188 | + num_tokens_generated += len(LOGPROBS[0].outputs[0].token_ids) |
| 189 | + if (num_tokens_generated < my_n_samples): |
| 190 | + MESSAGE = (f"Warning: The number of generated tokens is" \ |
| 191 | + f"less than requested ({num_tokens_generated}" \ |
| 192 | + f" < {my_n_samples}).") |
| 193 | + logger.info(MESSAGE) |
| 194 | + print(MESSAGE) |
| 195 | + my_ppl -= LOGPROBS[0].outputs[0].cumulative_logprob |
| 196 | + MESSAGE = (f"Iteration {c+1} of {my_n_patches} Intermediate" \ |
| 197 | + "Estimates:\n" \ |
| 198 | + f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \ |
| 199 | + f"\tPerplexity_intermediate=" \ |
| 200 | + f"{math.exp(my_ppl/num_tokens_generated)}") |
| 201 | + |
| 202 | + logger.info(MESSAGE) |
| 203 | + print(MESSAGE) |
| 204 | + ending_time = datetime.datetime.now() |
| 205 | + MESSAGE = (f"Done @ {ending_time} after processing for" \ |
| 206 | + f" {ending_time-starting_time}" \ |
| 207 | + f" generated {num_tokens_generated} tokens.") |
| 208 | + |
| 209 | + logger.info(MESSAGE) |
| 210 | + print(MESSAGE) |
| 211 | + |
| 212 | + MESSAGE = (f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy=" \ |
| 213 | + f"{my_ppl/num_tokens_generated}" \ |
| 214 | + f"\n\tPPL={math.exp(my_ppl/num_tokens_generated)}") |
| 215 | + |
| 216 | + if args.output_json: |
| 217 | + results = { |
| 218 | + "integral_cross_entropy": my_ppl, |
| 219 | + "average_cross_entropy": my_ppl / num_tokens_generated, |
| 220 | + "ppl": math.exp(my_ppl / num_tokens_generated), |
| 221 | + } |
| 222 | + with open(args.output_json, "w") as f: |
| 223 | + json.dump(results, f, indent=4) |
| 224 | + |
| 225 | + logger.info(MESSAGE) |
| 226 | + print(MESSAGE) |
| 227 | + return |
| 228 | + |
| 229 | + |
| 230 | +if __name__ == "__main__": |
| 231 | + parser = argparse.ArgumentParser( |
| 232 | + description='Measure the PPPL (P3L) score of a given model.') |
| 233 | + parser.add_argument( |
| 234 | + '--data', |
| 235 | + type=str, |
| 236 | + default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet') |
| 237 | + parser.add_argument('--context-size', type=int, default=4096) |
| 238 | + parser.add_argument('--sample-size', type=int, default=512) |
| 239 | + parser.add_argument('--patch-size', type=int, default=None) |
| 240 | + parser.add_argument('--lang-script', type=str, default="eng_Latn") |
| 241 | + parser.add_argument( |
| 242 | + '--output-json', |
| 243 | + type=str, |
| 244 | + default=None, |
| 245 | + help='Path to save the latency results in JSON format.') |
| 246 | + |
| 247 | + parser = EngineArgs.add_cli_args(parser) |
| 248 | + args = parser.parse_args() |
| 249 | + |
| 250 | + main(args) |
0 commit comments