Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added BRNN support #3

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5d9a7b4
Changed local/global functions to refer to self for inheritance
Apr 11, 2016
dec350f
Added BLSTM based on RNN implementation
Apr 11, 2016
9bfed85
added BLSTM to init
Apr 11, 2016
fe29e8b
removed self call in makeContiguous
Apr 11, 2016
bdc97fa
Added rnn test (for basic RELU)
Apr 13, 2016
e9de79f
Removed c sum check
Apr 13, 2016
25f5383
Removed print statements
Apr 13, 2016
d115fac
Fixed call to input grads
Apr 13, 2016
0df09da
Added test for all RNN types
Apr 13, 2016
11fe833
Added description
Apr 13, 2016
6406b2a
Changed tolerance to fix weight difference, small comment change
Apr 13, 2016
94c1c22
Exposed rnn variable to make easier to replace
Apr 13, 2016
18ce0fb
Added base test for BLSTM
Apr 13, 2016
291bd50
Revert "Added description"
Apr 13, 2016
b09c184
Fixed change of RNN module name
Apr 13, 2016
12a7696
Added comment removed BLSTM test
Apr 13, 2016
1287250
Added numDirections param to RNN
Apr 13, 2016
c98e24f
Added numDirections param instead of duplicated methods
Apr 13, 2016
4d83082
Removed *2 on gradInput
Apr 13, 2016
3c13e7e
Removed hardcoded 3 dimension from cudnn call
Apr 13, 2016
e80fce1
Added BLSTM test
Apr 13, 2016
54488bb
Fixed resize of hidden tensors
Apr 14, 2016
9e58334
Added numDirections to assertions
Apr 14, 2016
04feef3
Added assertion check to accGradParams
Apr 14, 2016
9d44393
Added BRNN test
Apr 14, 2016
1cdff21
Reverted change of class name
Apr 14, 2016
d8d4f89
Added ReLU/tanh/LSTM/GRU bidirectional tests
Apr 15, 2016
ea21e9d
Updated modules in tests
Apr 15, 2016
31f5f17
Added GRU and LSTM
Apr 15, 2016
242d613
Added GRU/LSTM to init
Apr 15, 2016
170e90d
Removed direction set
Apr 15, 2016
2b82dc0
Changed init
Apr 15, 2016
589fdfd
Fixed module names
Apr 15, 2016
04a1ca9
added batchfirst param
Apr 15, 2016
e7ce48b
Put tranpose at top of method call
Apr 15, 2016
a6ab506
Added batchFirst to accGrad
Apr 15, 2016
f40008b
Fixed transpose name
Apr 15, 2016
ba105f0
Fixed transpose name for grads
Apr 15, 2016
d9062ee
Added all transpose operations
Apr 15, 2016
df632a9
Added batchFirst to test params
Apr 15, 2016
e210208
Added small comment to clarify batchFirst
Apr 15, 2016
c420df1
Added description of recurrent modules
Apr 15, 2016
a334ae2
Added separate RNN modules
Apr 16, 2016
079a11c
Added modules to init
Apr 16, 2016
0a9f8ca
Added RNNReLU/Tanh description
Apr 16, 2016
2cff977
Added batchFirst to params
Apr 16, 2016
fc7e761
Changed module calls in tests
Apr 16, 2016
50d4f7c
Use torch sum instead of manual loop
Apr 16, 2016
b89f6c2
Put resizing of tensors in a better place
Apr 18, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions BLSTM.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
local BLSTM, parent = torch.class('cudnn.BLSTM', 'cudnn.RNN')

function BLSTM:__init(inputSize, hiddenSize, numLayers, batchFirst)
parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst)
self.bidirectional = 'CUDNN_BIDIRECTIONAL'
self.mode = 'CUDNN_LSTM'
self.numDirections = 2
self:reset()
end
7 changes: 7 additions & 0 deletions GRU.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
local GRU, parent = torch.class('cudnn.GRU', 'cudnn.RNN')

function GRU:__init(inputSize, hiddenSize, numLayers)
parent.__init(self,inputSize, hiddenSize, numLayers)
self.mode = 'CUDNN_GRU'
self:reset()
end
7 changes: 7 additions & 0 deletions LSTM.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
local LSTM, parent = torch.class('cudnn.LSTM', 'cudnn.RNN')

