Skip to content

Commit f0e2a2c

Browse files
authored
Update ner_eval.py
1 parent c4427eb commit f0e2a2c

File tree

1 file changed

+67
-1
lines changed

1 file changed

+67
-1
lines changed

ner_eval.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,77 @@ def eval_ner_cluner(inputfile):
4343
print("ner soft metrics: f1: %.4f, precision: %.4f, recall: %.4f" % (e_f1, e_precision, e_recall))
4444

4545

46+
def eval_ner_cluner4(inputfile):
47+
X, Y, Z = 1e-10, 1e-10, 1e-10
48+
labels_list = []
49+
c_data = []
50+
def extract_ner(text):
51+
res = {}
52+
p1 = re.compile(r'<name>(.*?)</name>')
53+
p2 = re.compile(r'<organization>(.*?)</organization>')
54+
p3 = re.compile(r'<scene>(.*?)</scene>')
55+
p4 = re.compile(r'<company>(.*?)</company>')
56+
p5 = re.compile(r'<movie>(.*?)</movie>')
57+
p6 = re.compile(r'<book>(.*?)</book>')
58+
p7 = re.compile(r'<government>(.*?)</government>')
59+
p8 = re.compile(r'<position>(.*?)</position>')
60+
p9 = re.compile(r'<address>(.*?)</address>')
61+
p10 = re.compile(r'<game>(.*?)</game>')
62+
63+
if p1.findall(text):
64+
for ent in p1.findall(text):
65+
res[ent] = 'name'
66+
if p2.findall(text):
67+
for ent in p2.findall(text):
68+
res[ent] = 'org'
69+
if p3.findall(text):
70+
for ent in p3.findall(text):
71+
res[ent] = 'scene'
72+
if p4.findall(text):
73+
for ent in p4.findall(text):
74+
res[ent] = 'com'
75+
if p5.findall(text):
76+
for ent in p5.findall(text):
77+
res[ent] = 'movie'
78+
if p6.findall(text):
79+
for ent in p6.findall(text):
80+
res[ent] = 'book'
81+
if p7.findall(text):
82+
for ent in p7.findall(text):
83+
res[ent] = 'gov'
84+
if p8.findall(text):
85+
for ent in p8.findall(text):
86+
res[ent] = 'pos'
87+
if p9.findall(text):
88+
for ent in p9.findall(text):
89+
res[ent] = 'loc'
90+
if p10.findall(text):
91+
for ent in p10.findall(text):
92+
res[ent] = 'game'
93+
return res
94+
95+
with open(inputfile,'r',encoding='utf-8') as f:
96+
for line in f:
97+
line = json.loads(line)
98+
label = line['labels']
99+
predict = line['output']
100+
gold_list = extract_ner(label)
101+
pred_list = extract_ner(predict)
102+
# print(gold_list)
103+
# print(pred_list)
104+
Z += len(gold_list)
105+
Y += len(pred_list)
106+
for k1 in gold_list:
107+
for k2 in pred_list:
108+
if k1 == k2 and gold_list[k1] == pred_list[k2]:
109+
X += 1
110+
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
111+
print("ner metrics: f1: %.4f, precision: %.4f, recall: %.4f" % (f1, precision, recall))
46112

47113

48114

49115
if __name__ =="__main__":
50116

51117
#evaluate ner
52118
inputfile = r'E:\openlab\ChatGLM2-6B\ptuning\output\ner\model1\checkpoint-3000\generated_predictions.txt'
53-
eval_ner_cluner(inputfile)
119+
eval_ner_cluner(inputfile)

0 commit comments

Comments
 (0)