forked from hunkim/PyTorchZeroToAll
-
Notifications
You must be signed in to change notification settings - Fork 0
/
13_1_rnn_classification_basics.py
95 lines (70 loc) · 2.93 KB
/
13_1_rnn_classification_basics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Original code is from https://github.com/spro/practical-pytorch
import time
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from name_dataset import NameDataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# Parameters and DataLoaders
HIDDEN_SIZE = 100
N_CHARS = 128 # ASCII
N_CLASSES = 18
class RNNClassifier(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNNClassifier, self).__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, input):
# Note: we run this all at once (over the whole input sequence)
# input = B x S . size(0) = B
batch_size = input.size(0)
# input: B x S -- (transpose) --> S x B
input = input.t()
# Embedding S x B -> S x B x I (embedding size)
print(" input", input.size())
embedded = self.embedding(input)
print(" embedding", embedded.size())
# Make a hidden
hidden = self._init_hidden(batch_size)
output, hidden = self.gru(embedded, hidden)
print(" gru hidden output", hidden.size())
# Use the last layer output as FC's input
# No need to unpack, since we are going to use hidden
fc_output = self.fc(hidden)
print(" fc output", fc_output.size())
return fc_output
def _init_hidden(self, batch_size):
hidden = torch.zeros(self.n_layers, batch_size, self.hidden_size)
return Variable(hidden)
# Help functions
def str2ascii_arr(msg):
arr = [ord(c) for c in msg]
return arr, len(arr)
# pad sequences and sort the tensor
def pad_sequences(vectorized_seqs, seq_lengths):
seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()
for idx, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
seq_tensor[idx, :seq_len] = torch.LongTensor(seq)
return seq_tensor
# Create necessary variables, lengths, and target
def make_variables(names):
sequence_and_length = [str2ascii_arr(name) for name in names]
vectorized_seqs = [sl[0] for sl in sequence_and_length]
seq_lengths = torch.LongTensor([sl[1] for sl in sequence_and_length])
return pad_sequences(vectorized_seqs, seq_lengths)
if __name__ == '__main__':
names = ['adylov', 'solan', 'hard', 'san']
classifier = RNNClassifier(N_CHARS, HIDDEN_SIZE, N_CLASSES)
for name in names:
arr, _ = str2ascii_arr(name)
inp = Variable(torch.LongTensor([arr]))
out = classifier(inp)
print("in", inp.size(), "out", out.size())
inputs = make_variables(names)
out = classifier(inputs)
print("batch in", inputs.size(), "batch out", out.size())