Skip to content

Commit 1dda093

Browse files
authored
Merge pull request #230 from pollytur/encoder3d_shifter_fix
Added shifter and updated 3d-encoder nonlinearity
2 parents a57dc39 + 584c646 commit 1dda093

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

neuralpredictors/layers/encoders/encoder3d.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,30 @@ def __init__(self, core, readout, readout_nonlinearity, elu_xshift, elu_yshift):
1212
else:
1313
self.nonlinearity = core.nonlinearities[readout_nonlinearity]()
1414

15-
def forward(self, x, data_key=None):
15+
def forward(self, x, data_key=None, pupil_center=None, trial_idx=None, shift=None, detach_core=False, **kwargs):
1616
out_core = self.core(x)
17+
if detach_core:
18+
out_core = out_core.detach()
19+
20+
if self.shifter:
21+
if pupil_center is None:
22+
raise ValueError("pupil_center is not given")
23+
if shift is None:
24+
time_points = x.shape[1]
25+
pupil_center = pupil_center[:, :, -time_points:]
26+
pupil_center = torch.transpose(pupil_center, 1, 2)
27+
pupil_center = pupil_center.reshape(((-1,) + pupil_center.size()[2:]))
28+
shift = self.shifter[data_key](pupil_center, trial_idx)
29+
1730
out_core = torch.transpose(out_core, 1, 2)
1831
# the expected readout is 2d whereas the core can output 3d matrices
1932
# therefore, the first two dimensions (representing depth and batch size) are flattened and then passed
2033
# through the readout
2134
out_core = out_core.reshape(((-1,) + out_core.size()[2:]))
35+
readout_out = self.readout(out_core, data_key=data_key, shift=shift, **kwargs)
2236

23-
readout_out = self.readout(out_core)
24-
out = self.nonlinearity(readout_out)
37+
if self.nonlinearity_type == "elu":
38+
out = self.nonlinearity_fn(readout_out + self.offset) + 1
39+
else:
40+
out = self.nonlinearity_fn(readout_out)
2541
return out

0 commit comments

Comments
 (0)