Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Ordered Neuron LSTM #854

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions demo/atis_joint_model/atis_joint_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
"representation": {
"BiLSTMDocSlotAttention": {
"lstm": {
"dropout": 0.5,
"lstm_dim": 366,
"num_layers": 2,
"bidirectional": true
"BiLSTM": {
"dropout": 0.5,
"lstm_dim": 366,
"num_layers": 2,
"bidirectional": true
}
},
"pooling": {
"SelfAttention": {
Expand Down
6 changes: 4 additions & 2 deletions pytext/exporters/test/text_model_exporter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
"representation": {
"BiLSTMDocSlotAttention": {
"lstm": {
"lstm_dim": 30,
"num_layers": 1
"BiLSTM": {
"lstm_dim": 30,
"num_layers": 1
}
},
"pooling": {
"SelfAttention": {
Expand Down
3 changes: 2 additions & 1 deletion pytext/models/representations/bilstm_doc_slot_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytext.models.module import create_module

from .bilstm import BiLSTM
from .ordered_neuron_lstm import OrderedNeuronLSTM
from .pooling import MaxPool, MeanPool, SelfAttention
from .representation_base import RepresentationBase
from .slot_attention import SlotAttention
Expand Down Expand Up @@ -49,7 +50,7 @@ class BiLSTMDocSlotAttention(RepresentationBase):

class Config(RepresentationBase.Config, ConfigBase):
dropout: float = 0.4
lstm: BiLSTM.Config = BiLSTM.Config()
lstm: Union[BiLSTM.Config, OrderedNeuronLSTM.Config] = BiLSTM.Config()
pooling: Optional[
Union[SelfAttention.Config, MaxPool.Config, MeanPool.Config]
] = None
Expand Down
176 changes: 176 additions & 0 deletions pytext/models/representations/ordered_neuron_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytext.config import ConfigBase
from pytext.models.module import Module
from pytext.utils import cuda

from .representation_base import RepresentationBase


# A single layer of an Ordered Neuron LSTM
class OrderedNeuronLSTMLayer(Module):
def __init__(
self, embed_dim: int, lstm_dim: int, padding_value: float, dropout: float
) -> None:
super().__init__()
self.lstm_dim = lstm_dim
self.padding_value = padding_value
self.dropout = nn.Dropout(dropout)

total_size = embed_dim + lstm_dim
self.f_gate = nn.Linear(total_size, lstm_dim)
self.i_gate = nn.Linear(total_size, lstm_dim)
self.o_gate = nn.Linear(total_size, lstm_dim)
self.c_hat_gate = nn.Linear(total_size, lstm_dim)
self.master_forget_no_cumax_gate = nn.Linear(total_size, lstm_dim)
self.master_input_no_cumax_gate = nn.Linear(total_size, lstm_dim)

# embedded_tokens has shape (seq length, batch size, embed size)
# states = (hidden, context), where both hidden and context have
# shape (batch size, hidden size)
def forward(
self,
embedded_tokens: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor],
seq_lengths: List[int],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
hidden, context = states
batch_size = hidden.size(0)
all_context = []
all_hidden = []

if self.dropout.p > 0.0:
embedded_tokens = self.dropout(embedded_tokens)

for batch in embedded_tokens:
# Compute the normal LSTM gates
combined = torch.cat((batch, hidden), 1)
ft = self.f_gate(combined).sigmoid()
it = self.i_gate(combined).sigmoid()
ot = self.o_gate(combined).sigmoid()
c_hat = self.c_hat_gate(combined).tanh()

# Compute the master gates
master_forget_no_cumax = self.master_forget_no_cumax_gate(combined)
master_forget = torch.cumsum(
F.softmax(master_forget_no_cumax, dim=1), dim=1
)
master_input_no_cumax = self.master_input_no_cumax_gate(combined)
master_input = torch.cumsum(F.softmax(master_input_no_cumax, dim=1), dim=1)

# Combine master gates with normal LSTM gates
wt = master_forget * master_input
f_hat_t = ft * wt + (master_forget - wt)
i_hat_t = it * wt + (master_input - wt)

# Compute new context and hidden using final combined gates
context = f_hat_t * context + i_hat_t * c_hat
hidden = ot * context
all_context.append(context)
all_hidden.append(hidden)

# Compute what the final state (hidden and context for each element
# in the batch) should be based on seq_lengths
state_hidden = []
state_context = []

for i in range(batch_size):
seq_length = seq_lengths[i]
state_hidden.append(all_hidden[seq_length - 1][i])
state_context.append(all_context[seq_length - 1][i])

# Return hidden states across all time steps, and return a tuple
# containing the hidden and context for the last time step (might
# be different based on seq_lengths)
return (
torch.stack(all_hidden),
(torch.stack(state_hidden), torch.stack(state_context)),
)


# Ordered Neuron LSTM with any number of layers
class OrderedNeuronLSTM(RepresentationBase):
class Config(RepresentationBase.Config, ConfigBase):
dropout: float = 0.4
lstm_dim: int = 32
num_layers: int = 1

def __init__(
self, config: Config, embed_dim: int, padding_value: Optional[float] = 0.0
) -> None:
super().__init__(config)
self.representation_dim = config.lstm_dim
self.padding_value = padding_value
lstms = []
sizes = [embed_dim] + ([config.lstm_dim] * config.num_layers)

# Create an ONLstm for each hidden size, and chain them together
# using lstms
for i in range(len(sizes) - 1):
lstm = OrderedNeuronLSTMLayer(
sizes[i], sizes[i + 1], padding_value, config.dropout
)
lstms.append(lstm)

self.lstms = nn.ModuleList(lstms)

# rep has shape (batch size, seq length, embed dim)
# seq_lengths has sequence lengths for each case in the batch, used to
# pick the last hidden and context
# states is a tuple for initial hidden and context
def forward(
self,
rep: torch.Tensor,
seq_lengths: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if states is not None:
# Transpose states so hidden and context both have shape
# (num layers, batch size, lstm dim)
states = (
states[0].transpose(0, 1).contiguous(),
states[1].transpose(0, 1).contiguous(),
)
else:
# state has shape (num layers, batch size, lstm dim)
state = torch.zeros(
self.config.num_layers,
rep.size(0),
self.config.lstm_dim,
device=torch.cuda.current_device() if cuda.CUDA_ENABLED else None,
)

states = (state, state)

# hidden_by_layer is a list of hidden states for each layer of the
# network, and similarly for context_by_layer
hidden_by_layer, context_by_layer = states

# Collect the last hidden and context for each layer
last_hidden_by_layer = []
last_context_by_layer = []
rep = rep.transpose(0, 1).contiguous()

for lstm, hidden, context in zip(self.lstms, hidden_by_layer, context_by_layer):
state = (hidden, context)

# We purposefully throw away new_state until we reach the top layer
# since we only care about passing on the final hidden state
rep, (last_hidden, last_context) = lstm(rep, state, seq_lengths)
last_hidden_by_layer.append(last_hidden)
last_context_by_layer.append(last_context)

# Make rep have shape (batch size, num layers, hidden size)
rep = rep.transpose(0, 1).contiguous()

# Make last_hidden and last_context have shape
# (batch size, num layers, hidden size)
last_hidden = torch.stack(last_hidden_by_layer).transpose(0, 1)
last_context = torch.stack(last_context_by_layer).transpose(0, 1)
return rep, (last_hidden, last_context)
73 changes: 73 additions & 0 deletions pytext/models/representations/test/ordered_neuron_lstm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest

import torch
from pytext.models.representations.ordered_neuron_lstm import OrderedNeuronLSTM


class OrderedNeuronLSTMTest(unittest.TestCase):
def _test_shape(self, dropout, num_layers):
config = OrderedNeuronLSTM.Config()
config.dropout = dropout
config.num_layers = num_layers

batch_size = 3
time_step = 17
input_size = 31
lstm = OrderedNeuronLSTM(config, input_size)

input_tensor = torch.randn(batch_size, time_step, input_size)
input_length = torch.zeros((batch_size,)).long()
input_states = (
torch.randn(batch_size, config.num_layers, config.lstm_dim),
torch.randn(batch_size, config.num_layers, config.lstm_dim),
)

for i in range(batch_size):
input_length[i] = time_step - i

for inp_state in [None, input_states]:
output, (hidden_state, cell_state) = lstm(
input_tensor, input_length, inp_state
)

# Test Shapes
self.assertEqual(
hidden_state.size(), (batch_size, config.num_layers, config.lstm_dim)
)
self.assertEqual(
hidden_state.size(), (batch_size, config.num_layers, config.lstm_dim)
)
self.assertEqual(
cell_state.size(), (batch_size, config.num_layers, config.lstm_dim)
)

# Make sure gradients propagate correctly
output_agg = output.sum()
output_agg.backward()
for param in lstm.parameters():
self.assertEqual(torch.isnan(param).long().sum(), 0)
self.assertEqual(torch.isinf(param).long().sum(), 0)

# Make sure dropout actually does something
s_output, (s_hidden_state, s_cell_state) = lstm(
input_tensor, input_length, inp_state
)

if config.dropout == 0.0:
assert torch.all(
torch.lt(torch.abs(torch.add(s_output, -output)), 1e-12)
)
else:
assert not torch.all(torch.eq(s_output, output))

def test_ordered_neuron_lstm(self):
# test every configuration
for num_layers in [1, 2, 3]:
for dropout in [0.0, 0.5]:
self._test_shape(dropout, num_layers)


if __name__ == "__main__":
unittest.main()