Skip to content

Commit

Permalink
ClassNLLCriterion now can take target as a number or Tensor of dim1 w…
Browse files Browse the repository at this point in the history
…ith 1 element in non-minibatch mode
  • Loading branch information
soumith committed Apr 21, 2015
1 parent 17138b2 commit d3aa258
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions ClassNLLCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ function ClassNLLCriterion:updateOutput(input, target)
if input:type() == 'torch.CudaTensor' and not self.weights then
if input:dim() == 1 then
self._target = self._target or input.new(1)
self._target[1] = target
if type(target) == 'number' then
self._target[1] = target
else
self._target:copy(target)
end
input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
else
input.nn.ClassNLLCriterion_updateOutput(self, input, target)
Expand All @@ -24,6 +28,7 @@ function ClassNLLCriterion:updateOutput(input, target)
end

if input:dim() == 1 then
if torch.isTensor(target) then target = target[1] end
self.output = -input[target]
if self.weights then
self.output = self.output*self.weights[target]
Expand Down Expand Up @@ -54,7 +59,11 @@ function ClassNLLCriterion:updateGradInput(input, target)
if input:type() == 'torch.CudaTensor' and not self.weights then
if input:dim() == 1 then
self._target = self._target or input.new(1)
self._target[1] = target
if type(target) == 'number' then
self._target[1] = target
else
self._target:copy(target)
end
input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
else
input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
Expand All @@ -63,6 +72,7 @@ function ClassNLLCriterion:updateGradInput(input, target)
end

if input:dim() == 1 then
if torch.isTensor(target) then target = target[1] end
self.gradInput[target] = -1
if self.weights then
self.gradInput[target] = self.gradInput[target]*self.weights[target]
Expand Down

0 comments on commit d3aa258

Please sign in to comment.