Skip to content

Commit f70138b

Browse files
committed
Use solenoid time for trials.reward_times if available
- fixes #147
1 parent 6efde07 commit f70138b

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/npc_sessions/sessions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,8 @@ def set_lazy_eval(
11401140
kwargs |= {"sync": self.sync_data}
11411141
if self.is_ephys and self.is_sync:
11421142
kwargs |= {"ephys_recording_dirs": self.ephys_recording_dirs}
1143-
1143+
if (reward_times := getattr(self, "_reward_times_with_duration", None)) is not None:
1144+
kwargs |= {"reward_times_with_duration": reward_times.timestamps}
11441145
# set items in LazyDict for postponed evaluation
11451146
if "RFMapping" in stim_filename:
11461147
# create two separate trials tables

src/npc_sessions/trials/TaskControl/DynamicRouting1.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -540,20 +540,32 @@ def response_time(self) -> npt.NDArray[np.float64]:
540540
@npc_io.cached_property
541541
def reward_time(self) -> npt.NDArray[np.floating]:
542542
"""delivery time of water reward, for contingent and non-contingent rewards"""
543-
all_reward_times = npc_stim.safe_index(self._flip_times, self._sam.rewardFrames)
544-
all_reward_times = all_reward_times[all_reward_times <= self.stop_time[-1]]
545-
all_reward_trials = (
543+
if (
544+
(solenoid_times := getattr(self, "_reward_times_with_duration", None)) is not None
545+
and len(solenoid_times) >= len(np.where(self.is_rewarded)[0])
546+
):
547+
logger.info(f'Using solenoid opening time on sync for `reward_time`')
548+
all_reward_times = solenoid_times
549+
else:
550+
logger.info(f'Using flip time of each TaskControl frame for `reward_time`')
551+
all_reward_times = npc_stim.safe_index(self._flip_times, self._sam.rewardFrames)
552+
all_reward_times = all_reward_times[
553+
(all_reward_times >= self.start_time[0]) &
554+
(all_reward_times <= self.stop_time[-1])
555+
]
556+
trial_idx_from_rewards = (
546557
np.searchsorted(
547558
self.start_time,
548559
all_reward_times,
549560
side="right",
550561
)
551562
- 1
552563
)
564+
assert len(is_rewarded := np.where(self.is_rewarded)[0]) <= len(trial_idx_from_rewards)
553565
reward_time = np.full(self._len, np.nan)
554-
if np.all(np.where(self.is_rewarded)[0] == all_reward_trials):
555-
# expected single reward per trial
556-
reward_time[all_reward_trials] = all_reward_times
566+
if np.all(is_rewarded == trial_idx_from_rewards):
567+
# expected case: single reward per trial
568+
reward_time[trial_idx_from_rewards] = all_reward_times
557569
else:
558570
# mismatch between reward times and trials that are marked as having rewards
559571
for trial_idx in np.where(self.is_rewarded)[0]:

0 commit comments

Comments
 (0)