function LSTM:__init(inputSize, hiddenSize, numLayers)
parent.__init(self,inputSize, hiddenSize, numLayers)
self.mode = 'CUDNN_LSTM'
self:reset()
end
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ cudnn.SpatialCrossEntropyCriterion() -- A spatial version of LogSoftMax +
cudnn.VolumetricConvolution(nInputPlane, nOutputPlane, kT, kW, kH, dT, dW, dH, padT, padW, padH)
cudnn.VolumetricMaxPooling(kT, kW, kH, dT, dW, dH, padT, padW, padH)
cudnn.VolumetricAveragePooling(kT, kW, kH, dT, dW, dH, padT, padW, padH)

-- Recurrent Modules

-- All inputs have to be 3D. Accepts input of seqLength x batch x inputDim, or batch x seqLength x inputDim if batchFirst set to true.
cudnn.RNNReLU(inputDim, outputDim, numberOfLayers, [batchFirst = false])
cudnn.RNNTanh(inputDim, outputDim, numberOfLayers, [batchFirst = false])
cudnn.LSTM(inputDim, outputDim, numberOfLayers, [batchFirst = false])
cudnn.GRU(inputDim, outputDim, numberOfLayers, [batchFirst = false])
cudnn.BLSTM(inputDim, outputDim, numberOfLayers, [batchFirst = false])
```

### Modes
Expand Down
103 changes: 64 additions & 39 deletions RNN.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ local RNN, parent = torch.class('cudnn.RNN', 'nn.Module')
local ffi = require 'ffi'
local errcheck = cudnn.errcheck

function RNN:__init(inputSize, hiddenSize, numLayers)
function RNN:__init(inputSize, hiddenSize, numLayers, batchFirst)
parent.__init(self)

self.datatype = 'CUDNN_DATA_FLOAT'
Expand All @@ -12,10 +12,12 @@ function RNN:__init(inputSize, hiddenSize, numLayers)
self.miniBatch = 1
self.numLayers = numLayers
self.bidirectional = 'CUDNN_UNIDIRECTIONAL'
self.numDirections = 1 -- set to 2 for bi-directional.
self.inputMode = 'CUDNN_LINEAR_INPUT'
self.mode = 'CUDNN_RNN_RELU'
self.dropout = 0
self.seed = 0x01234567
self.batchFirst = batchFirst or false -- Set to true for batch x time x inputdim.

self.gradInput = torch.CudaTensor()
self.output = torch.CudaTensor()
Expand Down Expand Up @@ -50,7 +52,7 @@ function RNN:reset(stdv)
self.gradWeight:resizeAs(self.weight):zero()
end

local function createDescriptors(count, descs_type, create_func, destroy_func)
function RNN:createDescriptors(count, descs_type, create_func, destroy_func)
local ds = ffi.new(descs_type, count)
for i = 0, count - 1 do
errcheck(create_func, ds + i)
Expand All @@ -64,37 +66,37 @@ local function createDescriptors(count, descs_type, create_func, destroy_func)
return ds
end

local function createDropoutDescriptors(count)
return createDescriptors(count,
function RNN:createDropoutDescriptors(count)
return self:createDescriptors(count,
'cudnnDropoutDescriptor_t[?]',
'cudnnCreateDropoutDescriptor',
'cudnnDestroyDropoutDescriptor')
end

local function createFilterDescriptors(count)
return createDescriptors(count,
function RNN:createFilterDescriptors(count)
return self:createDescriptors(count,
'cudnnFilterDescriptor_t[?]',
'cudnnCreateFilterDescriptor',
'cudnnDestroyFilterDescriptor')
end

local function createRNNDescriptors(count)
return createDescriptors(count,
function RNN:createRNNDescriptors(count)
return self:createDescriptors(count,
'cudnnRNNDescriptor_t[?]',
'cudnnCreateRNNDescriptor',
'cudnnDestroyRNNDescriptor')
end

local function createTensorDescriptors(count)
return createDescriptors(count,
function RNN:createTensorDescriptors(count)
return self:createDescriptors(count,
'cudnnTensorDescriptor_t[?]',
'cudnnCreateTensorDescriptor',
'cudnnDestroyTensorDescriptor')
end

function RNN:resetDropoutDescriptor()
if not self.dropoutDesc then
self.dropoutDesc = createDropoutDescriptors(1)
self.dropoutDesc = self:createDropoutDescriptors(1)
end

self.dropoutStatesSize = torch.LongTensor(1)
Expand All @@ -113,7 +115,7 @@ end

function RNN:resetRNNDescriptor()
if not self.rnnDesc then
self.rnnDesc = createRNNDescriptors(1)
self.rnnDesc = self:createRNNDescriptors(1)
end

errcheck('cudnnSetRNNDescriptor',
Expand All @@ -130,7 +132,7 @@ end

function RNN:resetWeightDescriptor()
if not self.wDesc then
self.wDesc = createFilterDescriptors(1)
self.wDesc = self:createFilterDescriptors(1)
end

local dim = torch.IntTensor({self.weight:size(1), 1, 1})
Expand All @@ -144,8 +146,8 @@ function RNN:resetWeightDescriptor()
end

function RNN:resetIODescriptors()
self.xDescs = createTensorDescriptors(self.seqLength)
self.yDescs = createTensorDescriptors(self.seqLength)
self.xDescs = self:createTensorDescriptors(self.seqLength)
self.yDescs = self:createTensorDescriptors(self.seqLength)

for i = 0, self.seqLength - 1 do
local dim = torch.IntTensor({self.inputSize, self.miniBatch, self.seqLength})
Expand All @@ -157,7 +159,7 @@ function RNN:resetIODescriptors()
dim:data(),
stride:data())

local dim = torch.IntTensor({self.hiddenSize, self.miniBatch, self.seqLength})
local dim = torch.IntTensor({self.hiddenSize * self.numDirections, self.miniBatch, self.seqLength})
local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]})
errcheck('cudnnSetTensorNdDescriptor',
self.yDescs[i],
Expand All @@ -169,8 +171,8 @@ function RNN:resetIODescriptors()
end

function RNN:resetHiddenDescriptors()
self.hxDesc = createTensorDescriptors(1)
self.hyDesc = createTensorDescriptors(1)
self.hxDesc = self:createTensorDescriptors(1)
self.hyDesc = self:createTensorDescriptors(1)

local dim = torch.IntTensor({self.hiddenSize, self.miniBatch, self.numLayers})
local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]})
Expand All @@ -190,8 +192,8 @@ function RNN:resetHiddenDescriptors()
end

function RNN:resetCellDescriptors()
self.cxDesc = createTensorDescriptors(1)
self.cyDesc = createTensorDescriptors(1)
self.cxDesc = self:createTensorDescriptors(1)
self.cyDesc = self:createTensorDescriptors(1)

local dim = torch.IntTensor({self.hiddenSize, self.miniBatch, self.numLayers})
local stride = torch.IntTensor({1, dim[1], dim[1] * dim[2]})
Expand All @@ -210,7 +212,7 @@ function RNN:resetCellDescriptors()
stride:data())
end

local function makeContiguous(self, input, gradOutput)
function RNN:makeContiguous(input, gradOutput)
if not input:isContiguous() then
self._input = self._input or input.new()
self._input:typeAs(input):resizeAs(input):copy(input)
Expand All @@ -224,9 +226,19 @@ local function makeContiguous(self, input, gradOutput)
return input, gradOutput
end

function RNN:resizeOutput(tensor)
return tensor:resize(self.seqLength, self.miniBatch, self.hiddenSize * self.numDirections)
end

function RNN:resizeHidden(tensor)
return tensor:resize(self.numLayers * self.numDirections, self.miniBatch, self.hiddenSize)
end

function RNN:updateOutput(input)
if (self.batchFirst) then
input = input:transpose(1, 2)
end
assert(input:dim() == 3, 'input must have 3 dimensions: seqLength, miniBatch, inputSize')

-- Decide which descriptors/tensors need to be updated.
local resetRNN = not self.dropoutDesc or not self.rnnDesc
local resetIO = not self.xDescs or not self.yDescs
Expand Down Expand Up @@ -263,26 +275,26 @@ function RNN:updateOutput(input)
self:resetWeightDescriptor()
end

local x = makeContiguous(self, input)
local y = self.output:resize(self.seqLength, self.miniBatch, self.hiddenSize)
local x = self:makeContiguous(input)
local y = self:resizeOutput(self.output)
local w = self.weight
local hy = self.hiddenOutput:resize(self.numLayers, self.miniBatch, self.hiddenSize):zero()
local cy = self.cellOutput:resize(self.numLayers, self.miniBatch, self.hiddenSize):zero()
local hy = self:resizeHidden(self.hiddenOutput):zero()
local cy = self:resizeHidden(self.cellOutput):zero()

-- Optionally use hiddenInput/cellInput parameters
local hx = self.hiddenInput
local cx = self.cellInput

if hx then
assert(hx:dim() == 3, 'hiddenInput must have 3 dimensions: numLayers, miniBatch, hiddenSize')
assert(hx:size(1) == self.numLayers, 'hiddenInput has incorrect number of layers!')
assert(hx:size(1) == self.numLayers * self.numDirections, 'hiddenInput has incorrect number of layers!')
assert(hx:size(2) == self.miniBatch, 'hiddenInput has incorrect number of minibathes!')
assert(hx:size(3) == self.hiddenSize, 'hiddenIinput has incorrect size!')
assert(hx:isContiguous(), 'hiddenInput must be contiguous!') end

if cx then
assert(cx:dim() == 3, 'cellInput must have 3 dimensions: numLayers, miniBatch, hiddenSize')
assert(cx:size(1) == self.numLayers, 'cellInput has incorrect number of layers!')
assert(cx:size(1) == self.numLayers * self.numDirections, 'cellInput has incorrect number of layers!')
assert(cx:size(2) == self.miniBatch, 'cellInput has incorrect number of minibathes!')
assert(cx:size(3) == self.hiddenSize, 'cellInput has incorrect size!')
assert(cx:isContiguous(), 'cellInput must be contiguous!')
Expand Down Expand Up @@ -338,11 +350,18 @@ function RNN:updateOutput(input)
self.cyDesc[0], cy:data(),
self.workspace:data(), self.workspace:size(1) * 4) -- sizeof(float)
end

if (self.batchFirst) then
self.output = self.output:transpose(1, 2)
end
return self.output
end

function RNN:updateGradInput(input, gradOutput)
if (self.batchFirst) then
input = input:transpose(1, 2)
gradOutput = gradOutput:transpose(1, 2)
self.output = self.output:transpose(1, 2)
end
assert(input:dim() == 3, 'input should have 3 dimensions: seqLength, miniBatch, inputSize')
assert(input:size(1) == self.seqLength, 'input has incorrect sequence length!')
assert(input:size(2) == self.miniBatch, 'input has incorrect minibatch size!')
Expand All @@ -351,29 +370,29 @@ function RNN:updateGradInput(input, gradOutput)
assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
assert(self.train, 'updateGradInput can only be called when training!')

local x, dy = makeContiguous(self, input, gradOutput)
local x, dy = self:makeContiguous(input, gradOutput)
local y = self.output
local w = self.weight
local dx = self.gradInput:resizeAs(input)
local hx = self.hiddenInput
local cx = self.cellInput
local dhy = self.gradHiddenOutput
local dcy = self.gradCellOutput
local dhx = self.gradHiddenInput:resize(self.numLayers, self.miniBatch, self.hiddenSize):zero()
local dcx = self.gradCellInput:resize(self.numLayers, self.miniBatch, self.hiddenSize):zero()
local dhx = self:resizeHidden(self.gradHiddenInput):zero()
local dcx = self:resizeHidden(self.gradCellInput):zero()


if hx then
assert(hx:dim() == 3, 'hiddenInput must have 3 dimensions: numLayers, miniBatch, hiddenSize')
assert(hx:size(1) == self.numLayers, 'hiddenInput has incorrect number of layers!')
assert(hx:size(1) == self.numLayers * self.numDirections, 'hiddenInput has incorrect number of layers!')
assert(hx:size(2) == self.miniBatch, 'hiddenInput has incorrect minibatch size!')
assert(hx:size(3) == self.hiddenSize, 'hiddenInput has incorrect size!')
assert(hx:isContiguous(), 'hiddenInput must be contiguous!')
end

if cx then
assert(cx:dim() == 3, 'cellInput must have 3 dimensions: numLayers, miniBatch, hiddenSize')
assert(cx:size(1) == self.numLayers, 'cellInput has incorrect number of layers!')
assert(cx:size(1) == self.numLayers * self.numDirections, 'cellInput has incorrect number of layers!')
assert(cx:size(2) == self.miniBatch, 'cellInput has incorrect minibatch size!')
assert(cx:size(3) == self.hiddenSize, 'cellInput has incorrect size!')
assert(cx:isContiguous(), 'cellInput must be contiguous!')
Expand All @@ -382,7 +401,7 @@ function RNN:updateGradInput(input, gradOutput)
if dhy then
assert(dhy:dim() == 3, 'gradHiddenOutput must have 3 dimensions: ' ..
'numLayers, miniBatch, hiddenSize')
assert(dhy:size(1) == self.numLayers, 'gradHiddenOutput has incorrect number of layers!')
assert(dhy:size(1) == self.numLayers * self.numDirections, 'gradHiddenOutput has incorrect number of layers!')
assert(dhy:size(2) == self.miniBatch, 'gradHiddenOutput has incorrect minibatch size!')
assert(dhy:size(3) == self.hiddenSize, 'gradHiddenOutput has incorrect size!')
assert(dhy:isContiguous(), 'gradHiddenOutput must be contiguous!')
Expand All @@ -391,7 +410,7 @@ function RNN:updateGradInput(input, gradOutput)
if dcy then
assert(dcy:dim() == 3, 'gradCellOutput must have 3 dimensions: ' ..
'numLayers, miniBatch, hiddenSize')
assert(dcy:size(1) == self.numLayers, 'gradCellOutput has incorrect number of layers!')
assert(dcy:size(1) == self.numLayers * self.numDirections, 'gradCellOutput has incorrect number of layers!')
assert(dcy:size(2) == self.miniBatch, 'gradCellOutput has incorrect minibatch size!')
assert(dcy:size(3) == self.hiddenSize, 'gradCellOutput has incorrect size!')
assert(dcy:isContiguous(), 'gradCellOutput must be contiguous!')
Expand All @@ -412,11 +431,17 @@ function RNN:updateGradInput(input, gradOutput)
self.cxDesc[0], dcx:data(),
self.workspace:data(), self.workspace:size(1) * 4, -- sizeof(float)
self.reserve:data(), self.reserve:size(1) * 4) -- sizeof(float)

if (self.batchFirst) then
self.gradInput = self.gradInput:transpose(1, 2)
end
return self.gradInput
end

function RNN:accGradParameters(input, gradOutput, scale)
if (self.batchFirst) then
input = input:transpose(1, 2)
gradOutput = gradOutput:transpose(1, 2)
end
scale = scale or 1
if scale == 0 then return end

Expand All @@ -428,14 +453,14 @@ function RNN:accGradParameters(input, gradOutput, scale)
assert(gradOutput:isSameSizeAs(self.output), 'gradOutput has incorrect size!')
assert(self.train, 'accGradParameters can only be called when training!')

local x, dy = makeContiguous(self, input, gradOutput)
local x, dy = self:makeContiguous(input, gradOutput)
local hx = self.hiddenInput
local y = self.output
local dw = self.gradWeight

if hx then
assert(hx:dim() == 3, 'hiddenInput must have 3 dimensions: numLayers, miniBatch, hiddenSize')
assert(hx:size(1) == self.numLayers, 'hiddenInput has incorrect number of layers!')
assert(hx:size(1) == self.numLayers * self.numDirections, 'hiddenInput has incorrect number of layers!')
assert(hx:size(2) == self.miniBatch, 'hiddenInput has incorrect minibatch size!')
assert(hx:size(3) == self.hiddenSize, 'hiddenIinput has incorrect size!')
assert(hx:isContiguous(), 'hiddenInput must be contiguous!')
Expand Down
6 changes: 6 additions & 0 deletions RNNReLU.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
local RNNReLU, parent = torch.class('cudnn.RNNReLU', 'cudnn.RNN')

function RNNReLU:__init(inputSize, hiddenSize, numLayers, batchFirst)
parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst)
self.mode = 'CUDNN_RNN_RELU'
end
6 changes: 6 additions & 0 deletions RNNTanh.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
local RNNTanh, parent = torch.class('cudnn.RNNTanh', 'cudnn.RNN')

function RNNTanh:__init(inputSize, hiddenSize, numLayers, batchFirst)
parent.__init(self,inputSize, hiddenSize, numLayers, batchFirst)
self.mode = 'CUDNN_RNN_TANH'
end
5 changes: 5 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ require('cudnn.VolumetricBatchNormalization')
require('cudnn.SpatialCrossEntropyCriterion')
require('cudnn.TemporalConvolution')
require('cudnn.RNN')
require('cudnn.RNNTanh')
require('cudnn.RNNReLU')
require('cudnn.BLSTM')
require('cudnn.LSTM')
require('cudnn.GRU')
require('cudnn.functional')
require('cudnn.convert')

Expand Down
Loading