Skip to content

Commit

Permalink
Added Dependency Tree-LSTM for sentiment.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaishengtai committed May 30, 2015
1 parent b451977 commit ba11ff0
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 80 deletions.
19 changes: 13 additions & 6 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@ treelstm.data_dir = 'data'
treelstm.models_dir = 'trained_models'
treelstm.predictions_dir = 'predictions'

-- share parameters of nngraph gModule instances
function share_params(cell, src, ...)
for i = 1, #cell.forwardnodes do
local node = cell.forwardnodes[i]
if node.data.module then
node.data.module:share(src.forwardnodes[i].data.module, ...)
-- share module parameters
function share_params(cell, src)
if torch.type(cell) == 'nn.gModule' then
for i = 1, #cell.forwardnodes do
local node = cell.forwardnodes[i]
if node.data.module then
node.data.module:share(src.forwardnodes[i].data.module,
'weight', 'bias', 'gradWeight', 'gradBias')
end
end
elseif torch.isTypeOf(cell, 'nn.Module') then
cell:share(src, 'weight', 'bias', 'gradWeight', 'gradBias')
else
error('parameters cannot be shared for this input')
end
end

Expand Down
17 changes: 5 additions & 12 deletions models/BinaryTreeLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,13 @@ local BinaryTreeLSTM, parent = torch.class('treelstm.BinaryTreeLSTM', 'treelstm.

function BinaryTreeLSTM:__init(config)
parent.__init(self, config)
self.in_dim = config.in_dim
if self.in_dim == nil then error('input dimension must be specified') end
self.mem_dim = config.mem_dim or 150
self.gate_output = config.gate_output
if self.gate_output == nil then self.gate_output = true end

-- a function that instantiates an output module that takes the hidden state
-- h as input
-- a function that instantiates an output module that takes the hidden state h as input
self.output_module_fn = config.output_module_fn
self.criterion = config.criterion

-- zero vectors for null inputs
self.mem_zeros = torch.zeros(self.mem_dim)

-- leaf input module
self.leaf_module = self:new_leaf_module()
self.leaf_modules = {}
Expand All @@ -48,7 +41,7 @@ function BinaryTreeLSTM:new_leaf_module()

local leaf_module = nn.gModule({input}, {c, h})
if self.leaf_module ~= nil then
share_params(leaf_module, self.leaf_module, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(leaf_module, self.leaf_module)
end
return leaf_module
end
Expand Down Expand Up @@ -85,7 +78,7 @@ function BinaryTreeLSTM:new_composer()
{c, h})

if self.composer ~= nil then
share_params(composer, self.composer, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(composer, self.composer)
end
return composer
end
Expand All @@ -94,7 +87,7 @@ function BinaryTreeLSTM:new_output_module()
if self.output_module_fn == nil then return nil end
local output_module = self.output_module_fn()
if self.output_module ~= nil then
output_module:share(self.output_module, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(output_module, self.output_module)
end
return output_module
end
Expand Down Expand Up @@ -137,7 +130,7 @@ end

function BinaryTreeLSTM:_backward(tree, inputs, grad, grad_inputs)
local output_grad = self.mem_zeros
if tree.output ~= nil then
if tree.output ~= nil and tree.gold_label ~= nil then
output_grad = tree.output_module:backward(
tree.state[2], self.criterion:backward(tree.output, tree.gold_label))
end
Expand Down
59 changes: 54 additions & 5 deletions models/ChildSumTreeLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ function ChildSumTreeLSTM:__init(config)
self.gate_output = config.gate_output
if self.gate_output == nil then self.gate_output = true end

-- a function that instantiates an output module that takes the hidden state h as input
self.output_module_fn = config.output_module_fn
self.criterion = config.criterion

-- composition module
self.composer = self:new_composer()
self.composers = {}

-- output module
self.output_module = self:new_output_module()
self.output_modules = {}
end

function ChildSumTreeLSTM:new_composer()
Expand Down Expand Up @@ -56,19 +64,38 @@ function ChildSumTreeLSTM:new_composer()

local composer = nn.gModule({input, child_c, child_h}, {c, h})
if self.composer ~= nil then
share_params(composer, self.composer, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(composer, self.composer)
end
return composer
end

function ChildSumTreeLSTM:new_output_module()
if self.output_module_fn == nil then return nil end
local output_module = self.output_module_fn()
if self.output_module ~= nil then
share_params(output_module, self.output_module)
end
return output_module
end

function ChildSumTreeLSTM:forward(tree, inputs)
local loss = 0
for i = 1, tree.num_children do
self:forward(tree.children[i], inputs)
local _, child_loss = self:forward(tree.children[i], inputs)
loss = loss + child_loss
end
local child_c, child_h = self:get_child_states(tree)
self:allocate_module(tree, 'composer')
tree.state = tree.composer:forward{inputs[tree.idx], child_c, child_h}
return tree.state

if self.output_module ~= nil then
self:allocate_module(tree, 'output_module')
tree.output = tree.output_module:forward(tree.state[2])
if self.train and tree.gold_label ~= nil then
loss = loss + self.criterion:forward(tree.output, tree.gold_label)
end
end
return tree.state, loss
end

function ChildSumTreeLSTM:backward(tree, inputs, grad)
Expand All @@ -78,10 +105,21 @@ function ChildSumTreeLSTM:backward(tree, inputs, grad)
end

function ChildSumTreeLSTM:_backward(tree, inputs, grad, grad_inputs)
local output_grad = self.mem_zeros
if tree.output ~= nil and tree.gold_label ~= nil then
output_grad = tree.output_module:backward(
tree.state[2], self.criterion:backward(tree.output, tree.gold_label))
end
self:free_module(tree, 'output_module')
tree.output = nil

local child_c, child_h = self:get_child_states(tree)
local composer_grad = tree.composer:backward({inputs[tree.idx], child_c, child_h}, grad)
local composer_grad = tree.composer:backward(
{inputs[tree.idx], child_c, child_h},
{grad[1], grad[2] + output_grad})
self:free_module(tree, 'composer')
tree.state = nil

grad_inputs[tree.idx] = composer_grad[1]
local child_c_grads, child_h_grads = composer_grad[2], composer_grad[3]
for i = 1, tree.num_children do
Expand All @@ -91,14 +129,25 @@ end

function ChildSumTreeLSTM:clean(tree)
self:free_module(tree, 'composer')
self:free_module(tree, 'output_module')
tree.state = nil
tree.output = nil
for i = 1, tree.num_children do
self:clean(tree.children[i])
end
end

function ChildSumTreeLSTM:parameters()
return self.composer:parameters()
local params, grad_params = {}, {}
local cp, cg = self.composer:parameters()
tablex.insertvalues(params, cp)
tablex.insertvalues(grad_params, cg)
if self.output_module ~= nil then
local op, og = self.output_module:parameters()
tablex.insertvalues(params, op)
tablex.insertvalues(grad_params, og)
end
return params, grad_params
end

function ChildSumTreeLSTM:get_child_states(tree)
Expand Down
2 changes: 1 addition & 1 deletion models/LSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function LSTM:new_cell()

-- share parameters
if self.master_cell then
share_params(cell, self.master_cell, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(cell, self.master_cell)
end
return cell
end
Expand Down
6 changes: 3 additions & 3 deletions relatedness/LSTMSim.lua
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ function LSTMSim:__init(config)

-- share must only be called after getParameters, since this changes the
-- location of the parameters
self.rlstm:share(self.llstm, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(self.rlstm, self.llstm)
if self.structure == 'bilstm' then
-- tying the forward and backward weights improves performance
self.llstm_b:share(self.llstm, 'weight', 'bias', 'gradWeight', 'gradBias')
self.rlstm_b:share(self.llstm, 'weight', 'bias', 'gradWeight', 'gradBias')
share_params(self.llstm_b, self.llstm)
share_params(self.rlstm_b, self.llstm)
end
end

Expand Down
3 changes: 2 additions & 1 deletion relatedness/TreeLSTMSim.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function TreeLSTMSim:__init(config)
self.emb_learning_rate = config.emb_learning_rate or 0.0
self.batch_size = config.batch_size or 25
self.reg = config.reg or 1e-4
self.structure = config.structure or 'dependency'
self.structure = config.structure or 'dependency' -- {dependency, constituency}
self.sim_nhidden = config.sim_nhidden or 50

-- word embedding
Expand All @@ -35,6 +35,7 @@ function TreeLSTMSim:__init(config)
mem_dim = self.mem_dim,
gate_output = false,
}

if self.structure == 'dependency' then
self.treelstm = treelstm.ChildSumTreeLSTM(treelstm_config)
elseif self.structure == 'constituency' then
Expand Down
Loading

0 comments on commit ba11ff0

Please sign in to comment.