Skip to content

Commit 3d6cc4a

Browse files
authored
Merge pull request #282 from yangheng95/dev
2.1.12
2 parents 14e32b7 + 89b6d5b commit 3d6cc4a

File tree

6 files changed

+100
-67
lines changed

6 files changed

+100
-67
lines changed

examples-v2/aspect_opinion_sentiment_category_extraction/inference.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from pyabsa import ABSAInstruction
1111

1212
if __name__ == "__main__":
13-
generator = ABSAInstruction.ABSAGenerator("multilingual")
13+
generator = ABSAInstruction.ABSAGenerator(
14+
"checkpoints/multitask/googleflan-t5-base-instruction/checkpoint-2745"
15+
)
1416
example = [
1517
"The food is good, but the service is bad.",
1618
"The laptop is good, but the battery life is bad.",

examples-v2/aspect_opinion_sentiment_category_extraction/multitask_train.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@
1111

1212
import findfile
1313
from pyabsa import ABSAInstruction as absa_instruction
14+
1415
warnings.filterwarnings("ignore")
1516
import pandas as pd
1617

1718

1819
task_name = "multitask"
1920
experiment_name = "instruction"
2021
# model_checkpoint = 'allenai/tk-instruct-base-def-pos'
21-
model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined"
22+
# model_checkpoint = "kevinscaria/ate_tk-instruct-base-def-pos-neg-neut-combined"
2223
# model_checkpoint = 'allenai/tk-instruct-large-def-pos'
2324
# model_checkpoint = 'allenai/tk-instruct-3b-def-pos'
24-
# model_checkpoint = 'google/mt5-base'
25+
model_checkpoint = "google/flan-t5-base"
2526

2627
print("Experiment Name: ", experiment_name)
2728
model_out_path = "checkpoints"
@@ -33,12 +34,12 @@
3334
# Load the data
3435
# id_train_file_path = './integrated_datasets'
3536
# id_test_file_path = './integrated_datasets'
36-
# id_train_file_path = "./integrated_datasets/acos_datasets/"
37-
# id_test_file_path = "./integrated_datasets/acos_datasets"
38-
id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
39-
id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
40-
# id_train_file_path = './integrated_datasets/acos_datasets/504.Restaurant16'
41-
# id_test_file_path = './integrated_datasets/acos_datasets/504.Restaurant16'
37+
id_train_file_path = "./integrated_datasets/acos_datasets/"
38+
id_test_file_path = "./integrated_datasets/acos_datasets"
39+
# id_train_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
40+
# id_test_file_path = './integrated_datasets/acos_datasets/501.Laptop14'
41+
# id_train_file_path = './integrated_datasets/acos_datasets/502.Restaurant14'
42+
# id_test_file_path = './integrated_datasets/acos_datasets/502.Restaurant14'
4243

4344

4445
id_tr_df = absa_instruction.data_utils.read_json(id_train_file_path, "train")
@@ -72,9 +73,9 @@
7273
"evaluation_strategy": "epoch",
7374
"save_strategy": "epoch",
7475
"learning_rate": 5e-5,
75-
"per_device_train_batch_size": 4,
76+
"per_device_train_batch_size": 16,
7677
"per_device_eval_batch_size": 16,
77-
"num_train_epochs": 6,
78+
"num_train_epochs": 3,
7879
"weight_decay": 0.01,
7980
"warmup_ratio": 0.1,
8081
"load_best_model_at_end": True,

pyabsa/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# Copyright (C) 2021. All Rights Reserved.
88

99
__name__ = "pyabsa"
10-
__version__ = "2.1.11"
10+
__version__ = "2.1.12"
1111

1212
from pyabsa.framework.flag_class import *
1313

pyabsa/tasks/ABSAInstruction/data_utils.py

+36-25
Original file line numberDiff line numberDiff line change
@@ -48,57 +48,67 @@ def prepare_instruction_dataloader(self, df):
4848
cat_instructor = CategoryInstruction()
4949
alldata = []
5050
for i, data in df.iterrows():
51-
_aspects = [label["aspect"] for label in data["labels"]]
51+
_aspects = ["aspect:" + label["aspect"] for label in data["labels"]]
5252
aspects = []
5353
for asp in _aspects:
5454
if asp.strip() not in aspects:
5555
aspects.append(asp.strip())
56-
aspects = ", ".join(aspects)
57-
alldata.append(
58-
{"text": ate_instructor.prepare_input(data["text"]), "labels": aspects}
59-
)
56+
aspects = "|".join(aspects)
6057

61-
opinions = ", ".join(
58+
polarities = []
59+
_polarities = [
60+
"{}:{}".format(label["aspect"], label["polarity"])
61+
for label in data["labels"]
62+
]
63+
for pol in _polarities:
64+
if pol not in polarities:
65+
polarities.append(pol)
66+
polarities = "|".join(polarities)
67+
68+
opinions = "|".join(
6269
[
6370
"{}:{}".format(label["aspect"], label["opinion"])
6471
for label in data["labels"]
6572
]
6673
)
67-
alldata.append(
68-
{
69-
"text": op_instructor.prepare_input(data["text"], aspects),
70-
"labels": opinions,
71-
}
72-
)
7374

74-
polarities = ", ".join(
75+
categories = "|".join(
7576
[
76-
"{}:{}".format(label["aspect"], label["polarity"])
77+
"{}:{}".format(label["aspect"], label["category"])
7778
for label in data["labels"]
7879
]
7980
)
81+
82+
# ATE task
83+
alldata.append(
84+
{"text": ate_instructor.prepare_input(data["text"]), "labels": aspects}
85+
)
86+
87+
# APC task
8088
alldata.append(
8189
{
8290
"text": apc_instructor.prepare_input(data["text"], aspects),
8391
"labels": polarities,
8492
}
8593
)
8694

87-
categories = ", ".join(
88-
[
89-
"{}:{}".format(
90-
label["aspect"], label["category"].replace("NULL", "")
91-
)
92-
for label in data["labels"]
93-
]
94-
)
95+
# Opinion task
9596
alldata.append(
9697
{
97-
"text": cat_instructor.prepare_input(data["text"], aspects),
98-
"labels": categories,
98+
"text": op_instructor.prepare_input(data["text"], aspects),
99+
"labels": opinions,
99100
}
100101
)
101-
# print(alldata[-1]['labels'])
102+
103+
# Category task
104+
if "NULL" not in categories:
105+
alldata.append(
106+
{
107+
"text": cat_instructor.prepare_input(data["text"], aspects),
108+
"labels": categories,
109+
}
110+
)
111+
102112
alldata = pd.DataFrame(alldata)
103113
return alldata
104114

@@ -163,6 +173,7 @@ def read_json(data_path, data_type="train"):
163173

164174
files = findfile.find_files(data_path, [data_type, ".jsonl"], exclude_key=[".txt"])
165175
for f in files:
176+
print(f)
166177
with open(f, "r", encoding="utf8") as fin:
167178
for line in fin:
168179
data.append(json.loads(line))

pyabsa/tasks/ABSAInstruction/instruction.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
3131
example 1-
3232
input: I charge it at night and skip taking the cord with me because of the good battery life.
3333
{self.eos_instruction}
34-
battery life, cord
34+
aspect:battery life|aspect:cord
3535
3636
example 2-
3737
input: Great food, good size menu, great service and an unpretensious setting.
3838
{self.eos_instruction}
39-
food, menu, service, setting
39+
aspect:food|aspect:menu|aspect:service|aspect:setting
4040
4141
Now extract aspects from the following example:
4242
input: """
@@ -64,13 +64,13 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
6464
input: I charge it at night and skip taking the cord with me because of the good battery life.
6565
The aspects are: battery life, cord
6666
{self.eos_instruction}
67-
battery life:positive, cord:positive
67+
battery life:positive|cord:positive
6868
6969
example 2-
7070
input: Great food, good size menu, great service and an unpretensious setting.
7171
The aspects are: food, menu, service, setting
7272
{self.eos_instruction}
73-
food:positive, menu:positive, service:positive, setting:positive
73+
food:positive|menu:positive|service:positive|setting:positive
7474
7575
Now predict aspect sentiments from the following example:
7676
@@ -103,13 +103,13 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
103103
input: I charge it at night and skip taking the cord with me because of the good battery life.
104104
The aspects are: battery life, cord
105105
{self.eos_instruction}
106-
battery life:good, cord:NULL
106+
battery life:good|cord:NULL
107107
108108
example 2-
109109
input: Great food, good size menu, great service and an unpretensious setting.
110110
The aspects are: food, menu, service, setting
111111
{self.eos_instruction}
112-
food:great, menu:good, service:great, setting:unpretensious
112+
food:great|menu:good|service:great|setting:unpretensious
113113
114114
Now extract opinions for the following example:
115115
input:"""
@@ -141,11 +141,11 @@ def __init__(self, bos_instruction=None, eos_instruction=None):
141141
input: I charge it at night and skip taking the cord with me because of the good battery life.
142142
The aspects are: battery life, cord
143143
{self.eos_instruction}
144-
battery life:POWER_SUPPLY#GENERAL, cord:NULL
144+
battery life:POWER_SUPPLY#GENERAL|cord:NULL
145145
146146
example 2-
147147
input: Great food, good size menu, great service and an unpretensious setting.
148-
The aspects are: food, menu, service, setting
148+
The aspects are: food:FOOD#QUALITY| menu:RESTAURANT#GENERAL|service:SERVICE#GENERAL|setting:SERVICE#GENERAL
149149
{self.eos_instruction}
150150
food:FOOD#QUALITY, menu:RESTAURANT#GENERAL, service:SERVICE#GENERAL, setting:SERVICE#GENERAL
151151

pyabsa/tasks/ABSAInstruction/model.py

+41-22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import autocuda
2+
import sklearn
23
import torch
34
from pyabsa.framework.checkpoint_class.checkpoint_template import CheckpointManager
45
from torch.utils.data import DataLoader
@@ -32,6 +33,7 @@ def __init__(self, checkpoint):
3233

3334
self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
3435
self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
36+
self.model.config.max_length = 128
3537
self.data_collator = DataCollatorForSeq2Seq(self.tokenizer)
3638
self.device = autocuda.auto_cuda()
3739
self.model.to(self.device)
@@ -94,7 +96,7 @@ def predict(self, text, **kwargs):
9496
ate_outputs = self.tokenizer.batch_decode(
9597
ate_outputs, skip_special_tokens=True
9698
)[0]
97-
result["aspect"] = [asp.strip() for asp in ate_outputs.split(",")]
99+
result["aspect"] = [asp.strip() for asp in ate_outputs.split("|")]
98100

99101
# APC inference
100102
inputs = self.tokenizer(
@@ -106,7 +108,7 @@ def predict(self, text, **kwargs):
106108
apc_outputs = self.tokenizer.batch_decode(
107109
apc_outputs, skip_special_tokens=True
108110
)[0]
109-
result["sentiment"] = [sent.strip() for sent in apc_outputs.split(",")]
111+
result["sentiment"] = [sent.strip() for sent in apc_outputs.split("|")]
110112

111113
# Opinion inference
112114
inputs = self.tokenizer(
@@ -118,7 +120,7 @@ def predict(self, text, **kwargs):
118120
op_outputs = self.tokenizer.batch_decode(op_outputs, skip_special_tokens=True)[
119121
0
120122
]
121-
result["opinion"] = [op.strip() for op in op_outputs.split(",")]
123+
result["opinion"] = [op.strip() for op in op_outputs.split("|")]
122124

123125
# Category inference
124126
inputs = self.tokenizer(
@@ -130,7 +132,7 @@ def predict(self, text, **kwargs):
130132
cat_outputs = self.tokenizer.batch_decode(
131133
cat_outputs, skip_special_tokens=True
132134
)[0]
133-
result["category"] = [cat.strip() for cat in cat_outputs.split(",")]
135+
result["category"] = [cat.strip() for cat in cat_outputs.split("|")]
134136
ensemble_result = {
135137
"text": text,
136138
"Quadruples": [
@@ -207,26 +209,43 @@ def get_aspect_metrics(self, true_aspects, pred_aspects):
207209
return aspect_p, aspect_r, aspect_f1
208210

209211
def get_classic_metrics(self, y_true, y_pred):
210-
total_pred = 0
211-
total_gt = 0
212-
tp = 1e-6
212+
valid_gts = []
213+
valid_preds = []
213214
for gt, pred in zip(y_true, y_pred):
214-
print(gt)
215-
print(pred)
216-
217-
gt_list = gt.split(", ")
218-
pred_list = pred.split(", ")
219-
total_pred += len(pred_list)
220-
total_gt += len(gt_list)
221-
for gt_val in gt_list:
215+
gt_list = gt.split("|")
216+
pred_list = pred.split("|")
217+
while gt_list:
218+
gt_val = gt_list[-1].strip().lower()
222219
for pred_val in pred_list:
223-
gt_val = gt_val.replace(" ", "")
224-
pred_val = pred_val.replace(" ", "")
225-
if pred_val.strip().lower() == gt_val.strip().lower():
226-
tp += 1
227-
p = tp / total_pred
228-
r = tp / total_gt
229-
return {"precision": p, "recall": r, "f1": 2 * p * r / (p + r)}
220+
pred_val = pred_val.strip().lower()
221+
gt_key, _, gt_label = gt_val.partition(":")
222+
pred_key, _, pred_label = pred_val.partition(":")
223+
if gt_key.startswith(pred_key):
224+
if gt_label:
225+
valid_gts.append(gt_label)
226+
else:
227+
break
228+
if pred_label:
229+
valid_preds.append(pred_label)
230+
else:
231+
valid_preds.append("")
232+
break
233+
234+
gt_list.pop()
235+
236+
report = sklearn.metrics.classification_report(valid_gts, valid_preds)
237+
print(report)
238+
accuracy = sklearn.metrics.accuracy_score(valid_gts, valid_preds)
239+
precision = precision_score(valid_gts, valid_preds, average="macro")
240+
recall = recall_score(valid_gts, valid_preds, average="macro")
241+
f1 = f1_score(valid_gts, valid_preds, average="macro")
242+
243+
return {
244+
"accuracy": accuracy,
245+
"precision": precision,
246+
"recall": recall,
247+
"f1": f1,
248+
}
230249

231250
# def get_classic_metrics(self, y_true, y_pred):
232251
#

0 commit comments

Comments
 (0)