19
19
from torch .utils .data import DataLoader , RandomSampler , SequentialSampler
20
20
from transformers import AutoTokenizer , AutoModel
21
21
22
- from pyabsa . tasks . AspectPolarityClassification .models .__classic__ import GloVeAPCModelList
23
- from pyabsa . tasks . AspectPolarityClassification .models .__lcf__ import APCModelList
24
- from pyabsa . tasks . AspectPolarityClassification .models .__plm__ import BERTBaselineAPCModelList
25
- from pyabsa . tasks . AspectPolarityClassification .dataset_utils .__classic__ .data_utils_for_training import GloVeABSADataset
26
- from pyabsa . tasks . AspectPolarityClassification .dataset_utils .__lcf__ .data_utils_for_training import ABSADataset
27
- from pyabsa . tasks . AspectPolarityClassification .dataset_utils .__plm__ .data_utils_for_training import BERTBaselineABSADataset
22
+ from . .models .__classic__ import GloVeAPCModelList
23
+ from . .models .__lcf__ import APCModelList
24
+ from . .models .__plm__ import BERTBaselineAPCModelList
25
+ from . .dataset_utils .__classic__ .data_utils_for_training import GloVeABSADataset
26
+ from . .dataset_utils .__lcf__ .data_utils_for_training import ABSADataset
27
+ from . .dataset_utils .__plm__ .data_utils_for_training import BERTBaselineABSADataset
28
28
from pyabsa .framework .tokenizer_class .tokenizer_class import PretrainedTokenizer , Tokenizer , build_embedding_matrix
29
29
30
30
@@ -119,7 +119,6 @@ def __init__(self, config, load_dataset=True, **kwargs):
119
119
self .valid_set = GloVeABSADataset (self .config , self .tokenizer , dataset_type = 'valid' ) if not self .valid_set else self .valid_set
120
120
121
121
self .models .append (models [i ](copy .deepcopy (self .embedding_matrix ) if self .config .deep_ensemble else self .embedding_matrix , self .config ))
122
- self .config .tokenizer = self .tokenizer
123
122
self .config .embedding_matrix = self .embedding_matrix
124
123
125
124
if self .config .cache_dataset and not os .path .exists (cache_path ) and not self .config .overwrite_cache :
@@ -137,6 +136,8 @@ def __init__(self, config, load_dataset=True, **kwargs):
137
136
valid_sampler = SequentialSampler (self .valid_set if not self .valid_set else self .valid_set )
138
137
self .valid_dataloader = DataLoader (self .valid_set , batch_size = self .config .batch_size , pin_memory = True , sampler = valid_sampler )
139
138
139
+ self .config .tokenizer = self .tokenizer
140
+
140
141
self .dense = nn .Linear (config .output_dim * len (models ), config .output_dim )
141
142
142
143
def forward (self , inputs ):
0 commit comments