@@ -12,14 +12,30 @@ def __init__(self, core, readout, readout_nonlinearity, elu_xshift, elu_yshift):
12
12
else :
13
13
self .nonlinearity = core .nonlinearities [readout_nonlinearity ]()
14
14
15
- def forward (self , x , data_key = None ):
15
+ def forward (self , x , data_key = None , pupil_center = None , trial_idx = None , shift = None , detach_core = False , ** kwargs ):
16
16
out_core = self .core (x )
17
+ if detach_core :
18
+ out_core = out_core .detach ()
19
+
20
+ if self .shifter :
21
+ if pupil_center is None :
22
+ raise ValueError ("pupil_center is not given" )
23
+ if shift is None :
24
+ time_points = x .shape [1 ]
25
+ pupil_center = pupil_center [:, :, - time_points :]
26
+ pupil_center = torch .transpose (pupil_center , 1 , 2 )
27
+ pupil_center = pupil_center .reshape (((- 1 ,) + pupil_center .size ()[2 :]))
28
+ shift = self .shifter [data_key ](pupil_center , trial_idx )
29
+
17
30
out_core = torch .transpose (out_core , 1 , 2 )
18
31
# the expected readout is 2d whereas the core can output 3d matrices
19
32
# therefore, the first two dimensions (representing depth and batch size) are flattened and then passed
20
33
# through the readout
21
34
out_core = out_core .reshape (((- 1 ,) + out_core .size ()[2 :]))
35
+ readout_out = self .readout (out_core , data_key = data_key , shift = shift , ** kwargs )
22
36
23
- readout_out = self .readout (out_core )
24
- out = self .nonlinearity (readout_out )
37
+ if self .nonlinearity_type == "elu" :
38
+ out = self .nonlinearity_fn (readout_out + self .offset ) + 1
39
+ else :
40
+ out = self .nonlinearity_fn (readout_out )
25
41
return out
0 commit comments