-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
38 lines (35 loc) · 1.32 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from transformers import AutoModel, RobertaPreTrainedModel
from torch.cuda.amp import autocast
from roberta import RobertaModel
class REModel(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.roberta = RobertaModel(config=config)
hidden_size = config.hidden_size
self.loss_fnt = nn.CrossEntropyLoss()
self.classifier = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(p=config.dropout_prob),
nn.Linear(hidden_size, config.num_class)
)
@autocast()
def forward(self, input_ids=None, attention_mask=None, labels=None, ss=None, os=None, entity_mask=None):
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
entity_mask=entity_mask,
)
pooled_output = outputs[0]
idx = torch.arange(input_ids.size(0)).to(input_ids.device)
ss_emb = pooled_output[idx, ss]
os_emb = pooled_output[idx, os]
h = torch.cat((ss_emb, os_emb), dim=-1)
logits = self.classifier(h)
outputs = (logits,)
if labels is not None:
loss = self.loss_fnt(logits.float(), labels)
outputs = (loss,) + outputs
return outputs