Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Ordered Neuron LSTM (#854)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #854

Implement an ordered neurons LSTM, currently supports multiple layers and dropout

Reviewed By: anchit

Differential Revision: D16363259

fbshipit-source-id: d3c35393a7afa5a71520e7d255875dd3c5949c4c
  • Loading branch information
Victor Ling authored and facebook-github-bot committed Jul 30, 2019
1 parent bc78c39 commit cddeb7a
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 7 deletions.
10 changes: 6 additions & 4 deletions demo/atis_joint_model/atis_joint_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
"representation": {
"BiLSTMDocSlotAttention": {
"lstm": {
"dropout": 0.5,
"lstm_dim": 366,
"num_layers": 2,
"bidirectional": true
"BiLSTM": {
"dropout": 0.5,
"lstm_dim": 366,
"num_layers": 2,
"bidirectional": true
}
},
"pooling": {
"SelfAttention": {
Expand Down
6 changes: 4 additions & 2 deletions pytext/exporters/test/text_model_exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
"representation": {
"BiLSTMDocSlotAttention": {
"lstm": {
"lstm_dim": 30,
"num_layers": 1
"BiLSTM": {
"lstm_dim": 30,
"num_layers": 1
}
},
"pooling": {
"SelfAttention": {
Expand Down
3 changes: 2 additions & 1 deletion pytext/models/representations/bilstm_doc_slot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytext.models.module import create_module

from .bilstm import BiLSTM
from .ordered_neuron_lstm import OrderedNeuronLSTM
from .pooling import MaxPool, MeanPool, SelfAttention
from .representation_base import RepresentationBase
from .slot_attention import SlotAttention
Expand Down Expand Up @@ -49,7 +50,7 @@ class BiLSTMDocSlotAttention(RepresentationBase):

class Config(RepresentationBase.Config, ConfigBase):
dropout: float = 0.4
lstm: BiLSTM.Config = BiLSTM.Config()
lstm: Union[BiLSTM.Config, OrderedNeuronLSTM.Config] = BiLSTM.Config()
pooling: Optional[
Union[SelfAttention.Config, MaxPool.Config, MeanPool.Config]
] = None
Expand Down
176 changes: 176 additions & 0 deletions pytext/models/representations/ordered_neuron_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.models.module import Module
from pytext.utils import cuda

from .representation_base import RepresentationBase


# A single layer of an Ordered Neuron LSTM
class OrderedNeuronLSTMLayer(Module):
def __init__(
self, embed_dim: int, lstm_dim: int, padding_value: float, dropout: float
) -> None:
super().__init__()
self.lstm_dim = lstm_dim
self.padding_value = padding_value
self.dropout = nn.Dropout(dropout)

total_size = embed_dim + lstm_dim
self.f_gate = nn.Linear(total_size, lstm_dim)
self.i_gate = nn.Linear(total_size, lstm_dim)
self.o_gate = nn.Linear(total_size, lstm_dim)
self.c_hat_gate = nn.Linear(total_size, lstm_dim)
self.master_forget_no_cumax_gate = nn.Linear(total_size, lstm_dim)
self.master_input_no_cumax_gate = nn.Linear(total_size, lstm_dim)

# embedded_tokens has shape (seq length, batch size, embed size)
# states = (hidden, context), where both hidden and context have
# shape (batch size, hidden size)
def forward(
self,
embedded_tokens: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor],
seq_lengths: List[int],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
hidden, context = states
batch_size = hidden.size(0)
all_context = []
all_hidden = []

if self.dropout.p > 0.0:
embedded_tokens = self.dropout(embedded_tokens)

for batch in embedded_tokens:
# Compute the normal LSTM gates
combined = torch.cat((batch, hidden), 1)
ft = self.f_gate(combined).sigmoid()
it = self.i_gate(combined).sigmoid()
ot = self.o_gate(combined).sigmoid()
c_hat = self.c_hat_gate(combined).tanh()

# Compute the master gates
master_forget_no_cumax = self.master_forget_no_cumax_gate(combined)
master_forget = torch.cumsum(
F.softmax(master_forget_no_cumax, dim=1), dim=1
)
master_input_no_cumax = self.master_input_no_cumax_gate(combined)
master_input = torch.cumsum(F.softmax(master_input_no_cumax, dim=1), dim=1)

# Combine master gates with normal LSTM gates
wt = master_forget * master_input
f_hat_t = ft * wt + (master_forget - wt)
i_hat_t = it * wt + (master_input - wt)

# Compute new context and hidden using final combined gates
context = f_hat_t * context + i_hat_t * c_hat
hidden = ot * context
all_context.append(context)
all_hidden.append(hidden)

# Compute what the final state (hidden and context for each element
# in the batch) should be based on seq_lengths
state_hidden = []
state_context = []

for i in range(batch_size):
seq_length = seq_lengths[i]
state_hidden.append(all_hidden[seq_length - 1][i])
state_context.append(all_context[seq_length - 1][i])

# Return hidden states across all time steps, and return a tuple
# containing the hidden and context for the last time step (might
# be different based on seq_lengths)
return (
torch.stack(all_hidden),
(torch.stack(state_hidden), torch.stack(state_context)),
)


# Ordered Neuron LSTM with any number of layers
class OrderedNeuronLSTM(RepresentationBase):
class Config(RepresentationBase.Config, ConfigBase):
dropout: float = 0.4
lstm_dim: int = 32
num_layers: int = 1

def __init__(
self, config: Config, embed_dim: int, padding_value: Optional[float] = 0.0
) -> None:
super().__init__(config)
self.representation_dim = config.lstm_dim
self.padding_value = padding_value
lstms = []
sizes = [embed_dim] + ([config.lstm_dim] * config.num_layers)

# Create an ONLstm for each hidden size, and chain them together
# using lstms
for i in range(len(sizes) - 1):
lstm = OrderedNeuronLSTMLayer(
sizes[i], sizes[i + 1], padding_value, config.dropout
)
lstms.append(lstm)

self.lstms = nn.ModuleList(lstms)

# rep has shape (batch size, seq length, embed dim)
# seq_lengths has sequence lengths for each case in the batch, used to
# pick the last hidden and context
# states is a tuple for initial hidden and context
def forward(
self,
rep: torch.Tensor,
seq_lengths: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if states is not None:
# Transpose states so hidden and context both have shape
# (num layers, batch size, lstm dim)
states = (
states[0].transpose(0, 1).contiguous(),
states[1].transpose(0, 1).contiguous(),
)
else:
# state has shape (num layers, batch size, lstm dim)
state = torch.zeros(
self.config.num_layers,
rep.size(0),
self.config.lstm_dim,
device=torch.cuda.current_device() if cuda.CUDA_ENABLED else None,
)

states = (state, state)

# hidden_by_layer is a list of hidden states for each layer of the
# network, and similarly for context_by_layer
hidden_by_layer, context_by_layer = states

# Collect the last hidden and context for each layer
last_hidden_by_layer = []
last_context_by_layer = []
rep = rep.transpose(0, 1).contiguous()

for lstm, hidden, context in zip(self.lstms, hidden_by_layer, context_by_layer):
state = (hidden, context)

# We purposefully throw away new_state until we reach the top layer
# since we only care about passing on the final hidden state
rep, (last_hidden, last_context) = lstm(rep, state, seq_lengths)
last_hidden_by_layer.append(last_hidden)
last_context_by_layer.append(last_context)

# Make rep have shape (batch size, num layers, hidden size)
rep = rep.transpose(0, 1).contiguous()

# Make last_hidden and last_context have shape
# (batch size, num layers, hidden size)
last_hidden = torch.stack(last_hidden_by_layer).transpose(0, 1)
last_context = torch.stack(last_context_by_layer).transpose(0, 1)
return rep, (last_hidden, last_context)
73 changes: 73 additions & 0 deletions pytext/models/representations/test/ordered_neuron_lstm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest

import torch
from pytext.models.representations.ordered_neuron_lstm import OrderedNeuronLSTM


class OrderedNeuronLSTMTest(unittest.TestCase):
def _test_shape(self, dropout, num_layers):
config = OrderedNeuronLSTM.Config()
config.dropout = dropout
config.num_layers = num_layers

batch_size = 3
time_step = 17
input_size = 31
lstm = OrderedNeuronLSTM(config, input_size)

input_tensor = torch.randn(batch_size, time_step, input_size)
input_length = torch.zeros((batch_size,)).long()
input_states = (
torch.randn(batch_size, config.num_layers, config.lstm_dim),
torch.randn(batch_size, config.num_layers, config.lstm_dim),
)

for i in range(batch_size):
input_length[i] = time_step - i

for inp_state in [None, input_states]:
output, (hidden_state, cell_state) = lstm(
input_tensor, input_length, inp_state
)

# Test Shapes
self.assertEqual(
hidden_state.size(), (batch_size, config.num_layers, config.lstm_dim)
)
self.assertEqual(
hidden_state.size(), (batch_size, config.num_layers, config.lstm_dim)
)
self.assertEqual(
cell_state.size(), (batch_size, config.num_layers, config.lstm_dim)
)

# Make sure gradients propagate correctly
output_agg = output.sum()
output_agg.backward()
for param in lstm.parameters():
self.assertEqual(torch.isnan(param).long().sum(), 0)
self.assertEqual(torch.isinf(param).long().sum(), 0)

# Make sure dropout actually does something
s_output, (s_hidden_state, s_cell_state) = lstm(
input_tensor, input_length, inp_state
)

if config.dropout == 0.0:
assert torch.all(
torch.lt(torch.abs(torch.add(s_output, -output)), 1e-12)
)
else:
assert not torch.all(torch.eq(s_output, output))

def test_ordered_neuron_lstm(self):
# test every configuration
for num_layers in [1, 2, 3]:
for dropout in [0.0, 0.5]:
self._test_shape(dropout, num_layers)


if __name__ == "__main__":
unittest.main()

0 comments on commit cddeb7a

Please sign in to comment.