forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
91 lines (85 loc) · 3.53 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.metric import Metric, Accuracy
from paddlenlp.transformers import PPMiniLMForSequenceClassification, PPMiniLMTokenizer
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
MODEL_CLASSES = {
"ppminilm": (PPMiniLMForSequenceClassification, PPMiniLMTokenizer),
"bert": (BertForSequenceClassification, BertTokenizer)
}
METRIC_CLASSES = {
"afqmc": Accuracy,
"tnews": Accuracy,
"iflytek": Accuracy,
"ocnli": Accuracy,
"cmnli": Accuracy,
"cluewsc2020": Accuracy,
"csl": Accuracy,
}
def convert_example(example,
label_list,
tokenizer=None,
is_test=False,
max_seq_length=512,
**kwargs):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
example['label'] = np.array(example["label"], dtype="int64")
label = example['label']
# Convert raw text to feature
if 'keyword' in example: # CSL
sentence1 = " ".join(example['keyword'])
example = {
'sentence1': sentence1,
'sentence2': example['abst'],
'label': example['label']
}
elif 'target' in example: # wsc
text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
'target']['span1_text'], example['target']['span2_text'], example[
'target']['span1_index'], example['target']['span2_index']
text_list = list(text)
assert text[pronoun_idx:(pronoun_idx + len(pronoun)
)] == pronoun, "pronoun: {}".format(pronoun)
assert text[query_idx:(query_idx + len(query)
)] == query, "query: {}".format(query)
if pronoun_idx > query_idx:
text_list.insert(query_idx, "_")
text_list.insert(query_idx + len(query) + 1, "_")
text_list.insert(pronoun_idx + 2, "[")
text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
else:
text_list.insert(pronoun_idx, "[")
text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
text_list.insert(query_idx + 2, "_")
text_list.insert(query_idx + len(query) + 2 + 1, "_")
text = "".join(text_list)
example['sentence'] = text
if tokenizer is None:
return example
if 'sentence' in example:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
elif 'sentence1' in example:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
if not is_test:
return example['input_ids'], example['token_type_ids'], label
else:
return example['input_ids'], example['token_type_ids']