diff --git a/examples/text-generation/mbxp_evaluation/evaluate_mbxp.py b/examples/text-generation/mbxp_evaluation/evaluate_mbxp.py new file mode 100755 index 0000000000..bb2ae3c3e0 --- /dev/null +++ b/examples/text-generation/mbxp_evaluation/evaluate_mbxp.py @@ -0,0 +1,141 @@ +import json +import multiprocessing +import queue +import re +import timeit + +# functions below are executed in while loop using eval() function +# therefore, we need to import them and set F401 to avoid "imported but unused" warnings +from mxeval.execution import check_correctness as check_correctness_python # noqa: F401 +from mxeval.execution import ( + check_correctness_cpp, # noqa: F401 + check_correctness_csharp, # noqa: F401 + check_correctness_go, # noqa: F401 + check_correctness_java, # noqa: F401 + check_correctness_javascript, # noqa: F401 + check_correctness_kotlin, # noqa: F401 + check_correctness_perl, # noqa: F401 + check_correctness_php, # noqa: F401 + check_correctness_ruby, # noqa: F401 + check_correctness_scala, # noqa: F401 + check_correctness_swift, # noqa: F401 + check_correctness_typescript, # noqa: F401 +) +from tqdm import tqdm + + +def postprocess_golang(code: str) -> str: + multi_line_imports = re.compile(r"^import \(\n(.+)((?:\n.+)+)\n\)", re.MULTILINE) + line_imports = re.compile(r"^import \".*\"") + func_main = re.compile(r"^func main.*^}", re.MULTILINE | re.DOTALL) + + code = code.replace("package main", "") # Remove package main + code = multi_line_imports.sub("", code) + code = line_imports.sub("", code) + code = func_main.sub("", code) + + return code + + +def postprocess_scala(code: str) -> str: + code = code.replace("object Main extends App {", "") + code = "".join(code.splitlines(True)[:-1]) + return code + + +def postprocess_python(code: str) -> str: + return code.lstrip() + + +def worker(inp_queue, out_queue): + while True: + try: + problem = inp_queue.get(timeout=5) + except queue.Empty: + break + + key = f"{problem['lang']}_{problem['entry_point']}" + checker = eval(f"check_correctness_{problem['lang']}") + + problem["task_id"] = key + problem["test"] = problem["test_code"] + + solution = problem["response"] + + try: + solution = solution[: solution.index("```")] + except ValueError: + # Happens when a code block isn't closed properly + pass + + if problem["lang"] == "go": + solution = postprocess_golang(solution) + elif problem["lang"] == "python": + solution = postprocess_python(solution) + elif problem["lang"] == "scala": + solution = postprocess_scala(solution) + + # Mixtral likes escaping underscores for some reason, so let's remove + # these + solution = solution.replace("\\_", "_") + + # The evaluation script evaluates `code = prompt + solution + tests` + # But Mixtral regenerates the prompt in its output, so we should remove + # this + problem["prompt"] = "" + try: + result = checker(problem, solution, timeout=20.0) + out_queue.put((key, problem["lang"], result["passed"], result["result"], problem["response"])) + except Exception as e: + print(e) + out_queue.put((key, problem["lang"], False, "", problem["response"])) + + +def evaluate_mbxp(results, n_workers): + by_lang = {} + for problem in results: + by_lang.setdefault(problem["lang"], []).append(problem) + + inp_queue = multiprocessing.Queue() + out_queue = multiprocessing.Queue() + + n_problems = 0 + + for lang, problems in by_lang.items(): + if lang not in ["cpp", "python", "php", "javascript", "ruby", "typescript"]: + continue + + n_problems += len(problems) + for problem in problems: + inp_queue.put(problem) + + start = timeit.default_timer() + workers = [] + for _ in range(n_workers): + w = multiprocessing.Process(target=worker, args=(inp_queue, out_queue)) + w.start() + workers.append(w) + + passes = {} + n_passed = 0 + lang_passed = {} + lang_counts = {} + for i in tqdm(range(n_problems)): + key, lang, passed, result, response = out_queue.get() + passes[key] = {"passed": passed, "result": result, "response": response} + n_passed += passed + + lang_passed.setdefault(lang, 0) + lang_passed[lang] += passed + + lang_counts.setdefault(lang, 0) + lang_counts[lang] += 1 + + end = timeit.default_timer() + print(f"Processed {n_problems} in {end - start}s") + print(f"{100 * n_passed / n_problems: .02f}% pass@1") + print(lang_passed, lang_counts) + with open("evaluated_test.json", "w") as f: + json.dump(passes, f, indent=2) + + return 100 * n_passed / n_problems diff --git a/examples/text-generation/mbxp_evaluation/evaluation.py b/examples/text-generation/mbxp_evaluation/evaluation.py new file mode 100644 index 0000000000..da87aff83b --- /dev/null +++ b/examples/text-generation/mbxp_evaluation/evaluation.py @@ -0,0 +1,266 @@ +import argparse +import json +import os +import re + +import evaluate +import nltk +import numpy as np +from transformers import AutoTokenizer + + +N_WORKERS = 12 + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-path", + required=True, + help="Path to model checkpoint", + ) + parser.add_argument("--accuracy-file", required=True, help="Path to accuracy.json") + parser.add_argument( + "--target-file", + required=True, + help="Path to target.json file with accuracy results that we want to compare with", + ) + parser.add_argument( + "--performance-file", default="", help="Path to performance results that we want include with accuracy results" + ) + parser.add_argument( + "--dataset-mix", action="store_true", help="This flag allows to use mix dataset (openorca, gsm8k, mbxp)" + ) + parser.add_argument( + "--dataset-file", + required=True, + help="Path to processed validation dataset", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose messages") + parser.add_argument( + "--dtype", default="int64", help="dtype of the accuracy log", choices=["int32", "int64", "float"] + ) + args = parser.parse_args() + return args + + +def get_groundtruth(processed_dataset_file): + import pandas as pd + + data = pd.read_pickle(processed_dataset_file) + return data + + +def create_mbxp_dict(row, response): + lang, entry_point = row["id"].split("_", 1) + return { + "lang": lang, + "prompt": row["input"], + "test_code": row["gt_output"], + "entry_point": entry_point, + "response": response, + } + + +def maybe_remove_comma(x: str) -> str: + # Example: 5,600 -> 5600 + return x.replace(",", "") + + +def try_float(x: str): + try: + ret = float(x) + except BaseException: + ret = None + return ret + + +# Functions for evaluating GSM8K +def find_numbers(x: str) -> list[str]: + """Finds all numbers in a string.""" + # Search for number, possibly negative (hyphen), with thousand separators + # (comma), and with a decimal point (period inbetween digits). + numbers = re.compile( + r"-?[\d,]*\.?\d+", + re.MULTILINE | re.DOTALL | re.IGNORECASE, + ).findall(x) + return numbers + + +def find_number(x: str, answer_delimiter: str = "The answer is") -> str: + """Finds the most relevant number in a string.""" + # If model uses the answer delimiter, then select the first number following + # that format. + if answer_delimiter in x: + answer = x.split(answer_delimiter)[-1] + numbers = find_numbers(answer) + if numbers: + return numbers[0] + + # In general, select the last number in the string. + numbers = find_numbers(x) + if numbers: + return numbers[-1] + return "" + + +def get_estimated_performance(output_file): + try: + with open(output_file, "r") as file: + log_content = file.read() + match = re.search(r"Estimated performance for accuracy run is (\d+(\.\d+)?)", log_content) + estimated_performance = float(match.group(1)) if match else 0 + return estimated_performance + except FileNotFoundError: + return 0 + + +def postprocess_text(preds, targets): + preds = [pred.strip() for pred in preds] + targets = [target.strip() for target in targets] + + # rougeLSum expects newline after each sentence + preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] + + return preds, targets + + +def main(): + # Adding language specific paths to PATH + os.environ["PATH"] = ( + f"{os.environ.get('PATH')}:/usr/local/swift-5.7-RELEASE-ubuntu20.04/usr/bin:/usr/local/go/bin:/usr/local/bin:/root/.nvm/versions/node/v16.10.0/bin" + ) + # This is important for PHP language tests + os.environ["LD_PRELOAD"] = "" + + args = get_args() + checkpoint_path = args.checkpoint_path + metric = evaluate.load("rouge") + nltk.download("punkt") + + tokenizer = AutoTokenizer.from_pretrained( + checkpoint_path, + model_max_length=2048, + padding_side="left", + use_fast=False, + ) + + with open(args.target_file, "r") as f: + acc_json = json.load(f) + + data = get_groundtruth(args.dataset_file) + if args.dataset_mix: + acc_target = acc_json["mix"] + query_types, gt_outputs = data["dataset"], data["gt_output"] + else: + acc_target = acc_json["openorca"] + gt_outputs = data["output"] + + target_required_OpenOrca = [] + preds_token_ids_OpenOrca = [] + target_required_GSM8K = [] + preds_token_GSM8K = [] + results_MBXP = [] + + eval_dtype = np.int64 + if args.dtype == "int32": + eval_dtype = np.int32 + elif args.dtype == "float": + eval_dtype = np.float32 + + with open(args.accuracy_file, "r") as f: + results = json.load(f) + + seen = set() + gen_tok_len = 0 + gen_num = 0 + for pred in results: + gen_num += 1 + qsl_idx = pred["qsl_idx"] + if qsl_idx in seen: + continue + + seen.add(qsl_idx) + + if args.dataset_mix: + query_type = query_types.iloc[qsl_idx] + else: + query_type = "OpenOrca" + + if query_type == "GSM8K": + target = gt_outputs.iloc[qsl_idx] + target_required_GSM8K.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + + gen_tok_len += len(pred) + preds_token_GSM8K.append(pred) + elif query_type == "OpenOrca": + if args.dataset_mix: + target = gt_outputs.iloc[qsl_idx] + else: + target = gt_outputs[qsl_idx] + target_required_OpenOrca.append(target) + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + + gen_tok_len += len(pred) + preds_token_ids_OpenOrca.append(pred) + else: + target = data.iloc[qsl_idx] + pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) + pred_str = tokenizer.decode(pred, skip_special_tokens=True) + results_MBXP.append(create_mbxp_dict(target, pred_str)) + + gen_tok_len += len(pred) + + # OpenOrca metric + preds_decoded_text = tokenizer.batch_decode(preds_token_ids_OpenOrca, skip_special_tokens=True) + preds, targets = postprocess_text(preds_decoded_text, target_required_OpenOrca) + result = metric.compute(predictions=preds, references=targets, use_stemmer=True, use_aggregator=False) + result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} + prediction_lens = [len(pred) for pred in preds] + if args.dataset_mix: + # GSM8K metric + preds_decoded_text = tokenizer.batch_decode(preds_token_GSM8K, skip_special_tokens=True) + pred_nums = [maybe_remove_comma(find_number(pred_text.split("\nQ:")[0])) for pred_text in preds_decoded_text] + gsm8k_total = len(target_required_GSM8K) + correct = 0 + for idx in range(len(target_required_GSM8K)): + ref = try_float(target_required_GSM8K[idx]) + tgt = try_float(pred_nums[idx]) + if tgt is None: + continue + correct += ref == tgt + + result["gsm8k"] = 100.0 * correct / gsm8k_total + + # MBXP metric + from evaluate_mbxp import evaluate_mbxp + + result["mbxp"] = evaluate_mbxp(results_MBXP, N_WORKERS) + + ################## Habana internal code ################################## + # It does not impact values reported as in the reference implementation. + # It adds additional "accuracy" field which is used for internal testing. + acc = [result[key] / acc_target[key] for key in acc_target] + acc = round(np.min(acc) * 100, 2) + performance = get_estimated_performance(args.performance_file) + ########################################################################## + + result = { + **result, + "gen_len": np.sum(prediction_lens), + "gen_num": gen_num, + "gen_tok_len": gen_tok_len, + "tokens_per_sample": round(gen_tok_len / gen_num, 1), + "performance": performance, + "accuracy": acc, + } + + print("\nResults\n") + print(result) + + +if __name__ == "__main__": + main() + diff --git a/examples/text-generation/mbxp_evaluation/evaluation_setup/ubuntu.sh b/examples/text-generation/mbxp_evaluation/evaluation_setup/ubuntu.sh new file mode 100755 index 0000000000..b1daffc0ef --- /dev/null +++ b/examples/text-generation/mbxp_evaluation/evaluation_setup/ubuntu.sh @@ -0,0 +1,35 @@ +#!/usr/bin/bash + +apt update +echo "--> Ruby" +apt install -y ruby-full + +echo "--> PHP" +apt install -y software-properties-common ca-certificates lsb-release apt-transport-https +add-apt-repository ppa:ondrej/php +apt update -y +apt install -y php-{pear,cgi,common,curl,mbstring,gd,bcmath,json,xml,fpm,intl,zip} php8.0 + + +echo "--> JavaScript" +apt install curl +curl https://raw.githubusercontent.com/creationix/nvm/master/install.sh | bash +# Check if the lines containing NVM_DIR already exist in .bashrc +if ! grep -q 'NVM_DIR' ~/.bashrc; then + echo "# --- NVM ---" >> ~/.bashrc + grep 'NVM_DIR' ~/.zshrc >> ~/.bashrc +fi +PS1=1 source ~/.bashrc +apt install npm +nvm install 20.17.0 +node -e "console.log('Running Node.js ' + process.version)" +npm i -g npm +npm install -g lodash +npm i --save lodash + + +echo "--> TypeScript" +npm install -g typescript + + + diff --git a/examples/text-generation/mbxp_evaluation/setup.sh b/examples/text-generation/mbxp_evaluation/setup.sh new file mode 100755 index 0000000000..3df584eafa --- /dev/null +++ b/examples/text-generation/mbxp_evaluation/setup.sh @@ -0,0 +1,13 @@ +#!/bin/bash +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +set -xe + +apt-get update +git clone https://github.com/amazon-science/mxeval.git +pip install -e mxeval +sed -i 's/npx tsc/tsc/g' mxeval/mxeval/execution.py +cp mbxp_evaluation/evaluation_setup/ubuntu.sh mxeval/language_setup/ubuntu.sh +PATH="$HOME/.rbenv/bin:$PATH" bash mxeval/language_setup/ubuntu.sh diff --git a/examples/text-generation/mbxp_evaluation/target_accuracy.json b/examples/text-generation/mbxp_evaluation/target_accuracy.json new file mode 100644 index 0000000000..5a1f921f9f --- /dev/null +++ b/examples/text-generation/mbxp_evaluation/target_accuracy.json @@ -0,0 +1,8 @@ +{ + "mix": { + "rouge1": 45.4911, "rouge2": 23.2829, "rougeL": 30.3615, "gsm8k": 73.78, "mbxp": 60.16 + }, + "openorca": { + "rouge1": 44.4312, "rouge2": 22.0352, "rougeL": 28.6162 + } +}