From c9d22dba8783c4267edfd72581541275a7873f44 Mon Sep 17 00:00:00 2001 From: ZeyuTeng96 <96521059+ZeyuTeng96@users.noreply.github.com> Date: Tue, 16 Jan 2024 12:17:27 +0800 Subject: [PATCH] [text_classification_retrieval_based] fix bug of evaluate.py (#7844) --- .../hierarchical/retrieval_based/evaluate.py | 36 ++++++++++++++----- .../multi_class/retrieval_based/evaluate.py | 36 ++++++++++++++----- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/applications/text_classification/hierarchical/retrieval_based/evaluate.py b/applications/text_classification/hierarchical/retrieval_based/evaluate.py index c315b0d8b129..d4fe06353237 100644 --- a/applications/text_classification/hierarchical/retrieval_based/evaluate.py +++ b/applications/text_classification/hierarchical/retrieval_based/evaluate.py @@ -18,10 +18,23 @@ import numpy as np parser = argparse.ArgumentParser() -parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similar pair file") -parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file") parser.add_argument( - "--recall_num", type=int, default=10, help="Most similair number of doc recalled from corpus per query" + "--similar_text_pair", + type=str, + default="", + help="The full path of similar pair file", +) +parser.add_argument( + "--recall_result_file", + type=str, + default="", + help="The full path of recall result file", +) +parser.add_argument( + "--recall_num", + type=int, + default=10, + help="Most similair number of doc recalled from corpus per query", ) args = parser.parse_args() @@ -57,17 +70,24 @@ def recall(rs, N=10): with open(args.recall_result_file, "r", encoding="utf-8") as f: relevance_labels = [] for index, line in enumerate(f): - - if index % args.recall_num == 0 and index != 0: - rs.append(relevance_labels) - relevance_labels = [] text_arr = line.rstrip().split("\t") - text_title, text_para, recalled_title, recalled_para, label, cosine_sim = text_arr + ( + text_title, + text_para, + recalled_title, + recalled_para, + label, + cosine_sim, + ) = text_arr if text2similar["\t".join([text_title, text_para])] == label: relevance_labels.append(1) else: relevance_labels.append(0) + if (index + 1) % args.recall_num == 0: + rs.append(relevance_labels) + relevance_labels = [] + recall_N = [] recall_num = [1, 5, 10, 20, 50] for topN in recall_num: diff --git a/applications/text_classification/multi_class/retrieval_based/evaluate.py b/applications/text_classification/multi_class/retrieval_based/evaluate.py index bff2cfb814aa..944df9157a30 100644 --- a/applications/text_classification/multi_class/retrieval_based/evaluate.py +++ b/applications/text_classification/multi_class/retrieval_based/evaluate.py @@ -18,10 +18,23 @@ import numpy as np parser = argparse.ArgumentParser() -parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similar pair file") -parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file") parser.add_argument( - "--recall_num", type=int, default=10, help="Most similar number of doc recalled from corpus per query" + "--similar_text_pair", + type=str, + default="", + help="The full path of similar pair file", +) +parser.add_argument( + "--recall_result_file", + type=str, + default="", + help="The full path of recall result file", +) +parser.add_argument( + "--recall_num", + type=int, + default=10, + help="Most similar number of doc recalled from corpus per query", ) args = parser.parse_args() @@ -57,17 +70,24 @@ def recall(rs, N=10): with open(args.recall_result_file, "r", encoding="utf-8") as f: relevance_labels = [] for index, line in enumerate(f): - - if index % args.recall_num == 0 and index != 0: - rs.append(relevance_labels) - relevance_labels = [] text_arr = line.rstrip().split("\t") - text_title, text_para, recalled_title, recalled_para, label, cosine_sim = text_arr + ( + text_title, + text_para, + recalled_title, + recalled_para, + label, + cosine_sim, + ) = text_arr if text2similar["\t".join([text_title, text_para])] == label: relevance_labels.append(1) else: relevance_labels.append(0) + if (index + 1) % args.recall_num == 0: + rs.append(relevance_labels) + relevance_labels = [] + recall_N = [] recall_num = [1, 5, 10, 20, 50] for topN in recall_num: