forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DataList.lua
82 lines (73 loc) · 2.72 KB
/
DataList.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
--------------------------------------------------------------------------------
-- DataList: a container for plain DataSets.
-- Each sub dataset represents elements from only one class.
--
-- Authors: Corda, Farabet
--------------------------------------------------------------------------------
local DataList, parent = torch.class('nn.DataList','nn.DataSet')
function DataList:__init()
parent.__init(self)
self.datasets = {}
self.nbClass = 0
self.ClassName = {}
self.ClassMax = 0
self.nbSamples = 0
self.targetIsProbability = false
self.spatialTarget = false
end
function DataList:__tostring__()
str = 'DataList:\n'
str = str .. ' + nb samples : '..self.nbSamples..'\n'
str = str .. ' + nb classes : '..self.nbClass
return str
end
function DataList:__index__(key)
if type(key)=='number' and self.nbClass>0 and key <= self.nbSamples then
local class = ((key-1) % self.nbClass) + 1
local classSize = self.datasets[class]:size()
local elmt = math.floor((key-1)/self.nbClass) + 1
elmt = ((elmt-1) % classSize) + 1
-- create target vector on the fly
if self.spatialTarget then
if self.targetIsProbability then
self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):zero()
else
self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1)
end
self.datasets[class][elmt][2][class][1][1] = 1
else
if self.targetIsProbability then
self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):zero()
else
self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1)
end
self.datasets[class][elmt][2][class] = 1
end
-- apply hook on sample
local sample = self.datasets[class][elmt]
if self.hookOnSample then
sample = self.hookOnSample(self,sample)
end
-- auto conversion to CUDA
if torch.getdefaulttensortype() == 'torch.CudaTensor' then
sample[1] = torch.Tensor(sample[1]:size()):copy(sample[1])
end
return sample,true
end
-- if key is not a number this should return nil
return rawget(self, key)
end
function DataList:appendDataSet(dataSet,className)
table.insert(self.datasets,dataSet)
-- you can append the same class several times with this mechanism
if self.ClassName[className] then
self.ClassName[className] = self.ClassName[className] + dataSet:size()
else
self.ClassName[className] = dataSet:size()
self.nbClass = self.nbClass + 1
table.insert(self.ClassName,self.nbClass,className)
end
self.ClassMax =
math.floor(math.max(self.ClassMax,self.ClassName[className]))
self.nbSamples = self.ClassMax * self.nbClass
end