diff --git a/dreem/datasets/microscopy_dataset.py b/dreem/datasets/microscopy_dataset.py index bc28f2c5..710681f6 100644 --- a/dreem/datasets/microscopy_dataset.py +++ b/dreem/datasets/microscopy_dataset.py @@ -62,7 +62,7 @@ def __init__( seed, ) - self.videos = videos + self.vid_files = videos self.tracks = tracks self.chunk = chunk self.clip_length = clip_length @@ -92,13 +92,19 @@ def __init__( parser(self.tracks[video_idx]) for video_idx in range(len(self.tracks)) ] + self.videos = [] + for vid_file in self.vid_files: + if not isinstance(vid_file, list): + self.videos.append(data_utils.LazyTiffStack(vid_file)) + else: + self.videos.append([Image.open(frame_file) for frame_file in vid_file]) self.frame_idx = [ ( torch.arange(Image.open(video).n_frames) if isinstance(video, str) else torch.arange(len(video)) ) - for video in self.videos + for video in self.vid_files ] # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be @@ -128,9 +134,6 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram video = self.videos[label_idx] - if not isinstance(video, list): - video = data_utils.LazyTiffStack(self.videos[label_idx]) - frames = [] for frame_id in frame_idx: instances, gt_track_ids, centroids = [], [], [] @@ -138,7 +141,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram img = ( video.get_section(frame_id) if not isinstance(video, list) - else np.array(Image.open(video[frame_id])) + else np.array(video[frame_id]) ) lf = labels[labels["FRAME"].astype(int) == frame_id.item()] @@ -202,3 +205,12 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram ) return frames + + def __del__(self): + """Handle file closing before deletion.""" + for vid_reader in self.videos: + if not isinstance(vid_reader, list): + vid_reader.close() + else: + for frame_reader in vid_reader: + frame_reader.close() diff --git a/dreem/datasets/sleap_dataset.py b/dreem/datasets/sleap_dataset.py index f7297a54..67b07691 100644 --- a/dreem/datasets/sleap_dataset.py +++ b/dreem/datasets/sleap_dataset.py @@ -106,7 +106,7 @@ def __init__( # if self.seed is not None: # np.random.seed(self.seed) self.labels = [sio.load_slp(slp_file) for slp_file in self.slp_files] - + self.videos = [imageio.get_reader(vid_file) for vid_file in self.vid_files] # do we need this? would need to update with sleap-io # for label in self.labels: @@ -140,7 +140,7 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram video_name = self.video_files[label_idx] - vid_reader = imageio.get_reader(video_name, "ffmpeg") + vid_reader = self.videos[label_idx] img = vid_reader.get_data(0) @@ -370,3 +370,8 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram frames.append(frame) return frames + + def __del__(self): + """Handle file closing before garbage collection.""" + for reader in self.videos: + reader.close()