-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
56 lines (48 loc) · 1.44 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
class Embed(nn.Module):
def __init__(self,in_dim,embed_dim,pad_idx):
super().__init__()
self.embed = nn.Embedding(in_dim,embed_dim,pad_idx)
def forward(self,x):
return self.embed(x)
class SelfAttention(nn.Module):
def __init__(self,h_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(h_dim*2,h_dim),
nn.Tanh(),
nn.Linear(h_dim,1),
nn.Dropout(),
nn.Softmax(dim=1)
)
def forward(self,x):
# [B,S,H]
A = self.fc(x)
# A: [B,S,1]
x = x * A.repeat(1,1,x.shape[2]) # [B,S,H]
x = torch.sum(x,1,False)
# x: [B,H]
return x
class simpleNet(nn.Module):
def __init__(self,embed_dim,h_dim,out_dim):
super().__init__()
self.gru = nn.GRU(
input_size=embed_dim,
hidden_size=h_dim,
num_layers=3,
batch_first=True,
bidirectional=True
)
self.self_attention= SelfAttention(h_dim=h_dim)
self.fc = nn.Linear(h_dim*2,out_dim)
self.fc2 = nn.Linear(embed_dim,out_dim)
def forward(self,x):
# [B,S,E]
x, _ = self.gru(x)
# [B,S,H*2]
x = self.self_attention(x)
# [B,H*2]
x = self.fc(x)
# [B,O]
return x