Skip to content

Commit

Permalink
Merge pull request #211 from kklurz/main
Browse files Browse the repository at this point in the history
Minor bug fixes for zero_inflation losses
  • Loading branch information
kklurz authored Aug 25, 2023
2 parents 8ee8bff + d241fc1 commit 32b03fe
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
7 changes: 7 additions & 0 deletions neuralpredictors/layers/encoders/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch import nn


Expand Down Expand Up @@ -73,6 +74,12 @@ def forward(
else:
x = self.readout(x, data_key=data_key, shift=shift)

# keep batch dimension if only one image was passed
params = []
for param in x:
params.append(param[None, ...] if len(param.shape) == 1 else param)
x = torch.stack(params)

if self.modulator:
x = self.modulator[data_key](x, behavior=behavior)

Expand Down
1 change: 1 addition & 0 deletions neuralpredictors/layers/encoders/firing_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def forward(
shift = self.shifter[data_key](pupil_center, trial_idx)

x = self.readout(x, data_key=data_key, shift=shift, **kwargs)
x = x[None, ...] if len(x.shape) == 1 else x # keep dimensions if only one image was passed

if self.modulator:
if behavior is None:
Expand Down
6 changes: 6 additions & 0 deletions neuralpredictors/layers/encoders/zero_inflation_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def forward_base(
else:
x = self.readout(x, data_key=data_key, shift=shift)

# keep batch dimension if only one image was passed
params = []
for param in x:
params.append(param[None, ...] if len(param.shape) == 1 else param)
x = torch.stack(params)

if self.modulator:
x = self.modulator[data_key](x, behavior=behavior)

Expand Down
4 changes: 2 additions & 2 deletions neuralpredictors/measures/zero_inflated_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def forward(self, target, output, **kwargs):
if loc.requires_grad:
self.multi_clamp(loc, [0.0] * neurons_n, target.max(dim=0)[0])

zero_mask = target < loc
nonzero_mask = target >= loc
zero_mask = target <= loc
nonzero_mask = target > loc

# spike loss
spike_logl = torch.log(1 - q) - torch.log(loc)
Expand Down

0 comments on commit 32b03fe

Please sign in to comment.