Skip to content

Commit

Permalink
[text_classification_retrieval_based] fix bug of evaluate.py (#7844)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeyuTeng96 authored Jan 16, 2024
1 parent 04142e3 commit c9d22db
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similar pair file")
parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file")
parser.add_argument(
"--recall_num", type=int, default=10, help="Most similair number of doc recalled from corpus per query"
"--similar_text_pair",
type=str,
default="",
help="The full path of similar pair file",
)
parser.add_argument(
"--recall_result_file",
type=str,
default="",
help="The full path of recall result file",
)
parser.add_argument(
"--recall_num",
type=int,
default=10,
help="Most similair number of doc recalled from corpus per query",
)
args = parser.parse_args()

Expand Down Expand Up @@ -57,17 +70,24 @@ def recall(rs, N=10):
with open(args.recall_result_file, "r", encoding="utf-8") as f:
relevance_labels = []
for index, line in enumerate(f):

if index % args.recall_num == 0 and index != 0:
rs.append(relevance_labels)
relevance_labels = []
text_arr = line.rstrip().split("\t")
text_title, text_para, recalled_title, recalled_para, label, cosine_sim = text_arr
(
text_title,
text_para,
recalled_title,
recalled_para,
label,
cosine_sim,
) = text_arr
if text2similar["\t".join([text_title, text_para])] == label:
relevance_labels.append(1)
else:
relevance_labels.append(0)

if (index + 1) % args.recall_num == 0:
rs.append(relevance_labels)
relevance_labels = []

recall_N = []
recall_num = [1, 5, 10, 20, 50]
for topN in recall_num:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similar pair file")
parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file")
parser.add_argument(
"--recall_num", type=int, default=10, help="Most similar number of doc recalled from corpus per query"
"--similar_text_pair",
type=str,
default="",
help="The full path of similar pair file",
)
parser.add_argument(
"--recall_result_file",
type=str,
default="",
help="The full path of recall result file",
)
parser.add_argument(
"--recall_num",
type=int,
default=10,
help="Most similar number of doc recalled from corpus per query",
)
args = parser.parse_args()

Expand Down Expand Up @@ -57,17 +70,24 @@ def recall(rs, N=10):
with open(args.recall_result_file, "r", encoding="utf-8") as f:
relevance_labels = []
for index, line in enumerate(f):

if index % args.recall_num == 0 and index != 0:
rs.append(relevance_labels)
relevance_labels = []
text_arr = line.rstrip().split("\t")
text_title, text_para, recalled_title, recalled_para, label, cosine_sim = text_arr
(
text_title,
text_para,
recalled_title,
recalled_para,
label,
cosine_sim,
) = text_arr
if text2similar["\t".join([text_title, text_para])] == label:
relevance_labels.append(1)
else:
relevance_labels.append(0)

if (index + 1) % args.recall_num == 0:
rs.append(relevance_labels)
relevance_labels = []

recall_N = []
recall_num = [1, 5, 10, 20, 50]
for topN in recall_num:
Expand Down

0 comments on commit c9d22db

Please sign in to comment.