Skip to content

Commit

Permalink
Trying to stop Lua running out of memory with large tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Lewis committed May 8, 2014
1 parent 4380273 commit 0dd6b47
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
22 changes: 8 additions & 14 deletions training/SGD.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,17 @@ function SGD:__init(module, criterion, folder)
self.folder = folder
self.log=io.open(folder .. '/log',"w")
self.log:setvbuf("no")


end


function SGD:eval(validation)
right = 0.0
for t = 1,validation:size() do
--local example = validation[t]
--local input = example[1]
--local target = example[2]
local input = validation[1][t]
local target = validation[2][t]

input = nn.SplitTable(1):forward(nn.Reshape(3,window):forward(input))
output = self.module:forward(input)

outputLabel = 1;
Expand Down Expand Up @@ -71,26 +70,21 @@ function SGD:train(dataset, validation)
end
end

self.log:write("shuffled indices" .. '\n')


bestScore = -1;
it = 0;
local itsSinceImprovement = 0;
while true do
self.log:write(it .. " ")
it = it + 1
self.log:write("Iteration " .. it .. " ")


local currentError = 0
for t = 1,dataset:size() do
-- local example = dataset[shuffledIndices[t]]
-- local input = example[1]
-- local target = example[2]

local input = dataset[1][shuffledIndices[t]]
local target = dataset[2][shuffledIndices[t]]

input = nn.SplitTable(1):forward(nn.Reshape(3,window):forward(input))
currentError = currentError + criterion:forward(module:forward(input), target)

module:updateGradInput(input, criterion:updateGradInput(module.output, target))
Expand All @@ -110,7 +104,7 @@ function SGD:train(dataset, validation)

acc=self:eval(validation)

self.log:write("Validation Accuracy: " .. acc .. '\n')
self.log:write("Development Set Accuracy: " .. acc .. '\n')


if acc > bestScore then
Expand All @@ -128,8 +122,8 @@ function SGD:train(dataset, validation)
break
end

if itsSinceImprovement == 3 then
self.log:write("# SGD: no improvement for 3 iterations" .. '\n')
if itsSinceImprovement == 1 then
self.log:write("# SGD: no improvement for 1 iteration. Stopping training." .. '\n')
break
end
end
Expand Down
2 changes: 2 additions & 0 deletions training/features.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ end
function Features:fileToTable(path, minIndex, normalize)
local file = io.open(path)

if not file then error("Unable to load file: " .. file) end

local result = {}
local reverse = {}
local count = minIndex
Expand Down
12 changes: 8 additions & 4 deletions training/train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ function Train:loadDataset(path, features, windowBackward, windowForward, includ
local index = 0
local window = windowBackward + windowForward + 1

local inputData = {}
local targetData = torch.IntStorage(size)--{}
local inputData = torch.Tensor(size, window * 3)
local targetData = torch.IntStorage(size)

local lineNum = 1;
for line in file:lines() do
Expand Down Expand Up @@ -56,14 +56,18 @@ function Train:loadDataset(path, features, windowBackward, windowForward, includ
for i = 1,numWords do

local input = features:getFeatures(words, i, windowBackward, windowForward);
local newInput = nn.SplitTable(1):forward(nn.Reshape(3,window):forward(input))
--local newInput = nn.SplitTable(1):forward(nn.Reshape(3,window):forward(input))
local label = features:getCategoryIndex(cats[i])

if label > 0 or (includeRareCategories) then
--Label 0 is rare categories. These are used for evaluation, but not for training.
index = index + 1

inputData[index] = newInput
for j=1,input:size()[1] do
value = input[j]
inputData[index][j] = value
end

targetData[index] = label
end
end
Expand Down

0 comments on commit 0dd6b47

Please sign in to comment.