From 00ad9819db0eb18e1889945ac4fe7feab0e4f39a Mon Sep 17 00:00:00 2001 From: sushil Dubey Date: Tue, 23 Apr 2024 10:36:52 +0530 Subject: [PATCH] Add rouge metric evalution for llama 70B with orca datasets (#169) * Add rouge metric evalution for llama 70B with orca datasets use rouge metric to evaluate the corretness of the model, it uses openorca dataset --- examples/text-generation/evaluation.py | 115 ++++++++++++ .../text-generation/requirements_lm_eval.txt | 6 +- examples/text-generation/run_generation.py | 168 +++++++++++++++++- 3 files changed, 286 insertions(+), 3 deletions(-) create mode 100644 examples/text-generation/evaluation.py diff --git a/examples/text-generation/evaluation.py b/examples/text-generation/evaluation.py new file mode 100644 index 0000000000..f8e5e9eb36 --- /dev/null +++ b/examples/text-generation/evaluation.py @@ -0,0 +1,115 @@ +import argparse +from transformers import AutoTokenizer +import nltk +import evaluate +import numpy as np +import json + +###################### Habana internal code ################################## +ACC_TARGET = {"rouge1": 44.4312, "rouge2": 22.0352, "rougeL": 28.6162} + +# See https://github.com/mlcommons/inference/pull/1583 +############################################################################## + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint-path", default="/mnt/weka/data/pytorch/llama2/Llama-2-70b-chat-hf", + help="Path to Llama2-70b-hf-chat checkpoint") + parser.add_argument("--accuracy-file", default="output/accuracy.json", help="path to accuracy.json") + parser.add_argument("--dataset-file", default="/mnt/weka/data/mlperf_inference/llama2/processed-data.pkl", + help="path to processed openorca validation set") + parser.add_argument("--verbose", action="store_true", + help="verbose messages") + parser.add_argument("--dtype", default="int64", + help="dtype of the accuracy log", choices=["int32", "int64", "float"]) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + import pandas as pd + data = pd.read_pickle(processed_dataset_file) + ground_truths = data['output'] + return ground_truths + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + + +def main(): + + args = get_args() + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download('punkt') + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False,) + + + targets = get_groundtruth(args.dataset_file) + + target_required = [] + preds_token_ids = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + for pred in results: + qsl_idx = pred['qsl_idx'] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + target = targets[qsl_idx] + target_required.append(target) + pred = np.frombuffer( bytes.fromhex(pred['data']), eval_dtype) + + gen_tok_len += len(pred) + preds_token_ids.append(pred) + + preds_decoded_text = tokenizer.batch_decode( + preds_token_ids, skip_special_tokens=True) + + preds, targets = postprocess_text(preds_decoded_text, target_required) + + result = metric.compute( + predictions=preds, references=targets, use_stemmer=True, use_aggregator=False) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + gen_num = len(preds) + + acc = [result[key] / ACC_TARGET[key] for key in ACC_TARGET] + acc = round(np.min(acc) * 100, 2) + + + result = {**result, + 'gen_len': np.sum(prediction_lens), + 'gen_num': gen_num, + 'accuracy': acc # this is Habana internal field + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() diff --git a/examples/text-generation/requirements_lm_eval.txt b/examples/text-generation/requirements_lm_eval.txt index 4d18247223..b7112ce78a 100644 --- a/examples/text-generation/requirements_lm_eval.txt +++ b/examples/text-generation/requirements_lm_eval.txt @@ -1 +1,5 @@ -https://github.com/polisettyvarma/lm-evaluation-harness/archive/3cdc8daadad9f4559ae6cdfae96f1d83d6b3c1f4.zip \ No newline at end of file +https://github.com/polisettyvarma/lm-evaluation-harness/archive/3cdc8daadad9f4559ae6cdfae96f1d83d6b3c1f4.zip +evaluate +rouge_score +accelerate +pandas diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index b30c2e4447..ba8844b392 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -26,7 +26,9 @@ import time from itertools import cycle from pathlib import Path - +import pandas as pd +import struct +import contextlib import torch from utils import adjust_batch, count_hpu_graphs, initialize_model @@ -85,6 +87,12 @@ def setup_parser(parser): type=str, help="Optional argument if you want to assess your model on a given dataset of the HF Hub.", ) + parser.add_argument( + "--dataset", + default="/mnt/weka/data/mlperf_inference/llama2/processed-data.pkl", + type=str, + help="path of the dataset to run rouge evaluation and measurement for rouge", + ) parser.add_argument( "--column_name", default=None, @@ -313,8 +321,164 @@ def main(): use_lazy_mode = False import habana_frameworks.torch.hpu as torch_hpu + if args.dataset_name == "openorca": + # Benchmark over the prompts below + def get_ds(args): + ds = pd.read_pickle(args.dataset) + return ds + + + def get_input(ds, batch_size): + queries = [] + tok_input = ds["tok_input"].tolist() + for start in range(0, len(ds), batch_size): + end = start + batch_size + batch = tok_input[start:end] + input_ids = [] + attention_mask=[] + for query in batch: + input_ids.append( + [0] * (args.max_input_tokens - len(query)) + query) + attention_mask.append([0] * (args.max_input_tokens - len(query)) + [1] * len(query)) + queries.append({ + 'input_ids': torch.tensor(input_ids, dtype=torch.int32), + 'attention_mask': torch.tensor(attention_mask, dtype=torch.int32) + }) + return queries + + ds = get_ds(args) + input_sentences = get_input(ds, args.batch_size) + + def generate(input_tokens, size=None, reduce_recompile=False): + """Generates sequences from the input sentences and returns them.""" + + t0 = time.perf_counter() + print(f"Step4+ starting time is {t0*1000}", flush=True) + if size is not None: + input_tokens = adjust_batch(input_tokens, size) + + if not reduce_recompile: + # Move inputs to target device(s) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(args.device) + + outputs = model.generate( + **input_tokens, + generation_config=generation_config, + lazy_mode=use_lazy_mode, + hpu_graphs=args.use_hpu_graphs, + profiling_steps=args.profiling_steps, + profiling_warmup_steps=args.profiling_warmup_steps, + ).cpu() + outputs = outputs.tolist() + for i in range(len(outputs)): + outputs[i] = outputs[i][args.max_input_tokens:] + duration = time.perf_counter() - t0 + print(f"Total E2E time of this batch is {duration:.3f}s", flush=True) + return outputs + + from optimum.habana.utils import HabanaProfile - if args.dataset_name is None: + # compilation stage disable profiling + HabanaProfile.disable() + # Compilation + logger.info("Graph compilation...") + dyn_prompt_lens = args.simulate_dyn_prompt + t0 = time.perf_counter() + # The first three iterations take longer because of graph compilation + if dyn_prompt_lens is None or len(set(dyn_prompt_lens)) == 1: + for _ in range(args.warmup): + if dyn_prompt_lens is None: + print("Warming up", flush=True) + generate(input_sentences[0], None, args.reduce_recompile) + else: + print("Warming up for shape,", dyn_prompt_lens[0], flush=True) + generate(input_sentences[0], dyn_prompt_lens[0], args.reduce_recompile) + else: + if args.bucket_size > 0: + mn = min(dyn_prompt_lens) + mx = max(dyn_prompt_lens) + + def rounder(x): + return int(math.ceil(x / args.bucket_size) * args.bucket_size) + + min_prompt_len = rounder(mn) + max_sentence_len = rounder(mx) + for _ in range(args.warmup): + lst = list(range(min_prompt_len, max_sentence_len + 1, args.bucket_size)) + for sz in lst: + print("Warming up for shape,", sz - 1, flush=True) + generate(input_sentences[0], sz - 1, args.reduce_recompile) + torch_hpu.synchronize() + compilation_duration = time.perf_counter() - t0 + HabanaProfile.enable() + total_new_tokens_generated = 0 + logger.info("Running generate...") + t0 = time.perf_counter() + # Benchmark over n_iterations iterations + N = len(input_sentences) + if dyn_prompt_lens is None: + for i in range(args.n_iterations): + results = [] + b = 1 + for sentence in input_sentences: + generated = generate(sentence, None, args.reduce_recompile) + results.extend(generated) + print(f"Generatig batch {b}/{N}") + b +=1 + else: + repeated_prompt_len = cycle(dyn_prompt_lens) + for i in range(args.n_iterations): + prompt_len = next(repeated_prompt_len) + print("Generating for shape,", prompt_len) + results = [] + for sentence in input_sentences: + generated = generate(sentence, prompt_len, args.reduce_recompile) + results.extend(generated) + duration = time.perf_counter() - t0 + total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens + throughput = total_new_tokens_generated / duration + + # Store results if necessary + if args.output_dir is not None and args.global_rank == 0: + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + #TODO dump in hex format + acc_file = [] + num_token = 0 + for i, idx in enumerate(ds.index): + pred = results[i] + eos_token_id = 2 + try: + ind_eos = pred.index(eos_token_id)+1 + except: + ind_eos = len(pred) + pred = pred[:ind_eos] + num_token += len(pred) + acc_file.append({ + "seq_id": idx, + "qsl_idx": idx, + "data": bytes(struct.pack('L' * len(pred), *pred)).hex().upper() + }) + with open(output_dir / "accuracy.json", "w") as outfile: + outfile.write(json.dumps(acc_file)) + + stats = f"Throughput (including tokenization) = {throughput} tokens/second" + stats = stats + f"\nNumber of HPU graphs = {count_hpu_graphs()}" + separator = "-" * len(stats) + print() + print("Stats:") + print(separator) + print(stats) + mem = get_hpu_memory_stats() + for k, v in mem.items(): + print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) + print(f"Graph compilation duration = {compilation_duration} seconds") + print(separator) + print() + elif args.dataset_name is None: # Benchmark over the prompts below if args.prompt: input_sentences = args.prompt