@@ -540,20 +540,32 @@ def response_time(self) -> npt.NDArray[np.float64]:
540
540
@npc_io .cached_property
541
541
def reward_time (self ) -> npt .NDArray [np .floating ]:
542
542
"""delivery time of water reward, for contingent and non-contingent rewards"""
543
- all_reward_times = npc_stim .safe_index (self ._flip_times , self ._sam .rewardFrames )
544
- all_reward_times = all_reward_times [all_reward_times <= self .stop_time [- 1 ]]
545
- all_reward_trials = (
543
+ if (
544
+ (solenoid_times := getattr (self , "_reward_times_with_duration" , None )) is not None
545
+ and len (solenoid_times ) >= len (np .where (self .is_rewarded )[0 ])
546
+ ):
547
+ logger .info (f'Using solenoid opening time on sync for `reward_time`' )
548
+ all_reward_times = solenoid_times
549
+ else :
550
+ logger .info (f'Using flip time of each TaskControl frame for `reward_time`' )
551
+ all_reward_times = npc_stim .safe_index (self ._flip_times , self ._sam .rewardFrames )
552
+ all_reward_times = all_reward_times [
553
+ (all_reward_times >= self .start_time [0 ]) &
554
+ (all_reward_times <= self .stop_time [- 1 ])
555
+ ]
556
+ trial_idx_from_rewards = (
546
557
np .searchsorted (
547
558
self .start_time ,
548
559
all_reward_times ,
549
560
side = "right" ,
550
561
)
551
562
- 1
552
563
)
564
+ assert len (is_rewarded := np .where (self .is_rewarded )[0 ]) <= len (trial_idx_from_rewards )
553
565
reward_time = np .full (self ._len , np .nan )
554
- if np .all (np . where ( self . is_rewarded )[ 0 ] == all_reward_trials ):
555
- # expected single reward per trial
556
- reward_time [all_reward_trials ] = all_reward_times
566
+ if np .all (is_rewarded == trial_idx_from_rewards ):
567
+ # expected case: single reward per trial
568
+ reward_time [trial_idx_from_rewards ] = all_reward_times
557
569
else :
558
570
# mismatch between reward times and trials that are marked as having rewards
559
571
for trial_idx in np .where (self .is_rewarded )[0 ]:
0 commit comments