1
1
import autocuda
2
+ import sklearn
2
3
import torch
3
4
from pyabsa .framework .checkpoint_class .checkpoint_template import CheckpointManager
4
5
from torch .utils .data import DataLoader
@@ -32,6 +33,7 @@ def __init__(self, checkpoint):
32
33
33
34
self .tokenizer = AutoTokenizer .from_pretrained (checkpoint )
34
35
self .model = AutoModelForSeq2SeqLM .from_pretrained (checkpoint )
36
+ self .model .config .max_length = 128
35
37
self .data_collator = DataCollatorForSeq2Seq (self .tokenizer )
36
38
self .device = autocuda .auto_cuda ()
37
39
self .model .to (self .device )
@@ -94,7 +96,7 @@ def predict(self, text, **kwargs):
94
96
ate_outputs = self .tokenizer .batch_decode (
95
97
ate_outputs , skip_special_tokens = True
96
98
)[0 ]
97
- result ["aspect" ] = [asp .strip () for asp in ate_outputs .split (", " )]
99
+ result ["aspect" ] = [asp .strip () for asp in ate_outputs .split ("| " )]
98
100
99
101
# APC inference
100
102
inputs = self .tokenizer (
@@ -106,7 +108,7 @@ def predict(self, text, **kwargs):
106
108
apc_outputs = self .tokenizer .batch_decode (
107
109
apc_outputs , skip_special_tokens = True
108
110
)[0 ]
109
- result ["sentiment" ] = [sent .strip () for sent in apc_outputs .split (", " )]
111
+ result ["sentiment" ] = [sent .strip () for sent in apc_outputs .split ("| " )]
110
112
111
113
# Opinion inference
112
114
inputs = self .tokenizer (
@@ -118,7 +120,7 @@ def predict(self, text, **kwargs):
118
120
op_outputs = self .tokenizer .batch_decode (op_outputs , skip_special_tokens = True )[
119
121
0
120
122
]
121
- result ["opinion" ] = [op .strip () for op in op_outputs .split (", " )]
123
+ result ["opinion" ] = [op .strip () for op in op_outputs .split ("| " )]
122
124
123
125
# Category inference
124
126
inputs = self .tokenizer (
@@ -130,7 +132,7 @@ def predict(self, text, **kwargs):
130
132
cat_outputs = self .tokenizer .batch_decode (
131
133
cat_outputs , skip_special_tokens = True
132
134
)[0 ]
133
- result ["category" ] = [cat .strip () for cat in cat_outputs .split (", " )]
135
+ result ["category" ] = [cat .strip () for cat in cat_outputs .split ("| " )]
134
136
ensemble_result = {
135
137
"text" : text ,
136
138
"Quadruples" : [
@@ -207,26 +209,43 @@ def get_aspect_metrics(self, true_aspects, pred_aspects):
207
209
return aspect_p , aspect_r , aspect_f1
208
210
209
211
def get_classic_metrics (self , y_true , y_pred ):
210
- total_pred = 0
211
- total_gt = 0
212
- tp = 1e-6
212
+ valid_gts = []
213
+ valid_preds = []
213
214
for gt , pred in zip (y_true , y_pred ):
214
- print (gt )
215
- print (pred )
216
-
217
- gt_list = gt .split (", " )
218
- pred_list = pred .split (", " )
219
- total_pred += len (pred_list )
220
- total_gt += len (gt_list )
221
- for gt_val in gt_list :
215
+ gt_list = gt .split ("|" )
216
+ pred_list = pred .split ("|" )
217
+ while gt_list :
218
+ gt_val = gt_list [- 1 ].strip ().lower ()
222
219
for pred_val in pred_list :
223
- gt_val = gt_val .replace (" " , "" )
224
- pred_val = pred_val .replace (" " , "" )
225
- if pred_val .strip ().lower () == gt_val .strip ().lower ():
226
- tp += 1
227
- p = tp / total_pred
228
- r = tp / total_gt
229
- return {"precision" : p , "recall" : r , "f1" : 2 * p * r / (p + r )}
220
+ pred_val = pred_val .strip ().lower ()
221
+ gt_key , _ , gt_label = gt_val .partition (":" )
222
+ pred_key , _ , pred_label = pred_val .partition (":" )
223
+ if gt_key .startswith (pred_key ):
224
+ if gt_label :
225
+ valid_gts .append (gt_label )
226
+ else :
227
+ break
228
+ if pred_label :
229
+ valid_preds .append (pred_label )
230
+ else :
231
+ valid_preds .append ("" )
232
+ break
233
+
234
+ gt_list .pop ()
235
+
236
+ report = sklearn .metrics .classification_report (valid_gts , valid_preds )
237
+ print (report )
238
+ accuracy = sklearn .metrics .accuracy_score (valid_gts , valid_preds )
239
+ precision = precision_score (valid_gts , valid_preds , average = "macro" )
240
+ recall = recall_score (valid_gts , valid_preds , average = "macro" )
241
+ f1 = f1_score (valid_gts , valid_preds , average = "macro" )
242
+
243
+ return {
244
+ "accuracy" : accuracy ,
245
+ "precision" : precision ,
246
+ "recall" : recall ,
247
+ "f1" : f1 ,
248
+ }
230
249
231
250
# def get_classic_metrics(self, y_true, y_pred):
232
251
#
0 commit comments